jaxdem.utils.environment#

Utility functions to handle environments.

Functions

env_step(env, model, key, *[, n])

Advance the environment n steps using actions from model.

env_trajectory_rollout(env, model, key, *, n)

Roll out a trajectory by applying model in chunks of stride steps and collecting the environment after each chunk.

lidar(env)

jaxdem.utils.environment.env_trajectory_rollout(env: Environment, model: Callable, key: jax.Array, *, n: int, stride: int = 1, **kw: Any) Tuple['Environment', 'Environment'][source][source]#

Roll out a trajectory by applying model in chunks of stride steps and collecting the environment after each chunk.

Parameters:
  • env (Environment) – Initial environment pytree.

  • model (Callable) – Callable with signature model(obs, key, **kw) -> action.

  • n (int) – Number of chunks to roll out. Total internal steps = n * stride.

  • stride (int) – Steps per chunk between recorded snapshots.

  • **kw (Any) – Extra keyword arguments passed to model on every step.

Returns:

  • Environment – Environment after n * stride steps.

  • Environment – Stacked pytree of environments with length n, each snapshot taken after a chunk of stride steps.

Examples

>>> env, traj = env_trajectory_rollout(env, model, n=100, stride=5, objective=goal)
jaxdem.utils.environment.env_step(env: Environment, model: Callable, key: jax.Array, *, n: int = 1, **kw: Any) Environment[source][source]#

Advance the environment n steps using actions from model.

Parameters:
  • env (Environment) – Initial environment pytree (batchable).

  • model (Callable) – Callable with signature model(obs, key, **kw) -> action.

  • n (int) – Number of steps to perform.

  • **kw (Any) – Extra keyword arguments forwarded to model.

Returns:

Environment after n steps.

Return type:

Environment

Examples

>>> env = env_step(env, model, n=10, objective=goal)
jaxdem.utils.environment.lidar(env: Environment) jax.Array[source][source]#