jaxdem.utils.environment#
Utility functions to handle environments.
Functions
|
Advance the environment n steps using actions from model. |
|
Roll out a trajectory by applying model in chunks of stride steps and collecting the environment after each chunk. |
|
- jaxdem.utils.environment.env_trajectory_rollout(env: Environment, model: Callable, key: jax.Array, *, n: int, stride: int = 1, **kw: Any) Tuple['Environment', 'Environment'][source][source]#
Roll out a trajectory by applying model in chunks of stride steps and collecting the environment after each chunk.
- Parameters:
env (Environment) – Initial environment pytree.
model (Callable) – Callable with signature model(obs, key, **kw) -> action.
n (int) – Number of chunks to roll out. Total internal steps = n * stride.
stride (int) – Steps per chunk between recorded snapshots.
**kw (Any) – Extra keyword arguments passed to model on every step.
- Returns:
Environment – Environment after n * stride steps.
Environment – Stacked pytree of environments with length n, each snapshot taken after a chunk of stride steps.
Examples
>>> env, traj = env_trajectory_rollout(env, model, n=100, stride=5, objective=goal)
- jaxdem.utils.environment.env_step(env: Environment, model: Callable, key: jax.Array, *, n: int = 1, **kw: Any) Environment[source][source]#
Advance the environment n steps using actions from model.
- Parameters:
env (Environment) – Initial environment pytree (batchable).
model (Callable) – Callable with signature model(obs, key, **kw) -> action.
n (int) – Number of steps to perform.
**kw (Any) – Extra keyword arguments forwarded to model.
- Returns:
Environment after n steps.
- Return type:
Examples
>>> env = env_step(env, model, n=10, objective=goal)
- jaxdem.utils.environment.lidar(env: Environment) jax.Array[source][source]#