Note
Go to the end to download the full example code.
Intro to JaxDEM Reinforcement Learning#
In this example, we’ll train a simple agent using JaxDEM’s reinforcement learning tools.
The agent is a humble sphere that moves inside a box with reflective boundaries; the objective is
to reach a target location. We train it with Proximal Policy Optimization (PPO)
(jaxdem.rl.trainer.PPOTrainer
) and a shared-parameters actor–critic MLP
(jaxdem.rl.model.SharedActorCritic
).
Imports#
import distrax
import jax
import jax.numpy as jnp
import jaxdem as jdem
import jaxdem.rl as rl
from dataclasses import replace
from flax import nnx
import optax
Save = False
Environment#
First, we create a single-agent navigation environment with reflective boundaries
(uses sensible defaults for domain/time step internally). Check jaxdem.rl.environment.SingleNavigator
for details.
env = rl.Environment.create("singleNavigator")
Model#
Next, we build a shared-parameters actor–critic MLP. We can use a bijector to constrain the action space.
key = jax.random.key(1)
key, subkey = jax.random.split(key)
model = rl.Model.create(
"SharedActorCritic",
key=nnx.Rngs(subkey),
observation_space_size=env.observation_space_size,
action_space_size=env.action_space_size,
architecture=[32, 32],
action_space=rl.ActionSpace.create("maxNorm", max_norm=0.4),
)
Trainer (PPO)#
Then, we construct the PPO trainer; feel free to tweak learning rate, num_epochs, etc. (jaxdem.rl.trainer.PPOTrainer
)
These parameters are chosen for the training to run very fast. Not really for quality. Using a bijector, we don’t need to clip actions.
However, if we wanted to, we could pass that option to the trainer.
key, subkey = jax.random.split(key)
tr = rl.Trainer.create(
"PPO",
env=env,
model=model,
key=subkey,
num_epochs=100, # 300
num_envs=512,
num_steps_epoch=32,
num_minibatches=4,
learning_rate=5e-2, # 1e-1
optimizer=optax.adam, # right now, optax.contrib.muon is broken. Use adam so build passes. Muon is way better
)
Training#
Train the policy. Returns the updated trainer with learned parameters. This method is just a convenience
training loop. If desired, one can iterate manually jaxdem.rl.trainer.epoch()
tr, _ = tr.train(tr, verbose=False)
steps/s: 8.38e+05, final avg_score: -0.80
Testing the New Policy#
Now that we have a trained agent, let’s play around with it.
We spawn the agent and periodically change the target it needs to go to. This way, we will have the agent chasing around the objective. When saving the simulation state, we add a small sphere to visualize where the agent needs to go.
key, subkey = jax.random.split(key)
tr = replace(tr, env=env.reset(env, subkey))
if Save:
writer = jdem.VTKWriter()
state = tr.env.state.add(
tr.env.state, pos=tr.env.env_params["objective"], rad=tr.env.state.rad / 5
)
writer.save(state, tr.env.system)
for i in range(1, 2000):
tr, _ = tr.step(tr)
if i % 10 == 0 and Save:
state = tr.env.state.add(
tr.env.state,
pos=tr.env.env_params["objective"],
rad=tr.env.state.rad / 5,
)
writer.save(state, tr.env.system)
# Change the objective without moving the agent
if i % 500 == 0:
key, subkey = jax.random.split(key)
min_pos = tr.env.state.rad[0] * jnp.ones_like(tr.env.system.domain.box_size)
objective = jax.random.uniform(
subkey,
(tr.env.max_num_agents, tr.env.state.dim),
minval=min_pos,
maxval=tr.env.system.domain.box_size - min_pos,
dtype=float,
)
tr.env.env_params["objective"] = objective
Total running time of the script: (0 minutes 24.438 seconds)