jaxdem.rl.environments#
Reinforcement-learning environment interface.
Classes
|
Defines the interface for reinforcement-learning environments. |
- class jaxdem.rl.environments.Environment(state: State, system: System, env_params: dict[str, Any])#
Bases:
Factory,ABCDefines the interface for reinforcement-learning environments.
Let A be the number of agents (A ≥ 1). Single-agent environments still use A=1.
Observations and actions are flattened per agent to fixed sizes. Use
action_space_shapeto reshape inside the environment if needed.
Required shapes
Observation:
(A, observation_space_size)Action (input to
step()):(A, action_space_size)Reward:
(A,)Done: scalar boolean for the whole environment
Todo: - Truncated data field: per-agent termination flag - Render method
Example:#
To define a custom environment, inherit from
Environmentand implement the abstract methods:>>> @Environment.register("Environment") >>> @jax.tree_util.register_dataclass >>> @dataclass(slots=True) >>> class MyCustomEnv(Environment): ...
- env_params: dict[str, Any]#
Environment-specific parameters.
- abstractmethod static reset(env: Environment, key: Array | ndarray | bool | number | bool | int | float | complex) Environment[source]#
Initialize the environment to a valid start state.
- Parameters:
env ('MyCustomEnv') – Instance of the environment.
key (jax.random.PRNGKey) – JAX random number generator key.
- Returns:
Freshly initialized environment.
- Return type:
- static reset_if_done(env: Environment, done: Array, key: Array | ndarray | bool | number | bool | int | float | complex) Environment[source]#
Conditionally resets the environment if the environment has reached a terminal state.
This method checks the done flag and, if True, calls the environment’s reset method to reinitialize the state. Otherwise, it returns the current environment unchanged.
- Parameters:
env (Environment) – The current environment instance.
done (jax.Array) – A boolean flag indicating whether the environment has reached a terminal state.
key (jax.random.PRNGKey) – JAX random number generator key used for reinitialization.
- Returns:
Either the freshly reset environment (if done is True) or the unchanged environment (if done is False).
- Return type:
- abstractmethod static step(env: Environment, action: Array) Environment[source]#
Advance the simulation by one step using per-agent actions.
- Parameters:
env (Environment) – The current environment.
action (jax.Array) – The vector of actions each agent in the environment should take.
- Returns:
The updated environment state.
- Return type:
- abstractmethod static observation(env: Environment) Array[source]#
Returns the per-agent observation vector.
- Parameters:
env (Environment) – The current environment.
- Returns:
Vector corresponding to the environment observation.
- Return type:
jax.Array
- abstractmethod static reward(env: Environment) Array[source]#
Returns the per-agent immediate rewards.
- Parameters:
env (Environment) – The current environment.
- Returns:
Vector corresponding to all the agent’s rewards based on the current environment state.
- Return type:
jax.Array
- abstractmethod static done(env: Environment) Array[source]#
Returns a boolean indicating whether the environment has ended.
- Parameters:
env (Environment) – The current environment.
- Returns:
A bool indicating when the environment ended
- Return type:
jax.Array
- static info(env: Environment) dict[str, Any][source]#
Return auxiliary diagnostic information.
By default, returns an empty dict. Subclasses may override to provide environment specific information.
- Parameters:
env (Environment) – The current state of the environment.
- Returns:
A dictionary with additional information about the environment.
- Return type:
Dict
- property action_space_size: int[source]#
Flattened action size per agent. Actions passed to
step()have shape(A, action_space_size).
- property action_space_shape: tuple[int][source]#
Original per-agent action shape (useful for reshaping inside the environment).
- property observation_space_size: int[source]#
Flattened observation size per agent.
observation()returns shape(A, observation_space_size).
Bases:
EnvironmentMulti-agent navigation environment toward assigned targets.
Each agent controls a force vector that is applied directly to a sphere inside a reflective box. Viscous drag
-friction * velis added each step. Objectives are sampled and assigned one-to-one via a random permutation.The reward uses exponential potential-based shaping:
\[R_i = (e^{-2d_i} - e^{-2d_i^{\mathrm{prev}}}) - w_{\mathrm{ke}}(K_i - K_i^{\mathrm{prev}}) + w_{\mathrm{coop}} \cdot \frac{1}{N}\sum_j (e^{-2d_j} - e^{-2d_j^{\mathrm{prev}}}) + w_{\mathrm{near}}\,\mathbf{1}[d_i \le r_i]\]where \(d_i\) is the distance to the assigned objective and \(K_i\) is the translational kinetic energy of agent \(i\).
Notes
The observation vector per agent is:
Feature
Size
Unit direction to objective
dimClamped displacement
dimVelocity
dimLiDAR proximity (normalised)
n_lidar_raysIf one wants some realistic parameters for training,
skip_frames = 50will give a response rate of 200 Hz, meaning thatnum_steps_epoch = 100gives a horizon of 0.5 seconds.Number of angular bins for each LiDAR sensor.
Create a multi-agent navigator environment.
- Parameters:
N (int) – Number of agents.
min_box_size (float) – Range for the random square domain side length sampled at each
reset().max_box_size (float) – Range for the random square domain side length sampled at each
reset().box_padding (float) – Extra padding around the domain in multiples of the particle radius.
max_steps (int) – Episode length in physics steps.
friction (float) – Viscous drag coefficient applied as
-friction * vel.ke_weight (float) – Weight for the differential kinetic energy penalty.
coop_weight (float) – Weight for the shared team-progress bonus.
near_goal_bonus (float) – Reward bonus applied when an agent is within one radius of its objective.
lidar_range (float) – Maximum detection range for the LiDAR sensor.
n_lidar_rays (int) – Number of angular LiDAR bins spanning \([-\pi, \pi)\).
- Returns:
A freshly constructed environment (call
reset()before use).- Return type:
Initialize the environment with random positions and objectives.
- Parameters:
env (Environment) – Current environment instance.
key (ArrayLike) – JAX random number generator key.
- Returns:
Freshly initialized environment.
- Return type:
Advance one step. Actions are forces; simple drag is applied (-friction * vel).
- Parameters:
env (Environment) – The current environment.
action (jax.Array) – The vector of actions each agent in the environment should take.
- Returns:
The updated environment state.
- Return type:
Build per-agent observations.
Contents per agent#
Unit vector to objective (shape (dim,)) –> Direction
Clamped delta to objective (shape (dim,)) –> Local precision
Velocity (shape (dim,))
LiDAR proximity, normalized by
lidar_range(shape (n_lidar_rays,))
- returns:
Array of shape
(N, 3 * dim + n_lidar_rays)- rtype:
jax.Array
Returns a vector of per-agent rewards.
\[\mathrm{rew}_t = (e^{-2 \cdot d_t} - e^{-2 \cdot d_t^{\mathrm{prev}}}) - w_{\text{ke}} (K_t - K_{t-1}) + w_{\text{coop}} \cdot \mathrm{mean}(e^{-2 \cdot d_t} - e^{-2 \cdot d_t^{\mathrm{prev}}}) + w_{\text{near}} \cdot \mathbf{1}[d_t \le r]\]where \(d_t\) is the distance to the objective at step \(t\), \(K_t\) is the kinetic energy at step \(t\), \(w_{\text{ke}}\) is the kinetic-energy penalty weight, and \(w_{\text{coop}}\) weights a shared team-progress bonus, and \(w_{\text{near}}\) weights a near-goal bonus.
- Parameters:
env (Environment) – Current environment.
- Returns:
Shape
(N,).- Return type:
jax.Array
Returns a boolean indicating whether the environment has ended. The episode terminates when the maximum number of steps is reached.
- Parameters:
env (Environment) – The current environment.
- Returns:
Boolean array indicating whether the episode has ended.
- Return type:
jax.Array
Flattened action size per agent. Actions passed to
step()have shape(A, action_space_size).
Original per-agent action shape (useful for reshaping inside the environment).
Flattened observation size per agent.
observation()returns shape(A, observation_space_size).
- class jaxdem.rl.environments.MultiRoller(state: State, system: System, env_params: dict[str, Any], n_lidar_rays: int)#
Bases:
EnvironmentMulti-agent rolling environment toward assigned targets.
Each agent controls a torque vector that is applied directly to a sphere on a \(z=0\) floor. Translational drag
-friction * veland angular damping-friction * ang_velare applied each step. Objectives are sampled and assigned one-to-one via a random permutation.The reward uses exponential potential-based shaping:
\[R_i = (e^{-2d_i} - e^{-2d_i^{\mathrm{prev}}}) - w_{\mathrm{ke}}(K_i - K_i^{\mathrm{prev}}) + w_{\mathrm{coop}} \cdot \frac{1}{N}\sum_j (e^{-2d_j} - e^{-2d_j^{\mathrm{prev}}}) + w_{\mathrm{near}}\,\mathbf{1}[d_i \le r_i]\]where \(d_i\) is the distance to the assigned objective in the \(xy\) plane and \(K_i\) is the translational kinetic energy of agent \(i\).
Notes
The observation vector per agent is:
Feature
Size
Unit direction to objective
2Clamped displacement
2Velocity
2LiDAR proximity (normalised)
n_lidar_raysIf one wants some realistic parameters for training,
skip_frames = 50will give a response rate of 200 Hz, meaning thatnum_steps_epoch = 100gives a horizon of 0.5 seconds.- n_lidar_rays: int#
Number of angular bins for each LiDAR sensor.
- classmethod Create(N: int = 64, min_box_size: float = 20.0, max_box_size: float = 20.0, box_padding: float = 5.0, max_steps: int = 100000, friction: float = 0.2, ke_weight: float = 0.1, coop_weight: float = 0.2, near_goal_bonus: float = 0.1, lidar_range: float = 6.0, n_lidar_rays: int = 16) MultiRoller[source]#
Create a multi-agent roller environment.
- Parameters:
N (int) – Number of agents.
min_box_size (float) – Range for the random square domain side length sampled at each
reset().max_box_size (float) – Range for the random square domain side length sampled at each
reset().box_padding (float) – Extra padding around the domain in multiples of the particle radius.
max_steps (int) – Episode length in physics steps.
friction (float) – Translational and angular damping coefficient.
ke_weight (float) – Weight for the differential kinetic energy penalty.
coop_weight (float) – Weight for the shared team-progress bonus.
near_goal_bonus (float) – Reward bonus applied when an agent is within one radius of its objective.
lidar_range (float) – Maximum detection range for the LiDAR sensor.
n_lidar_rays (int) – Number of angular LiDAR bins spanning \([-\pi, \pi)\).
- Returns:
A freshly constructed environment (call
reset()before use).- Return type:
- static reset(env: MultiRoller, key: Array | ndarray | bool | number | bool | int | float | complex) Environment[source]#
Initialize the environment with random positions and objectives.
- Parameters:
env (Environment) – Current environment instance.
key (ArrayLike) – JAX random number generator key.
- Returns:
Freshly initialized environment.
- Return type:
- static step(env: MultiRoller, action: Array) Environment[source]#
Advance one step. Actions are torques; simple damping is applied.
- Parameters:
env (Environment) – The current environment.
action (jax.Array) – The vector of actions each agent in the environment should take.
- Returns:
The updated environment state.
- Return type:
- static observation(env: MultiRoller) Array[source]#
Build per-agent observations.
Contents per agent#
Unit vector to objective in the \(xy\) plane (shape (2,)).
Clamped objective delta in the \(xy\) plane (shape (2,)).
Velocity in the \(xy\) plane (shape (2,)).
LiDAR proximity, normalized by
lidar_range(shape (n_lidar_rays,)).
- returns:
Array of shape
(N, 6 + n_lidar_rays)- rtype:
jax.Array
- static reward(env: MultiRoller) Array[source]#
Returns a vector of per-agent rewards.
\[\mathrm{rew}_t = (e^{-2 \cdot d_t} - e^{-2 \cdot d_t^{\mathrm{prev}}}) - w_{\text{ke}} (K_t - K_{t-1}) + w_{\text{coop}} \cdot \mathrm{mean}(e^{-2 \cdot d_t} - e^{-2 \cdot d_t^{\mathrm{prev}}}) + w_{\text{near}} \cdot \mathbf{1}[d_t \le r]\]where \(d_t\) is the distance to the objective at step \(t\), \(K_t\) is the kinetic energy at step \(t\), \(w_{\text{ke}}\) is the kinetic-energy penalty weight, and \(w_{\text{coop}}\) weights a shared team-progress bonus, and \(w_{\text{near}}\) weights a near-goal bonus.
- Parameters:
env (Environment) – Current environment.
- Returns:
Shape
(N,).- Return type:
jax.Array
- static done(env: MultiRoller) Array[source]#
Returns a boolean indicating whether the environment has ended. The episode terminates when the maximum number of steps is reached.
- Parameters:
env (Environment) – The current environment.
- Returns:
Boolean array indicating whether the environment has ended.
- Return type:
jax.Array
- property action_space_size: int[source]#
Flattened action size per agent. Actions passed to
step()have shape(A, action_space_size).
- property action_space_shape: tuple[int][source]#
Original per-agent action shape (useful for reshaping inside the environment).
- property observation_space_size: int[source]#
Flattened observation size per agent.
observation()returns shape(A, observation_space_size).
Bases:
EnvironmentSingle-agent navigation environment toward a fixed target.
The agent controls a force vector that is applied directly to a sphere inside a reflective box. Viscous drag
-friction * velis added each step. The reward uses exponential potential-based shaping:\[\mathrm{rew}_t = (e^{-2 \cdot d_t} - e^{-2 \cdot d_t^{\mathrm{prev}}}) - w_{\text{ke}} (K_t - K_{t-1})\]where \(d_t\) is the distance to the objective at step \(t\), \(K_t\) is the kinetic energy at step \(t\), and \(w_{\text{ke}}\) is the weight for the kinetic energy penalty.
Notes
The observation vector per agent is:
Feature
Size
Unit direction to objective
dimClamped displacement
dimVelocity
dimIf one wants some realistic parameters for training,
skip_frames = 50will give a response rate of 200 Hz, meaning thatnum_steps_epoch = 100gives a horizon of 0.5 seconds.Create a single-agent navigator environment.
- Parameters:
dim (int) – Spatial dimensionality (2 or 3).
min_box_size (float) – Range for the random square domain side length.
max_box_size (float) – Range for the random square domain side length.
max_steps (int) – Episode length in physics steps.
friction (float) – Viscous drag coefficient applied as
-friction * vel.ke_weight (float) – Weight for the differential kinetic energy penalty.
- Returns:
A freshly constructed environment (call
reset()before use).- Return type:
Initialize the environment with a randomly placed particle and velocity.
- Parameters:
env ('SingleNavigator') – Current environment instance.
key (jax.random.PRNGKey) – JAX random number generator key.
- Returns:
Freshly initialized environment.
- Return type:
Advance one step. Actions are forces; simple drag is applied (-friction * vel).
- Parameters:
env (Environment) – The current environment.
action (jax.Array) – The vector of actions each agent in the environment should take.
- Returns:
The updated environment state.
- Return type:
Build per-agent observations.
Contents per agent#
Unit vector to objective (shape (dim,)) –> Direction
Clamped delta to objective (shape (dim,)) –> Local precision
Velocity (shape (dim,))
- returns:
Array of shape
(N, 3 * dim)- rtype:
jax.Array
Returns a vector of per-agent rewards.
Reward:
\[\mathrm{rew}_t = (e^{-2 \cdot d_t} - e^{-2 \cdot d_t^{\mathrm{prev}}}) - w_{\text{ke}} (K_t - K_{t-1})\]where \(d_t\) is the distance to the objective at step \(t\), \(K_t\) is the kinetic energy at step \(t\), and \(w_{\text{ke}}\) is the weight for the kinetic energy penalty.
- Parameters:
env (Environment) – Current environment.
- Returns:
Shape
(N,).- Return type:
jax.Array
Returns a boolean indicating whether the environment has ended. The episode terminates when the maximum number of steps is reached.
- Parameters:
env (Environment) – The current environment.
- Returns:
Boolean array indicating whether the episode has ended.
- Return type:
jax.Array
Flattened action size per agent. Actions passed to
step()have shape(A, action_space_size).
Original per-agent action shape (useful for reshaping inside the environment).
Flattened observation size per agent.
observation()returns shape(A, observation_space_size).
- class jaxdem.rl.environments.SingleRoller3D(state: State, system: System, env_params: dict[str, Any])#
Bases:
EnvironmentSingle-agent 3D navigation via torque-controlled rolling.
The agent is a sphere resting on a \(z = 0\) floor under gravity. Actions are 3-D torque vectors; translational motion arises from frictional contact with the floor (see
frictional_wall_force()). A viscous drag-friction * veland a fixed angular damping of-friction * ang_velare applied each step.The reward uses exponential potential-based shaping:
\[\mathrm{rew}_t = (e^{-2 \cdot d_t} - e^{-2 \cdot d_t^{\mathrm{prev}}}) - w_{\text{ke}} (K_t - K_{t-1})\]where \(d_t\) is the distance to the objective at step \(t\), \(K_t\) is the kinetic energy at step \(t\), and \(w_{\text{ke}}\) is the weight for the kinetic energy penalty.
Notes
The observation vector per agent is:
Feature
Size
Unit direction to objective
2
Clamped displacement (x, y)
2
Velocity (x, y)
2
Angular velocity
3
If one wants some realistic parameters for training,
skip_frames = 50will give a response rate of 200 Hz, meaning thatnum_steps_epoch = 100gives a horizon of 0.5 seconds.- classmethod Create(min_box_size: float = 40.0, max_box_size: float = 40.0, max_steps: int = 20000, friction: float = 0.2, ke_weight: float = 0.1) SingleRoller3D[source]#
Create a single-agent roller environment.
- Parameters:
min_box_size (float) – Range for the random square domain side length.
max_box_size (float) – Range for the random square domain side length.
max_steps (int) – Episode length in physics steps.
friction (float) – Viscous drag coefficient applied as
-friction * vel.ke_weight (float) – Weight for the differential kinetic energy penalty.
- Returns:
A freshly constructed environment (call
reset()before use).- Return type:
- static reset(env: SingleRoller3D, key: Array | ndarray | bool | number | bool | int | float | complex) Environment[source]#
Randomly place the agent and objective on the floor.
- Parameters:
env (Environment) – Current environment instance.
key (ArrayLike) – JAX PRNG key.
- Returns:
Freshly initialised environment.
- Return type:
- static step(env: SingleRoller3D, action: Array) Environment[source]#
Apply a torque action, advance physics by one step.
- Parameters:
env (Environment) – Current environment.
action (jax.Array) – 3-D torque vector per agent.
- Returns:
Updated environment after one physics step.
- Return type:
- static observation(env: SingleRoller3D) Array[source]#
Per-agent observation vector.
Contents per agent:
Unit displacement to objective projected to x-y (shape
(2,)).Clamped displacement to objective projected to x-y (shape
(2,)).Velocity projected to x-y (shape
(2,)).Angular velocity (shape
(3,)).
- Returns:
Shape
(N, 9).- Return type:
jax.Array
- static reward(env: SingleRoller3D) Array[source]#
Returns a vector of per-agent rewards.
Exponential potential-based shaping:
\[\mathrm{rew}_t = (e^{-2 \cdot d_t} - e^{-2 \cdot d_t^{\mathrm{prev}}}) - w_{\text{ke}} (K_t - K_{t-1})\]- Returns:
Shape
(N,).- Return type:
jax.Array
- static done(env: SingleRoller3D) Array[source]#
Truewhenstep_countexceedsmax_steps.
Bases:
EnvironmentMulti-agent navigation environment toward nearby shared targets.
Each agent controls a force vector that is applied directly to a sphere inside a reflective box. Viscous drag
-friction * velis added each step. Objectives are sampled globally, and each agent observes objective LiDAR and agent LiDAR.At reset, a small subset of agents is spawned in the central objective region while the rest are spawned in the outer padding ring.
The reward uses exponential potential-based shaping:
\[R_i = (S_i - S_i^{\mathrm{prev}}) - w_{\mathrm{ke}}(K_i - K_i^{\mathrm{prev}}) + w_{\mathrm{coop}} \cdot \frac{1}{N}\sum_m (S_m - S_m^{\mathrm{prev}}) + w_{\mathrm{near}}\,\mathbf{1}[d_i \le r_i]\]where \(d_i\) is the distance to the closest objective, \(K_i\) is the translational kinetic energy of agent \(i\), and \(S_i = \sum_{r \in \text{obj-LiDAR}} e^{-2 d_{ir}}\) sums exponential shaping over objectives detected by objective LiDAR rays.
Notes
The observation vector per agent is:
Feature
Size
Velocity
dimObjective LiDAR proximity
n_lidar_raysAgent LiDAR proximity
n_lidar_raysNumber of angular bins for each LiDAR sensor.
Create a swarm navigator environment.
- Parameters:
N (int) – Number of agents and number of sampled objectives.
min_box_size (float) – Range for the random square domain side length sampled at each
reset().max_box_size (float) – Range for the random square domain side length sampled at each
reset().box_padding (float) – Extra padding around the domain in multiples of the particle radius. The padding region is used as the outer spawn ring.
max_steps (int) – Episode length in physics steps.
friction (float) – Viscous drag coefficient applied as
-friction * vel.ke_weight (float) – Weight for the differential kinetic energy penalty.
coop_weight (float) – Weight for the shared team-progress bonus.
near_goal_bonus (float) – Reward bonus applied when an agent is within one radius of its closest objective.
lidar_range (float) – Maximum detection range for the LiDAR sensor.
n_lidar_rays (int) – Number of angular LiDAR bins spanning \([-\pi, \pi)\).
- Returns:
A freshly constructed environment (call
reset()before use).- Return type:
Initialize the environment with random positions and objectives.
- Parameters:
env (Environment) – Current environment instance.
key (ArrayLike) – JAX random number generator key.
- Returns:
Freshly initialized environment.
- Return type:
Advance one step. Actions are forces; simple drag is applied (-friction * vel).
- Parameters:
env (Environment) – The current environment.
action (jax.Array) – The vector of actions each agent in the environment should take.
- Returns:
The updated environment state.
- Return type:
Build per-agent observations.
Contents per agent#
Velocity (shape (dim,)).
Objective LiDAR proximity, normalized by
lidar_range(shape (n_lidar_rays,)).Agent LiDAR proximity, normalized by
lidar_range(shape (n_lidar_rays,)).
- returns:
Array of shape
(N, dim + 2 * n_lidar_rays)- rtype:
jax.Array
Returns a vector of per-agent rewards.
\[\mathrm{rew}_t = (S_t - S_t^{\mathrm{prev}}) - w_{\text{ke}} (K_t - K_{t-1}) + w_{\text{coop}} \cdot \mathrm{mean}\left( (S_t - S_t^{\mathrm{prev}})\right) + w_{\text{near}} \cdot \mathbf{1}[d_t \le r]\]where \(d_t\) is the distance to the closest objective at step \(t\), \(K_t\) is the kinetic energy at step \(t\), and \(S_t\) is the per-agent sum of \(e^{-2d}\) over objectives detected by objective LiDAR rays, \(w_{\text{ke}}\) is the kinetic-energy penalty weight, and \(w_{\text{coop}}\) weights a shared team-progress bonus, and \(w_{\text{near}}\) weights a near-goal bonus.
- Parameters:
env (Environment) – Current environment.
- Returns:
Shape
(N,).- Return type:
jax.Array
Returns a boolean indicating whether the environment has ended. The episode terminates when the maximum number of steps is reached.
- Parameters:
env (Environment) – The current environment.
- Returns:
Boolean array indicating whether the environment has ended.
- Return type:
jax.Array
Flattened action size per agent. Actions passed to
step()have shape(A, action_space_size).
Original per-agent action shape (useful for reshaping inside the environment).
Flattened observation size per agent.
observation()returns shape(A, observation_space_size).
- class jaxdem.rl.environments.SwarmRoller(state: State, system: System, env_params: dict[str, Any], n_lidar_rays: int)#
Bases:
EnvironmentMulti-agent rolling environment toward nearby shared targets.
Each agent controls a torque vector that is applied directly to a sphere on a \(z=0\) floor. Translational drag
-friction * veland angular damping-friction * ang_velare applied each step. Objectives are sampled globally, and each agent observes objective LiDAR and agent LiDAR.At reset, a small subset of agents is spawned in the central objective region while the rest are spawned in the outer padding ring.
The reward uses exponential potential-based shaping:
\[R_i = (S_i - S_i^{\mathrm{prev}}) - w_{\mathrm{ke}}(K_i - K_i^{\mathrm{prev}}) + w_{\mathrm{coop}} \cdot \frac{1}{N}\sum_m (S_m - S_m^{\mathrm{prev}}) + w_{\mathrm{near}}\,\mathbf{1}[d_i \le r_i]\]where \(d_i\) is the distance to the closest objective, \(K_i\) is the translational kinetic energy of agent \(i\), and \(S_i = \sum_{r \in \text{obj-LiDAR}} e^{-2 d_{ir}}\) sums exponential shaping over objectives detected by objective LiDAR rays.
Notes
The observation vector per agent is:
Feature
Size
Velocity
dimObjective LiDAR proximity
n_lidar_raysAgent LiDAR proximity
n_lidar_rays- n_lidar_rays: int#
Number of angular bins for each LiDAR sensor.
- classmethod Create(N: int = 64, min_box_size: float = 20.0, max_box_size: float = 20.0, box_padding: float = 20.0, max_steps: int = 100000, friction: float = 0.2, ke_weight: float = 0.1, coop_weight: float = 0.2, near_goal_bonus: float = 0.1, lidar_range: float = 10.0, n_lidar_rays: int = 24) SwarmRoller[source]#
Create a swarm roller environment.
- Parameters:
N (int) – Number of agents and number of sampled objectives.
min_box_size (float) – Range for the random square domain side length sampled at each
reset().max_box_size (float) – Range for the random square domain side length sampled at each
reset().box_padding (float) – Extra padding around the domain in multiples of the particle radius. The padding region is used as the outer spawn ring.
max_steps (int) – Episode length in physics steps.
friction (float) – Translational and angular damping coefficient.
ke_weight (float) – Weight for the differential kinetic energy penalty.
coop_weight (float) – Weight for the shared team-progress bonus.
near_goal_bonus (float) – Reward bonus applied when an agent is within one radius of its closest objective.
lidar_range (float) – Maximum detection range for the LiDAR sensor.
n_lidar_rays (int) – Number of angular LiDAR bins spanning \([-\pi, \pi)\).
- Returns:
A freshly constructed environment (call
reset()before use).- Return type:
- static reset(env: SwarmRoller, key: Array | ndarray | bool | number | bool | int | float | complex) Environment[source]#
Initialize the environment with random positions and objectives.
- Parameters:
env (Environment) – Current environment instance.
key (ArrayLike) – JAX random number generator key.
- Returns:
Freshly initialized environment.
- Return type:
- static step(env: SwarmRoller, action: Array) Environment[source]#
Advance one step. Actions are torques; simple damping is applied.
- Parameters:
env (Environment) – The current environment.
action (jax.Array) – The vector of actions each agent in the environment should take.
- Returns:
The updated environment state.
- Return type:
- static observation(env: SwarmRoller) Array[source]#
Build per-agent observations.
Contents per agent#
Velocity (shape (dim,)).
Objective LiDAR proximity, normalized by
lidar_range(shape (n_lidar_rays,)).Agent LiDAR proximity, normalized by
lidar_range(shape (n_lidar_rays,)).
- returns:
Array of shape
(N, dim + 2 * n_lidar_rays)- rtype:
jax.Array
- static reward(env: SwarmRoller) Array[source]#
Returns a vector of per-agent rewards.
\[\mathrm{rew}_t = (S_t - S_t^{\mathrm{prev}}) - w_{\text{ke}} (K_t - K_{t-1}) + w_{\text{coop}} \cdot \mathrm{mean}\left( (S_t - S_t^{\mathrm{prev}})\right) + w_{\text{near}} \cdot \mathbf{1}[d_t \le r]\]where \(d_t\) is the distance to the closest objective at step \(t\), \(K_t\) is the kinetic energy at step \(t\), and \(S_t\) is the per-agent sum of \(e^{-2d}\) over objectives detected by objective LiDAR rays, \(w_{\text{ke}}\) is the kinetic-energy penalty weight, and \(w_{\text{coop}}\) weights a shared team-progress bonus, and \(w_{\text{near}}\) weights a near-goal bonus.
- Parameters:
env (Environment) – Current environment.
- Returns:
Shape
(N,).- Return type:
jax.Array
- static done(env: SwarmRoller) Array[source]#
Returns a boolean indicating whether the environment has ended. The episode terminates when the maximum number of steps is reached.
- Parameters:
env (Environment) – The current environment.
- Returns:
Boolean array indicating whether the environment has ended.
- Return type:
jax.Array
- property action_space_size: int[source]#
Flattened action size per agent. Actions passed to
step()have shape(A, action_space_size).
- property action_space_shape: tuple[int][source]#
Original per-agent action shape (useful for reshaping inside the environment).
- property observation_space_size: int[source]#
Flattened observation size per agent.
observation()returns shape(A, observation_space_size).
- class jaxdem.rl.environments.SwarmRoller3D(state: State, system: System, env_params: dict[str, Any], n_lidar_rays: int, n_lidar_elevation: int)#
Bases:
EnvironmentMulti-agent 3-D rolling environment with magnetic interaction and pyramid objectives.
- n_lidar_rays: int#
Number of azimuthal bins for the 3-D LiDAR sensor.
- n_lidar_elevation: int#
Number of elevation bins for the 3-D LiDAR sensor.
- classmethod Create(N: int = 5, min_box_size: float = 10.0, max_box_size: float = 10.0, box_padding: float = 0.0, max_steps: int = 100000, friction: float = 0.2, lidar_range: float = 10.0, n_lidar_rays: int = 8, n_lidar_elevation: int = 8, magnet_strength: float = 4.0, magnet_range: float = 3.0, ke_weight: float = 0.1, coop_weight: float = 0.2, near_goal_bonus: float = 0.1) SwarmRoller3D[source]#
Create a swarm roller 3-D environment.
- static reset(env: SwarmRoller3D, key: Array | ndarray | bool | number | bool | int | float | complex) Environment[source]#
Reset the environment to a random initial configuration.
- static step(env: SwarmRoller3D, action: Array) Environment[source]#
Advance the environment by one physics step.
- static observation(env: SwarmRoller3D) Array[source]#
Build per-agent observations.
- static reward(env: SwarmRoller3D) Array[source]#
Return per-agent rewards.
- static done(env: SwarmRoller3D) Array[source]#
Return
Truewhen the episode has exceededmax_steps.
- class jaxdem.rl.environments.SwarmStacking3D(state: State, system: System, env_params: dict[str, Any], n_lidar_rays: int, n_lidar_elevation: int)#
Bases:
EnvironmentMulti-agent 3-D stacking environment with periodic boundaries.
- n_lidar_rays: int#
Number of azimuthal bins for the 3-D LiDAR sensor.
- n_lidar_elevation: int#
Number of elevation bins for the 3-D LiDAR sensor.
- classmethod Create(N: int = 16, min_box_size: float = 20.0, max_box_size: float = 20.0, box_padding: float = 20.0, max_steps: int = 5760, friction: float = 0.2, lidar_range: float = 10.0, n_lidar_rays: int = 8, n_lidar_elevation: int = 8, magnet_strength: float = 40.0, magnet_range: float = 2.4, ke_weight: float = 0.1, coop_weight: float = 0.2, near_goal_bonus: float = 0.1) SwarmStacking3D[source]#
Create a swarm stacking 3-D environment.
- static reset(env: SwarmStacking3D, key: Array | ndarray | bool | number | bool | int | float | complex) Environment[source]#
- static step(env: SwarmStacking3D, action: Array) Environment[source]#
- static observation(env: SwarmStacking3D) Array[source]#
- static reward(env: SwarmStacking3D) Array[source]#
- static done(env: SwarmStacking3D) Array[source]#
- property action_space_size: int[source]#
Flattened action size per agent. Actions passed to
step()have shape(A, action_space_size).
- property action_space_shape: tuple[int][source]#
Original per-agent action shape (useful for reshaping inside the environment).
- property observation_space_size: int[source]#
Flattened observation size per agent.
observation()returns shape(A, observation_space_size).
- class jaxdem.rl.environments.ThreeGears(state: State, system: System, env_params: dict[str, Any])#
Bases:
EnvironmentMulti-agent 2-D environment with three gears.
The environment consists of three active gears composed of spheres. All gears can apply torque to themselves. The shared objective is to navigate the gears to form a triangular structure defined by a randomized target position.
Note
Similar to the TwoGears environment, if one wants some realistic parameters for training,
skip_frames = 50will give a response rate of 200 Hz, meaning thatnum_steps_epoch = 100gives a horizon of 0.5 seconds.- classmethod Create(box_size: float = 10.0, max_steps: int = 100000, friction: float = 0.2, ke_weight: float = 0.001) ThreeGears[source]#
Create a three-gears 2-D environment.
- Parameters:
box_size (float) – Size of the square bounding box.
max_steps (int) – Episode length in physics steps.
friction (float) – Viscous drag coefficient applied as
-friction * vel.ke_weight (float) – Weight for the differential kinetic energy penalty.
- Returns:
A freshly constructed environment (call
reset()before use).- Return type:
- static reset(env: ThreeGears, key: Array) Environment[source]#
Reset the environment to a random initial configuration.
- Parameters:
env (Environment) – The environment instance to reset.
key (jax.Array) – PRNG key used to sample the initial positions and objective triangle.
- Returns:
The environment with a fresh episode state.
- Return type:
- static step(env: ThreeGears, action: Array) Environment[source]#
Advance the environment by one physics step.
Applies torque to the active gears, computes inter-gear forces, and applies viscous drag.
- Parameters:
env (Environment) – Current environment.
action (jax.Array) – Actions for all gears.
- Returns:
Updated environment after physics integration and sensor updates.
- Return type:
- static observation(env: ThreeGears) Array[source]#
Build the observation vector.
The observation vector contains 22 features:
Feature
Size
Distance to floor
1Distance to left/right walls
2Unit vector to target
2Clamped displacement to target
2Unit vector to neighbor j
2Clamped displacement to neighbor j
2\(\sin(\Delta\theta_j)\)
1\(\cos(\Delta\theta_j)\)
1Unit vector to neighbor k
2Clamped displacement to neighbor k
2\(\sin(\Delta\theta_k)\)
1\(\cos(\Delta\theta_k)\)
1Velocity (x, y)
2Angular velocity
1- Returns:
Observation vector of size
22.- Return type:
jax.Array
- static reward(env: ThreeGears) Array[source]#
Compute the cooperative reward.
The shared reward is based on the differential distance to the objective minus a penalty for the change in kinetic energy:
\[R_t = (d_{t-1} - d_t) - w_{\text{ke}} \sum_i (K_t^i - K_{t-1}^i)\]where \(d_t\) is the total distance to the objective at step \(t\), \(K_t^i\) is the kinetic energy of agent \(i\), and \(w_{\text{ke}}\) is the kinetic energy weight.
- Returns:
Reward value, identical for all agents.
- Return type:
jax.Array
- static done(env: ThreeGears) Array[source]#
- property action_space_size: int[source]#
Flattened action size per agent. Actions passed to
step()have shape(A, action_space_size).
- property action_space_shape: tuple[int][source]#
Original per-agent action shape (useful for reshaping inside the environment).
- property observation_space_size: int[source]#
Flattened observation size per agent.
observation()returns shape(A, observation_space_size).
- class jaxdem.rl.environments.TwoGears(state: State, system: System, env_params: dict[str, Any])#
Bases:
EnvironmentTwo-dimensional environment with two gears.
The environment consists of two gears composed of spheres. One gear is frozen on the floor, and the other is an active agent that can apply torque to itself. The objective is to navigate the active gear to a specified target position above the frozen gear. The active gear is attracted to the frozen gear by a magnetic force.
Note
After experimentation, one needs the max torque to be at least
4.0 * mgrfor the gear to be able to climb correctly, and attraction at least1 * mg. If one wants some realistic parameters for training,skip_frames = 50will give a response rate of 200 Hz, meaning thatnum_steps_epoch = 100gives a horizon of 0.5 seconds.- classmethod Create(box_size: float = 10.0, max_steps: int = 100000, friction: float = 0.2, ke_weight: float = 0.1, attraction_mag: float = 4.0) TwoGears[source]#
Create a two-gears 2-D environment.
- Parameters:
box_size (float) – Size of the square bounding box.
max_steps (int) – Episode length in physics steps.
friction (float) – Viscous drag coefficient applied as
-friction * vel.ke_weight (float) – Weight for the differential kinetic energy penalty.
attraction_mag (float) – Magnitude of the attraction force between the two gears.
- Returns:
A freshly constructed environment (call
reset()before use).- Return type:
- static reset(env: TwoGears, key: Array) Environment[source]#
Reset the environment to a random initial configuration.
- Parameters:
env (Environment) – The environment instance to reset.
key (jax.Array) – PRNG key used to sample the initial positions and objective.
- Returns:
The environment with a fresh episode state.
- Return type:
- static step(env: TwoGears, action: Array) Environment[source]#
Advance the environment by one step.
Applies torque to the active agent, computes the attraction force between the gears, and applies viscous drag.
The attraction force is defined as:
\[\mathbf{F}_{\text{attraction}} = - \frac{C}{d^3} \hat{n},\]when \(d < 3 r\), where \(d\) is the distance between the centers, \(\hat{n}\) is the unit vector from the frozen gear to the active gear, and \(C\) is determined by
attraction_magas \(C = m_{\text{attr}} (2r)^3\). r is the gear radius.- Parameters:
env (Environment) – Current environment.
action (jax.Array) – Actions for the active gear.
- Returns:
Updated environment after physics integration and sensor updates.
- Return type:
- static observation(env: TwoGears) Array[source]#
Build the observation vector.
The observation vector contains 16 features:
Feature
Size
Distance to floor
1Distance to left/right walls
2Unit vector to target
2Clamped displacement to target
2Unit vector to frozen gear
2Clamped displacement to frozen gear
2\(\sin(\Delta\theta)\)
1\(\cos(\Delta\theta)\)
1Velocity (x, y)
2Angular velocity
1- Returns:
Observation vector of size
16.- Return type:
jax.Array
- static reward(env: TwoGears) Array[source]#
Compute the reward.
The reward is based on the differential distance to the objective minus a penalty for the change in kinetic energy:
\[R_t = (d_{t-1} - d_t) - w_{\text{ke}} (K_t - K_{t-1})\]where \(d_t\) is the distance to the objective at step \(t\), \(K_t\) is the kinetic energy at step \(t\), and \(w_{\text{ke}}\) is the weight for the kinetic energy penalty.
- Returns:
Reward value for the active agent.
- Return type:
jax.Array
- property action_space_size: int[source]#
Flattened action size per agent. Actions passed to
step()have shape(A, action_space_size).
- property action_space_shape: tuple[int][source]#
Original per-agent action shape (useful for reshaping inside the environment).
- property observation_space_size: int[source]#
Flattened observation size per agent.
observation()returns shape(A, observation_space_size).
Modules
Environment where multiple agents navigate towards assigned targets. |
|
Environment where multiple rolling agents navigate towards assigned targets. |
|
Environment where a single agent navigates towards a target. |
|
Environment where a single agent rolls towards a target on the floor. |
|
Environment where multiple agents navigate towards nearby shared targets. |
|
Environment where multiple rolling agents navigate towards nearby shared targets. |
|
Multi-agent 3-D swarm rolling environment with magnetic interaction and pyramid objectives. |
|
Multi-agent 3-D swarm stacking environment with periodic boundaries. |
|
Multi-agent 2-D environment with three gears and a triangle objective. |
|
Two-dimensional environment with two gears for RL training. |