Driving Environments with a Custom Policy#

In this example, we create an environment instance and show how to drive it efficiently using a custom policy. This approach removes the need to create a trainer object, making evaluation much more efficient.

Imports#

import jax
import jax.numpy as jnp
from flax import nnx

import jaxdem as jdem
import jaxdem.rl as rl
from jaxdem import utils

from pathlib import Path

Variables#

First, we define all the variables needed for the example.

frames_dir = Path("/tmp/frames")
key = jax.random.key(1)
N = 24
save_every = 40
T = 4000
batches = T // save_every
num_envs = 40

The Policy#

Next, we define a callable that takes the observations and some keyword arguments, and returns the corresponding actions. For more information, see env_step().

In this example, we drive the environment with a model from JaxDEM using nnx. However, model can be any JIT-compatible function.

def model(obs, key, graphdef, graphstate):
    base_model = nnx.merge(graphdef, graphstate)
    pi, value = base_model(obs, sequence=False)
    action = pi.sample(seed=key)
    return action

Model and Environment#

Now we create a model and an environment to use in the example. We will not perform any training here, since the goal is to show how to drive the environment directly.

A trained model could be loaded in the same way using CheckpointModelLoader.

env = rl.Environment.create("MultiNavigator", N=N)

key, subkey = jax.random.split(key)
base_model = rl.Model.create(
    "SharedActorCritic",
    key=nnx.Rngs(subkey),
    observation_space_size=env.observation_space_size,
    action_space_size=env.action_space_size,
)
base_model.eval()
graphdef, graphstate = nnx.split(base_model)

Environment Vectorization#

JaxDEM supports vectorized environments, allowing multiple simulations to run in parallel for significant speedups. This is usefull for gathering statistics about the environmentt.

subkeys = jax.random.split(key, num_envs)
env = jax.vmap(lambda _: env)(jnp.arange(num_envs))
env = rl.vectorise_env(env)
env = env.reset(env, subkeys)

Driving the Environment#

There are two main ways to drive an environment. The first is by stepping it manually for a fixed number of steps:

key, subkey = jax.random.split(key)
env = utils.env_step(
    env,
    model,
    subkey,
    graphdef=graphdef,
    graphstate=graphstate,
    n=save_every,
)

The second approach is to roll out a trajectory, collecting data every stride steps:

key, subkey = jax.random.split(key)
env, env_traj = utils.env_trajectory_rollout(
    env,
    model,
    subkey,
    graphdef=graphdef,
    graphstate=graphstate,
    n=batches,
    stride=save_every,
)

Saving Data#

Finally, we can use JaxDEM’s VTKWriter to save the full rollout to disk in a single call:

writer = jdem.VTKWriter(directory=frames_dir)
writer.save(env_traj.state, env_traj.system, trajectory=True)

Total running time of the script: (0 minutes 22.553 seconds)