jaxdem.rl.environments#
Reinforcement-learning environment interface.
Classes
|
Defines the interface for reinforcement-learning environments. |
- class jaxdem.rl.environments.Environment(state: State, system: System, env_params: dict[str, Any])#
Bases:
Factory,ABCDefines the interface for reinforcement-learning environments.
Let A be the number of agents (A ≥ 1). Single-agent environments still use A=1.
Observations and actions are flattened per agent to fixed sizes. Use
action_space_shapeto reshape inside the environment if needed.
Required shapes
Observation:
(A, observation_space_size)Action (input to
step()):(A, action_space_size)Reward:
(A,)Done: scalar boolean for the whole environment
Todo: - Truncated data field: per-agent termination flag - Render method
Example:#
To define a custom environment, inherit from
Environmentand implement the abstract methods:>>> @Environment.register("Environment") >>> @jax.tree_util.register_dataclass >>> @dataclass(slots=True) >>> class MyCustomEnv(Environment): ...
- env_params: dict[str, Any]#
Environment-specific parameters.
- abstractmethod static reset(env: Environment, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source]#
Initialize the environment to a valid start state.
- Parameters:
env ('MyCustomEnv') – Instance of the environment.
key (jax.random.PRNGKey) – JAX random number generator key.
- Returns:
Freshly initialized environment.
- Return type:
- static reset_if_done(env: Environment, done: Array, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source]#
Conditionally resets the environment if the environment has reached a terminal state.
This method checks the done flag and, if True, calls the environment’s reset method to reinitialize the state. Otherwise, it returns the current environment unchanged.
- Parameters:
env (Environment) – The current environment instance.
done (jax.Array) – A boolean flag indicating whether the environment has reached a terminal state.
key (jax.random.PRNGKey) – JAX random number generator key used for reinitialization.
- Returns:
Either the freshly reset environment (if done is True) or the unchanged environment (if done is False).
- Return type:
- abstractmethod static step(env: Environment, action: Array) Environment[source]#
Advance the simulation by one step using per-agent actions.
- Parameters:
env (Environment) – The current environment.
action (jax.Array) – The vector of actions each agent in the environment should take.
- Returns:
The updated environment state.
- Return type:
- abstractmethod static observation(env: Environment) Array[source]#
Returns the per-agent observation vector.
- Parameters:
env (Environment) – The current environment.
- Returns:
Vector corresponding to the environment observation.
- Return type:
jax.Array
- abstractmethod static reward(env: Environment) Array[source]#
Returns the per-agent immediate rewards.
- Parameters:
env (Environment) – The current environment.
- Returns:
Vector corresponding to all the agent’s rewards based on the current environment state.
- Return type:
jax.Array
- abstractmethod static done(env: Environment) Array[source]#
Returns a boolean indicating whether the environment has ended.
- Parameters:
env (Environment) – The current environment.
- Returns:
A bool indicating when the environment ended
- Return type:
jax.Array
- static info(env: Environment) dict[str, Any][source]#
Return auxiliary diagnostic information.
By default, returns an empty dict. Subclasses may override to provide environment specific information.
- Parameters:
env (Environment) – The current state of the environment.
- Returns:
A dictionary with additional information about the environment.
- Return type:
Dict
- property action_space_size: int[source]#
Flattened action size per agent. Actions passed to
step()have shape(A, action_space_size).
- property action_space_shape: tuple[int][source]#
Original per-agent action shape (useful for reshaping inside the environment).
- property observation_space_size: int[source]#
Flattened observation size per agent.
observation()returns shape(A, observation_space_size).
Bases:
EnvironmentMulti-agent 2-D navigation with cooperative rewards.
Each agent controls a force vector applied directly to a sphere inside a reflective box. Viscous drag
-friction * velis added every step. Objectives are assigned one-to-one via a random permutation. Each agent receives a random priority scalar at reset for symmetry breaking.Reward
\[R_i = w_s\,(e^{-2d_i} - e^{-2d_i^{\mathrm{prev}}}) + w_g\,\mathbf{1}[d_i < f \cdot r_i] - w_c\,\left\|\sum_j l_j\,\hat{r}_j\right\| - w_w\,\|a_i\|^2 - \bar{r}_i\]where \(l_j\) and \(\hat{r}_j\) are the LiDAR readings and ray directions respectively, and \(\bar{r}_i\) is an EMA baseline updated with factor \(\alpha\). All weights (\(w_s, w_g, w_c, w_w, \alpha, f\)) are constructor parameters stored in
env_params.Notes
The observation vector per agent is:
Feature
Size
Unit direction to objective
dimClamped displacement
dimVelocity
dimOwn priority
1LiDAR proximity (normalised)
n_lidar_raysRadial relative velocity
n_lidar_raysLiDAR neighbour priority
n_lidar_raysNumber of angular bins for each LiDAR sensor.
Create a multi-agent navigator environment.
- Parameters:
N (int) – Number of agents.
min_box_size (float) – Range for the random square domain side length sampled at each
reset().max_box_size (float) – Range for the random square domain side length sampled at each
reset().box_padding (float) – Extra padding around the domain in multiples of the particle radius.
max_steps (int) – Episode length in physics steps.
friction (float) – Viscous drag coefficient applied as
-friction * vel.shaping_weight (float) – Multiplier \(w_s\) on the potential-based shaping signal.
goal_weight (float) – Bonus \(w_g\) for being on target.
crowding_weight (float) – Penalty \(w_c\) per unit of LiDAR proximity sum.
work_weight (float) – Weight \(w_w\) of the quadratic action penalty \(\|a\|^2\).
goal_radius_factor (float) – Multiplicative factor \(f\) applied to the particle radius to define the goal activation threshold \(d < f \cdot r\).
alpha_r_bar (float) – EMA smoothing factor \(\alpha\) for the differential reward baseline \(\bar{r}\).
lidar_range (float) – Maximum detection range for the LiDAR sensor.
n_lidar_rays (int) – Number of angular LiDAR bins spanning \([-\pi, \pi)\).
- Returns:
A freshly constructed environment (call
reset()before use).- Return type:
Reset the environment to a random initial configuration.
- Parameters:
env (Environment) – The environment instance to reset.
key (ArrayLike) – PRNG key used to sample the domain, positions, objectives, and initial velocities.
- Returns:
The environment with a fresh episode state.
- Return type:
Advance the environment by one physics step.
Applies force actions with viscous drag. After integration the method updates LiDAR sensors, displacement caches, and computes the reward with a differential baseline.
- Parameters:
env (Environment) – Current environment.
action (jax.Array) – Force actions for every agent, shape
(N * dim,).
- Returns:
Updated environment after physics integration, sensor updates, and reward computation.
- Return type:
Build the per-agent observation vector from cached sensors.
All state-dependent components are pre-computed in
step()andreset(). This method only concatenates cached arrays.- Returns:
Observation matrix of shape
(N, obs_dim). See the class docstring for the feature layout.- Return type:
jax.Array
Return the reward cached by
step().- Returns:
Reward vector of shape
(N,).- Return type:
jax.Array
Return
Truewhen the episode has exceededmax_steps.
Number of scalar actions per agent (equal to
dim).
Shape of a single agent’s action (
(dim,)).
Dimensionality of a single agent’s observation vector.
- class jaxdem.rl.environments.MultiRoller(state: State, system: System, env_params: dict[str, Any], n_lidar_rays: int)#
Bases:
EnvironmentMulti-agent 3-D rolling environment with cooperative rewards.
Each agent is a sphere resting on a \(z = 0\) floor under gravity. Actions are 3-D torque vectors; translational motion arises from frictional contact with the floor (see
frictional_wall_force()). Viscous drag-friction * veland angular damping-ang_damping * ang_velare applied every step. Objectives are assigned one-to-one via a random permutation. Each agent receives a random priority scalar at reset for symmetry breaking.Reward
\[R_i = w_s\,(e^{-2d_i} - e^{-2d_i^{\mathrm{prev}}}) + w_g\,\mathbf{1}[d_i < f \cdot r_i] - w_c\,\left\|\sum_j l_j\,\hat{r}_j\right\| - w_w\,\|a_i\|^2 - \bar{r}_i\]where \(l_j\) and \(\hat{r}_j\) are the LiDAR readings and ray directions respectively, and \(\bar{r}_i\) is an EMA baseline updated with factor \(\alpha\). All weights (\(w_s, w_g, w_c, w_w, \alpha, f\)) are constructor parameters stored in
env_params.Notes
The observation vector per agent is:
Feature
Size
Unit direction to objective (x, y)
2Clamped displacement (x, y)
2Velocity (x, y)
2Angular velocity
3Own priority
1LiDAR proximity (normalised)
n_lidar_raysRadial relative velocity
n_lidar_raysLiDAR neighbour priority
n_lidar_rays- n_lidar_rays: int#
Number of angular bins for each LiDAR sensor.
- classmethod Create(N: int = 64, min_box_size: float = 1.0, max_box_size: float = 1.0, box_padding: float = 5.0, max_steps: int = 5760, friction: float = 0.2, ang_damping: float = 0.07, shaping_weight: float = 1.5, goal_weight: float = 0.001, crowding_weight: float = 0.005, work_weight: float = 0.0005, goal_radius_factor: float = 1.0, alpha_r_bar: float = 0.07, lidar_range: float = 0.3, n_lidar_rays: int = 8) MultiRoller[source]#
Create a multi-agent roller environment.
- Parameters:
N (int) – Number of agents.
min_box_size (float) – Range for the random square domain side length sampled at each
reset().max_box_size (float) – Range for the random square domain side length sampled at each
reset().box_padding (float) – Extra padding around the domain in multiples of the particle radius.
max_steps (int) – Episode length in physics steps.
friction (float) – Viscous drag coefficient applied as
-friction * vel.ang_damping (float) – Angular damping coefficient applied as
-ang_damping * ang_vel.shaping_weight (float) – Multiplier \(w_s\) on the potential-based shaping signal.
goal_weight (float) – Bonus \(w_g\) for being on target.
crowding_weight (float) – Penalty \(w_c\) per unit of LiDAR crowding vector norm.
work_weight (float) – Weight \(w_w\) of the quadratic action penalty \(\|a\|^2\).
goal_radius_factor (float) – Multiplicative factor \(f\) applied to the particle radius to define the goal activation threshold \(d < f \cdot r\).
alpha_r_bar (float) – EMA smoothing factor \(\alpha\) for the differential reward baseline \(\bar{r}\).
lidar_range (float) – Maximum detection range for the LiDAR sensor.
n_lidar_rays (int) – Number of angular LiDAR bins spanning \([-\pi, \pi)\).
- Returns:
A freshly constructed environment (call
reset()before use).- Return type:
- static reset(env: MultiRoller, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source]#
Reset the environment to a random initial configuration.
- Parameters:
env (Environment) – The environment instance to reset.
key (ArrayLike) – PRNG key used to sample the domain, positions, objectives, and initial velocities.
- Returns:
The environment with a fresh episode state.
- Return type:
- static step(env: MultiRoller, action: Array) Environment[source]#
Advance the environment by one physics step.
Applies torque actions with angular damping and viscous drag. After integration the method updates LiDAR sensors, displacement caches, and computes the reward with a differential baseline.
- Parameters:
env (Environment) – Current environment.
action (jax.Array) – Torque actions for every agent, shape
(N * 3,).
- Returns:
Updated environment after physics integration, sensor updates, and reward computation.
- Return type:
- static observation(env: MultiRoller) Array[source]#
Build the per-agent observation vector from cached sensors.
All state-dependent components are pre-computed in
step()andreset(). This method only concatenates cached arrays.- Returns:
Observation matrix of shape
(N, obs_dim). See the class docstring for the feature layout.- Return type:
jax.Array
- static reward(env: MultiRoller) Array[source]#
Return the reward cached by
step().- Returns:
Reward vector of shape
(N,).- Return type:
jax.Array
- static done(env: MultiRoller) Array[source]#
Return
Truewhen the episode has exceededmax_steps.
Bases:
EnvironmentSingle-agent navigation environment toward a fixed target.
The agent controls a force vector that is applied directly to a sphere inside a reflective box. Viscous drag
-friction * velis added each step. The reward uses exponential potential-based shaping:\[\mathrm{rew} = e^{-2\,d} - e^{-2\,d^{\mathrm{prev}}}\]Notes
The observation vector per agent is:
Feature
Size
Unit direction to objective
dimClamped displacement
dimVelocity
dimCreate a single-agent navigator environment.
- Parameters:
dim (int) – Spatial dimensionality (2 or 3).
min_box_size (float) – Range for the random square domain side length.
max_box_size (float) – Range for the random square domain side length.
max_steps (int) – Episode length in physics steps.
friction (float) – Viscous drag coefficient applied as
-friction * vel.work_weight (float) – Penalty coefficient for large actions.
- Returns:
A freshly constructed environment (call
reset()before use).- Return type:
Initialize the environment with a randomly placed particle and velocity.
- Parameters:
env ('SingleNavigator') – Current environment instance.
key (jax.random.PRNGKey) – JAX random number generator key.
- Returns:
Freshly initialized environment.
- Return type:
Advance one step. Actions are forces; simple drag is applied (-friction * vel).
- Parameters:
env (Environment) – The current environment.
action (jax.Array) – The vector of actions each agent in the environment should take.
- Returns:
The updated environment state.
- Return type:
Build per-agent observations.
Contents per agent#
Unit vector to objective (shape (dim,)) –> Direction
Clamped delta to objective (shape (dim,)) –> Local precision
Velocity (shape (dim,))
- returns:
Array of shape
(N, 3 * dim)- rtype:
jax.Array
Returns a vector of per-agent rewards.
Reward:
\[\mathrm{rew}_i = e^{-2 \cdot d_i} - e^{-2 \cdot d_i^{\mathrm{prev}}}\]where \(d_i\) is the distance from agent \(i\) to the objective.
- Parameters:
env (Environment) – Current environment.
- Returns:
Shape
(N,).- Return type:
jax.Array
Returns a boolean indicating whether the environment has ended. The episode terminates when the maximum number of steps is reached.
- Parameters:
env (Environment) – The current environment.
- Returns:
Boolean array indicating whether the episode has ended.
- Return type:
jax.Array
Flattened action size per agent. Actions passed to
step()have shape(A, action_space_size).
Original per-agent action shape (useful for reshaping inside the environment).
Flattened observation size per agent.
observation()returns shape(A, observation_space_size).
- class jaxdem.rl.environments.SingleRoller3D(state: State, system: System, env_params: dict[str, Any])#
Bases:
EnvironmentSingle-agent 3D navigation via torque-controlled rolling.
The agent is a sphere resting on a \(z = 0\) floor under gravity. Actions are 3-D torque vectors; translational motion arises from frictional contact with the floor (see
frictional_wall_force()). A viscous drag-friction * veland a fixed angular damping of0.05 * ang_velare applied each step.The reward uses exponential potential-based shaping:
\[\mathrm{rew} = e^{-2\,d} - e^{-2\,d^{\mathrm{prev}}}\]Notes
The observation vector per agent is:
Feature
Size
Unit direction to objective
2
Clamped displacement (x, y)
2
Velocity (x, y)
2
Angular velocity
3
- classmethod Create(min_box_size: float = 2.0, max_box_size: float = 2.0, max_steps: int = 1000, friction: float = 0.2, work_weight: float = 0.0) SingleRoller3D[source]#
Create a single-agent roller environment.
- Parameters:
min_box_size (float) – Range for the random square domain side length.
max_box_size (float) – Range for the random square domain side length.
max_steps (int) – Episode length in physics steps.
friction (float) – Viscous drag coefficient applied as
-friction * vel.work_weight (float) – Penalty coefficient for large actions.
- Returns:
A freshly constructed environment (call
reset()before use).- Return type:
- static reset(env: SingleRoller3D, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source]#
Randomly place the agent and objective on the floor.
- Parameters:
env (Environment) – Current environment instance.
key (ArrayLike) – JAX PRNG key.
- Returns:
Freshly initialised environment.
- Return type:
- static step(env: SingleRoller3D, action: Array) Environment[source]#
Apply a torque action, advance physics by one step.
- Parameters:
env (Environment) – Current environment.
action (jax.Array) – 3-D torque vector per agent.
- Returns:
Updated environment after one physics step.
- Return type:
- static observation(env: SingleRoller3D) Array[source]#
Per-agent observation vector.
Contents per agent:
Unit displacement to objective projected to x-y (shape
(2,)).Clamped displacement to objective projected to x-y (shape
(2,)).Velocity projected to x-y (shape
(2,)).Angular velocity (shape
(3,)).
- Returns:
Shape
(N, 9).- Return type:
jax.Array
- static reward(env: SingleRoller3D) Array[source]#
Returns a vector of per-agent rewards.
Exponential potential-based shaping:
\[\mathrm{rew}_i = e^{-2 \cdot d_i} - e^{-2 \cdot d_i^{\mathrm{prev}}}\]- Returns:
Shape
(N,).- Return type:
jax.Array
- static done(env: SingleRoller3D) Array[source]#
Truewhenstep_countexceedsmax_steps.
Bases:
EnvironmentMulti-agent 2-D swarm navigation with potential-based rewards.
Each agent controls a force vector applied directly to a sphere inside a reflective box. Viscous drag
-friction * velis added every step. Objectives are shared among all agents; each agent dynamically tracks its k nearest objectives. The potential-based shaping signal is computed independently for each of the k objectives and summed. Occupancy is determined via strict symmetry breaking: only the closest agent to each objective within the activation threshold may claim it.Reward
\[R_i = w_s\,\sum_{j \in \text{top-}k} (e^{-2d_{ij}} - e^{-2d_{ij}^{\mathrm{prev}}}) + w_g\,\mathbf{1}[d_i < f \cdot r_i] - w_c\,\left\|\sum_j l_j\,\hat{r}_j\right\| - w_w\,\|a_i\|^2 + w_v\,\mathbf{1}[\text{all }k\text{ occupied}] - \bar{r}_i\]where \(\bar{r}_i\) is an EMA baseline updated with factor \(\alpha\). All weights are constructor parameters stored in
env_params.Notes
The observation vector per agent is:
Feature
Size
Velocity
dimLiDAR proximity
n_lidar_raysLiDAR radial relative velocity
n_lidar_raysLiDAR objective proximity
n_lidar_raysUnit direction to top k objectives
k_objectives * dimClamped displacement to top k
k_objectives * dimOccupancy status of top k
k_objectivesNumber of angular bins for the agent-to-agent LiDAR sensor.
Number of closest objectives tracked per agent.
Create a swarm navigator environment.
- Parameters:
N (int) – Number of agents.
min_box_size (float) – Range for the random square domain side length sampled at each
reset().max_box_size (float) – Range for the random square domain side length sampled at each
reset().box_padding (float) – Extra padding around the domain in multiples of the particle radius.
max_steps (int) – Episode length in physics steps.
friction (float) – Viscous drag coefficient applied as
-friction * vel.shaping_weight (float) – Multiplier \(w_s\) on the potential-based shaping signal summed over the k nearest objectives.
goal_weight (float) – Bonus \(w_g\) for uniquely claiming a target.
crowding_weight (float) – Penalty \(w_c\) per unit of LiDAR crowding vector norm.
work_weight (float) – Weight \(w_w\) of the quadratic action penalty \(\|a\|^2\).
vacancy_weight (float) – Reward \(w_v\) granted when all k nearest objectives are occupied.
goal_radius_factor (float) – Multiplicative factor \(f\) applied to the particle radius to define the goal activation threshold \(d < f \cdot r\).
alpha_r_bar (float) – EMA smoothing factor \(\alpha\) for the differential reward baseline \(\bar{r}\).
lidar_range (float) – Maximum detection range for the LiDAR sensor.
n_lidar_rays (int) – Number of angular LiDAR bins spanning \([-\pi, \pi)\).
k_objectives (int) – Number of closest objectives tracked per agent.
- Returns:
A freshly constructed environment (call
reset()before use).- Return type:
Reset the environment to a random initial configuration.
- Parameters:
env (Environment) – The environment instance to reset.
key (ArrayLike) – PRNG key used to sample the domain, positions, objectives, and initial velocities.
- Returns:
The environment with a fresh episode state.
- Return type:
Advance the environment by one physics step.
Applies force actions with viscous drag. After integration the method updates all sensor caches and computes the reward with a differential baseline. The shaping signal is summed over the k nearest objectives.
- Parameters:
env (Environment) – Current environment.
action (jax.Array) – Force actions for every agent, shape
(N * dim,).
- Returns:
Updated environment after physics integration, sensor updates, and reward computation.
- Return type:
Build the per-agent observation vector from cached sensors.
All state-dependent components are pre-computed in
step()andreset(). This method only concatenates cached arrays.- Returns:
Observation matrix of shape
(N, obs_dim). See the class docstring for the feature layout.- Return type:
jax.Array
Return the reward cached by
step().- Returns:
Reward vector of shape
(N,).- Return type:
jax.Array
Return
Truewhen the episode has exceededmax_steps.
Number of scalar actions per agent (equal to
dim).
Shape of a single agent’s action (
(dim,)).
Dimensionality of a single agent’s observation vector.
- class jaxdem.rl.environments.SwarmRoller(state: State, system: System, env_params: dict[str, Any], n_lidar_rays: int, k_objectives: int)#
Bases:
EnvironmentMulti-agent 3-D rolling environment with potential-based rewards.
Each agent is a sphere resting on a \(z = 0\) floor under gravity. Actions are 3-D torque vectors; translational motion arises from frictional contact with the floor (see
frictional_wall_force()). Viscous drag-friction * veland angular damping-ang_damping * ang_velare applied every step.Objectives are shared among all agents; each agent dynamically tracks its k nearest objectives. The potential-based shaping signal is computed independently for each of the k objectives and summed. Occupancy is determined via strict symmetry breaking: only the closest agent to each objective within the activation threshold may claim it.
Reward
\[R_i = w_s\,\sum_{j \in \text{top-}k} (e^{-2d_{ij}} - e^{-2d_{ij}^{\mathrm{prev}}}) + w_g\,\mathbf{1}[d_i < f \cdot r_i] - w_c\,\left\|\sum_j l_j\,\hat{r}_j\right\| - w_w\,\|a_i\|^2 + w_v\,\mathbf{1}[\text{all }k\text{ occupied}] - \bar{r}_i\]where \(\bar{r}_i\) is an EMA baseline updated with factor \(\alpha\). All weights are constructor parameters stored in
env_params.Notes
The observation vector per agent is:
Feature
Size
Velocity (x, y)
2Angular velocity
3LiDAR proximity
n_lidar_raysLiDAR radial relative velocity
n_lidar_raysLiDAR objective proximity
n_lidar_raysUnit direction to top k objectives
k_objectives * 2Clamped displacement to top k
k_objectives * 2Occupancy status of top k
k_objectives- n_lidar_rays: int#
Number of angular bins for the agent-to-agent LiDAR sensor.
- k_objectives: int#
Number of closest objectives tracked per agent.
- classmethod Create(N: int = 64, min_box_size: float = 1.0, max_box_size: float = 1.0, box_padding: float = 20.0, max_steps: int = 5760, friction: float = 0.2, ang_damping: float = 0.07, shaping_weight: float = 1.0, goal_weight: float = 0.001, crowding_weight: float = 0.005, work_weight: float = 0.0005, vacancy_weight: float = 0.005, goal_radius_factor: float = 1.0, alpha_r_bar: float = 0.07, lidar_range: float = 0.4, n_lidar_rays: int = 8, k_objectives: int = 5) SwarmRoller[source]#
Create a swarm roller environment.
- Parameters:
N (int) – Number of agents.
min_box_size (float) – Range for the random square domain side length sampled at each
reset().max_box_size (float) – Range for the random square domain side length sampled at each
reset().box_padding (float) – Extra padding around the domain in multiples of the particle radius.
max_steps (int) – Episode length in physics steps.
friction (float) – Viscous drag coefficient applied as
-friction * vel.ang_damping (float) – Angular damping coefficient applied as
-ang_damping * ang_vel.shaping_weight (float) – Multiplier \(w_s\) on the potential-based shaping signal summed over the k nearest objectives.
goal_weight (float) – Bonus \(w_g\) for uniquely claiming a target.
crowding_weight (float) – Penalty \(w_c\) per unit of LiDAR crowding vector norm.
work_weight (float) – Weight \(w_w\) of the quadratic action penalty \(\|a\|^2\).
vacancy_weight (float) – Reward \(w_v\) granted when all k nearest objectives are occupied.
goal_radius_factor (float) – Multiplicative factor \(f\) applied to the particle radius to define the goal activation threshold \(d < f \cdot r\).
alpha_r_bar (float) – EMA smoothing factor \(\alpha\) for the differential reward baseline \(\bar{r}\).
lidar_range (float) – Maximum detection range for the LiDAR sensor.
n_lidar_rays (int) – Number of angular LiDAR bins spanning \([-\pi, \pi)\).
k_objectives (int) – Number of closest objectives tracked per agent.
- Returns:
A freshly constructed environment (call
reset()before use).- Return type:
- static reset(env: SwarmRoller, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source]#
Reset the environment to a random initial configuration.
- Parameters:
env (Environment) – The environment instance to reset.
key (ArrayLike) – PRNG key used to sample the domain, positions, objectives, and initial velocities.
- Returns:
The environment with a fresh episode state.
- Return type:
- static step(env: SwarmRoller, action: Array) Environment[source]#
Advance the environment by one physics step.
Applies torque actions with angular damping and viscous drag. After integration the method updates all sensor caches and computes the reward with a differential baseline. The shaping signal is summed over the k nearest objectives.
- Parameters:
env (Environment) – Current environment.
action (jax.Array) – Torque actions for every agent, shape
(N * 3,).
- Returns:
Updated environment after physics integration, sensor updates, and reward computation.
- Return type:
- static observation(env: SwarmRoller) Array[source]#
Build the per-agent observation vector from cached sensors.
All state-dependent components are pre-computed in
step()andreset(). This method only concatenates cached arrays.- Returns:
Observation matrix of shape
(N, obs_dim). See the class docstring for the feature layout.- Return type:
jax.Array
- static reward(env: SwarmRoller) Array[source]#
Return the reward cached by
step().- Returns:
Reward vector of shape
(N,).- Return type:
jax.Array
- static done(env: SwarmRoller) Array[source]#
Return
Truewhen the episode has exceededmax_steps.
- class jaxdem.rl.environments.SwarmRoller3D(state: State, system: System, env_params: dict[str, Any], n_lidar_rays: int, n_lidar_elevation: int, k_objectives: int, n_objectives: int)#
Bases:
EnvironmentMulti-agent 3-D rolling environment with magnetic interaction and pyramid objectives. Extends the swarm roller with two additions:
Each agent has an extra binary magnet action. When two nearby agents both activate their magnets the mutual attraction is twice as strong:
\[\mathbf{F}_{ij}^{\text{mag}} = -w_{\text{mag}} \, (m_i + m_j) \, \max\!\bigl(0,\; 1 - d/r_{\text{mag}}\bigr) \, \hat{n}_{ij}\]where \(m_i \in \{0, 1\}\) is the magnet flag for agent i, \(d = \|r_{ij}\|\), and \(r_{\text{mag}}\) is
magnet_range.Pyramid objectives. Objectives are arranged in a pyramid: base layer on the floor and elevated apex targets. Agents must stack on top of one another to reach elevated targets. Occupancy uses full 3-D distance to prevent false apex claims.
Reward
\[R_i = w_s\,\sum_{j \in \text{top-}k} (e^{-2d_{ij}} - e^{-2d_{ij}^{\mathrm{prev}}}) + w_{th}\,\frac{1}{N}\sum_{m=1}^{N} z_m + w_g\,\mathbf{1}[\text{on target}] - w_w\,\|a_i\|^2 - w_{\mathrm{vel}}\,\|v_i\|^2 - \bar{r}_i\]where \(\bar{r}_i\) is an EMA baseline updated with factor \(\alpha\), \(w_{th}\) scales the reward for the average team height, \(w_g\) is the bonus for being on a target, and \(w_{\mathrm{vel}}\) penalises high agent velocity. All weights are constructor parameters stored in
env_params.Notes
The observation vector per agent is:
Feature
Size
Velocity (x, y, z)
3Angular velocity
3Magnet flag
1LiDAR proximity (normalised)
n_lidar_rays * n_lidar_elevationRadial relative velocity
n_lidar_rays * n_lidar_elevationObjective LiDAR proximity
n_lidar_rays * n_lidar_elevationUnit direction to top k objectives
k_objectives * 3Clamped displacement to top k
k_objectives * 3Occupancy status of top k
k_objectives- n_lidar_rays: int#
Number of azimuthal bins for the 3-D LiDAR sensor.
- n_lidar_elevation: int#
Number of elevation bins for the 3-D LiDAR sensor.
- k_objectives: int#
Number of closest objectives tracked per agent.
- n_objectives: int#
Number of shared objectives.
- classmethod Create(N: int = 5, n_objectives: int = 5, min_box_size: float = 1.0, max_box_size: float = 1.0, box_padding: float = 0.0, max_steps: int = 5760, friction: float = 0.2, ang_damping: float = 0.07, shaping_weight: float = 2.0, team_height_weight: float = 1.0, goal_weight: float = 0.0, work_weight: float = 0.0, velocity_weight: float = 0.018, goal_radius_factor: float = 1.0, alpha_r_bar: float = 0.07, lidar_range: float = 0.4, n_lidar_rays: int = 6, n_lidar_elevation: int = 6, k_objectives: int = 4, magnet_strength: float = 40.0, magnet_range: float = 0.12) SwarmRoller3D[source]#
Create a swarm roller 3-D environment.
- Parameters:
N (int) – Number of agents.
n_objectives (int) – Number of shared objectives.
min_box_size (float) – Range for the random square domain side length sampled at each
reset().max_box_size (float) – Range for the random square domain side length sampled at each
reset().box_padding (float) – Extra padding around the domain in multiples of the particle radius.
max_steps (int) – Episode length in physics steps.
friction (float) – Viscous drag coefficient applied as
-friction * vel.ang_damping (float) – Angular damping coefficient applied as
-ang_damping * ang_vel.shaping_weight (float) – Multiplier \(w_s\) on the potential-based shaping signal summed over the k nearest objectives.
team_height_weight (float) – Weight \(w_{th}\) scaling the average z-height of the swarm as a global reward.
goal_weight (float) – Bonus \(w_g\) for being positioned on a target.
work_weight (float) – Weight \(w_w\) of the quadratic action penalty \(\|a\|^2\).
velocity_weight (float) – Penalty \(w_{\mathrm{vel}}\) on the squared velocity magnitude \(\|v_i\|^2\).
goal_radius_factor (float) – Multiplicative factor \(f\) applied to the particle radius to define the goal activation threshold \(d < f \cdot r\).
alpha_r_bar (float) – EMA smoothing factor \(\alpha\) for the differential reward baseline \(\bar{r}\).
lidar_range (float) – Maximum detection range for the LiDAR sensor.
n_lidar_rays (int) – Number of azimuthal LiDAR bins spanning \([-\pi, \pi)\).
n_lidar_elevation (int) – Number of elevation LiDAR bins spanning \([-\pi/2, \pi/2]\).
k_objectives (int) – Number of closest objectives tracked per agent.
magnet_strength (float) – Magnitude of the magnetic attraction force.
magnet_range (float) – Maximum range for magnetic interaction (beyond this the force is zero).
- Returns:
A freshly constructed environment (call
reset()before use).- Return type:
- static reset(env: SwarmRoller3D, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source]#
Reset the environment to a random initial configuration.
- Parameters:
env (Environment) – The environment instance to reset.
key (ArrayLike) – PRNG key used to sample the domain, positions, objectives, and initial velocities.
- Returns:
The environment with a fresh episode state.
- Return type:
- static step(env: SwarmRoller3D, action: Array) Environment[source]#
Advance the environment by one physics step.
Applies torque actions with angular damping, viscous drag, and pairwise magnetic attraction. After integration the method updates all sensor caches and computes the reward with a differential baseline. The shaping signal is summed over the k nearest objectives.
- Parameters:
env (Environment) – Current environment.
action (jax.Array) – Actions for every agent, shape
(N * 4,)(3-D torque + magnet flag).
- Returns:
Updated environment after physics integration, sensor updates, and reward computation.
- Return type:
- static observation(env: SwarmRoller3D) Array[source]#
Build the per-agent observation vector from cached sensors. All state-dependent components are pre-computed in
step()andreset(). This method only concatenates cached arrays.- Returns:
Observation matrix of shape
(N, obs_dim). See the class docstring for the feature layout.- Return type:
jax.Array
- static reward(env: SwarmRoller3D) Array[source]#
Return the reward cached by
step().- Returns:
Reward vector of shape
(N,).- Return type:
jax.Array
- static done(env: SwarmRoller3D) Array[source]#
Return
Truewhen the episode has exceededmax_steps.
- class jaxdem.rl.environments.SwarmStacking3D(state: State, system: System, env_params: dict[str, Any], n_lidar_rays: int, n_lidar_elevation: int)#
Bases:
EnvironmentMulti-agent 3-D stacking environment with periodic boundaries.
Agents must stack on top of one another to reach as high as possible.
Reward
\[R_i = w_{climb} (0.8 \cdot z_i + 0.2 \cdot \bar{z}_t) + w_{cohesion} \sum \text{lidar} - w_w\,\|\tau_i\|^2 - w_{\mathrm{vel}}\,\|v_i\|^2 - \bar{r}_i\]where \(\bar{z}_t\) is the average height of the swarm.
Boundary Conditions: - Periodic in X and Y. - Frictional floor at Z=0. - Effectively unbounded Z (large box size).
- n_lidar_rays: int#
Number of azimuthal bins for the 3-D LiDAR sensor.
- n_lidar_elevation: int#
Number of elevation bins for the 3-D LiDAR sensor.
- classmethod Create(N: int = 16, min_box_size: float = 0.5, max_box_size: float = 0.5, box_padding: float = 0.0, max_steps: int = 5760, friction: float = 0.2, ang_damping: float = 0.07, climb_weight: float = 20.0, cohesion_weight: float = 0.05, work_weight: float = 0.0, velocity_weight: float = 2.0, alpha_r_bar: float = 0.07, lidar_range: float = 0.5, n_lidar_rays: int = 8, n_lidar_elevation: int = 8, magnet_strength: float = 40.0, magnet_range: float = 0.12) SwarmStacking3D[source]#
Create a swarm stacking 3-D environment.
- static reset(env: SwarmStacking3D, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source]#
- static step(env: SwarmStacking3D, action: Array) Environment[source]#
- static observation(env: SwarmStacking3D) Array[source]#
- static reward(env: SwarmStacking3D) Array[source]#
- static done(env: SwarmStacking3D) Array[source]#
- property action_space_size: int[source]#
Flattened action size per agent. Actions passed to
step()have shape(A, action_space_size).
- property action_space_shape: tuple[int][source]#
Original per-agent action shape (useful for reshaping inside the environment).
- property observation_space_size: int[source]#
Flattened observation size per agent.
observation()returns shape(A, observation_space_size).
Modules
Multi-agent 2-D navigation with collision avoidance and cooperative rewards. |
|
Multi-agent 3-D rolling environment with LiDAR sensing. |
|
Environment where a single agent navigates towards a target. |
|
Environment where a single agent rolls towards a target on the floor. |
|
Multi-agent 2-D swarm navigation with potential-based rewards. |
|
Multi-agent 3-D swarm rolling environment with potential-based rewards. |
|
Multi-agent 3-D swarm rolling environment with magnetic interaction and pyramid objectives. |
|
Multi-agent 3-D swarm stacking environment with periodic boundaries. |