jaxdem.rl.environments#

Reinforcement-learning environment interface.

Classes

Environment(state, system, env_params)

Defines the interface for reinforcement-learning environments.

class jaxdem.rl.environments.Environment(state: State, system: System, env_params: dict[str, Any])#

Bases: Factory, ABC

Defines the interface for reinforcement-learning environments.

  • Let A be the number of agents (A ≥ 1). Single-agent environments still use A=1.

  • Observations and actions are flattened per agent to fixed sizes. Use action_space_shape to reshape inside the environment if needed.

Required shapes

  • Observation: (A, observation_space_size)

  • Action (input to step()): (A, action_space_size)

  • Reward: (A,)

  • Done: scalar boolean for the whole environment

Todo: - Truncated data field: per-agent termination flag - Render method

Example:#

To define a custom environment, inherit from Environment and implement the abstract methods:

>>> @Environment.register("Environment")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True)
>>> class MyCustomEnv(Environment):
    ...
state: State#

Simulation state.

system: System#

Simulation system configuration.

env_params: dict[str, Any]#

Environment-specific parameters.

abstractmethod static reset(env: Environment, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source]#

Initialize the environment to a valid start state.

Parameters:
  • env ('MyCustomEnv') – Instance of the environment.

  • key (jax.random.PRNGKey) – JAX random number generator key.

Returns:

Freshly initialized environment.

Return type:

Environment

static reset_if_done(env: Environment, done: Array, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source]#

Conditionally resets the environment if the environment has reached a terminal state.

This method checks the done flag and, if True, calls the environment’s reset method to reinitialize the state. Otherwise, it returns the current environment unchanged.

Parameters:
  • env (Environment) – The current environment instance.

  • done (jax.Array) – A boolean flag indicating whether the environment has reached a terminal state.

  • key (jax.random.PRNGKey) – JAX random number generator key used for reinitialization.

Returns:

Either the freshly reset environment (if done is True) or the unchanged environment (if done is False).

Return type:

Environment

abstractmethod static step(env: Environment, action: Array) Environment[source]#

Advance the simulation by one step using per-agent actions.

Parameters:
  • env (Environment) – The current environment.

  • action (jax.Array) – The vector of actions each agent in the environment should take.

Returns:

The updated environment state.

Return type:

Environment

abstractmethod static observation(env: Environment) Array[source]#

Returns the per-agent observation vector.

Parameters:

env (Environment) – The current environment.

Returns:

Vector corresponding to the environment observation.

Return type:

jax.Array

abstractmethod static reward(env: Environment) Array[source]#

Returns the per-agent immediate rewards.

Parameters:

env (Environment) – The current environment.

Returns:

Vector corresponding to all the agent’s rewards based on the current environment state.

Return type:

jax.Array

abstractmethod static done(env: Environment) Array[source]#

Returns a boolean indicating whether the environment has ended.

Parameters:

env (Environment) – The current environment.

Returns:

A bool indicating when the environment ended

Return type:

jax.Array

static info(env: Environment) dict[str, Any][source]#

Return auxiliary diagnostic information.

By default, returns an empty dict. Subclasses may override to provide environment specific information.

Parameters:

env (Environment) – The current state of the environment.

Returns:

A dictionary with additional information about the environment.

Return type:

Dict

property num_envs: int[source]#

Number of batched environments.

property max_num_agents: int[source]#

Maximum number of active agents in the environment.

property action_space_size: int[source]#

Flattened action size per agent. Actions passed to step() have shape (A, action_space_size).

property action_space_shape: tuple[int][source]#

Original per-agent action shape (useful for reshaping inside the environment).

property observation_space_size: int[source]#

Flattened observation size per agent. observation() returns shape (A, observation_space_size).

class jaxdem.rl.environments.MultiNavigator(state: State, system: System, env_params: dict[str, Any], n_lidar_rays: int)#

Bases: Environment

Multi-agent 2-D navigation with cooperative rewards.

Each agent controls a force vector applied directly to a sphere inside a reflective box. Viscous drag -friction * vel is added every step. Objectives are assigned one-to-one via a random permutation. Each agent receives a random priority scalar at reset for symmetry breaking.

Reward

\[R_i = w_s\,(e^{-2d_i} - e^{-2d_i^{\mathrm{prev}}}) + w_g\,\mathbf{1}[d_i < f \cdot r_i] - w_c\,\left\|\sum_j l_j\,\hat{r}_j\right\| - w_w\,\|a_i\|^2 - \bar{r}_i\]

where \(l_j\) and \(\hat{r}_j\) are the LiDAR readings and ray directions respectively, and \(\bar{r}_i\) is an EMA baseline updated with factor \(\alpha\). All weights (\(w_s, w_g, w_c, w_w, \alpha, f\)) are constructor parameters stored in env_params.

Notes

The observation vector per agent is:

Feature

Size

Unit direction to objective

dim

Clamped displacement

dim

Velocity

dim

Own priority

1

LiDAR proximity (normalised)

n_lidar_rays

Radial relative velocity

n_lidar_rays

LiDAR neighbour priority

n_lidar_rays

n_lidar_rays: int#

Number of angular bins for each LiDAR sensor.

classmethod Create(N: int = 64, min_box_size: float = 1.0, max_box_size: float = 1.0, box_padding: float = 5.0, max_steps: int = 5760, friction: float = 0.2, shaping_weight: float = 1.5, goal_weight: float = 0.001, crowding_weight: float = 0.005, work_weight: float = 0.0005, goal_radius_factor: float = 1.0, alpha_r_bar: float = 0.07, lidar_range: float = 0.3, n_lidar_rays: int = 8) MultiNavigator[source]#

Create a multi-agent navigator environment.

Parameters:
  • N (int) – Number of agents.

  • min_box_size (float) – Range for the random square domain side length sampled at each reset().

  • max_box_size (float) – Range for the random square domain side length sampled at each reset().

  • box_padding (float) – Extra padding around the domain in multiples of the particle radius.

  • max_steps (int) – Episode length in physics steps.

  • friction (float) – Viscous drag coefficient applied as -friction * vel.

  • shaping_weight (float) – Multiplier \(w_s\) on the potential-based shaping signal.

  • goal_weight (float) – Bonus \(w_g\) for being on target.

  • crowding_weight (float) – Penalty \(w_c\) per unit of LiDAR proximity sum.

  • work_weight (float) – Weight \(w_w\) of the quadratic action penalty \(\|a\|^2\).

  • goal_radius_factor (float) – Multiplicative factor \(f\) applied to the particle radius to define the goal activation threshold \(d < f \cdot r\).

  • alpha_r_bar (float) – EMA smoothing factor \(\alpha\) for the differential reward baseline \(\bar{r}\).

  • lidar_range (float) – Maximum detection range for the LiDAR sensor.

  • n_lidar_rays (int) – Number of angular LiDAR bins spanning \([-\pi, \pi)\).

Returns:

A freshly constructed environment (call reset() before use).

Return type:

MultiNavigator

static reset(env: MultiNavigator, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source]#

Reset the environment to a random initial configuration.

Parameters:
  • env (Environment) – The environment instance to reset.

  • key (ArrayLike) – PRNG key used to sample the domain, positions, objectives, and initial velocities.

Returns:

The environment with a fresh episode state.

Return type:

Environment

static step(env: MultiNavigator, action: Array) Environment[source]#

Advance the environment by one physics step.

Applies force actions with viscous drag. After integration the method updates LiDAR sensors, displacement caches, and computes the reward with a differential baseline.

Parameters:
  • env (Environment) – Current environment.

  • action (jax.Array) – Force actions for every agent, shape (N * dim,).

Returns:

Updated environment after physics integration, sensor updates, and reward computation.

Return type:

Environment

static observation(env: MultiNavigator) Array[source]#

Build the per-agent observation vector from cached sensors.

All state-dependent components are pre-computed in step() and reset(). This method only concatenates cached arrays.

Returns:

Observation matrix of shape (N, obs_dim). See the class docstring for the feature layout.

Return type:

jax.Array

static reward(env: MultiNavigator) Array[source]#

Return the reward cached by step().

Returns:

Reward vector of shape (N,).

Return type:

jax.Array

static done(env: MultiNavigator) Array[source]#

Return True when the episode has exceeded max_steps.

property action_space_size: int[source]#

Number of scalar actions per agent (equal to dim).

property action_space_shape: tuple[int][source]#

Shape of a single agent’s action ((dim,)).

property observation_space_size: int[source]#

Dimensionality of a single agent’s observation vector.

class jaxdem.rl.environments.MultiRoller(state: State, system: System, env_params: dict[str, Any], n_lidar_rays: int)#

Bases: Environment

Multi-agent 3-D rolling environment with cooperative rewards.

Each agent is a sphere resting on a \(z = 0\) floor under gravity. Actions are 3-D torque vectors; translational motion arises from frictional contact with the floor (see frictional_wall_force()). Viscous drag -friction * vel and angular damping -ang_damping * ang_vel are applied every step. Objectives are assigned one-to-one via a random permutation. Each agent receives a random priority scalar at reset for symmetry breaking.

Reward

\[R_i = w_s\,(e^{-2d_i} - e^{-2d_i^{\mathrm{prev}}}) + w_g\,\mathbf{1}[d_i < f \cdot r_i] - w_c\,\left\|\sum_j l_j\,\hat{r}_j\right\| - w_w\,\|a_i\|^2 - \bar{r}_i\]

where \(l_j\) and \(\hat{r}_j\) are the LiDAR readings and ray directions respectively, and \(\bar{r}_i\) is an EMA baseline updated with factor \(\alpha\). All weights (\(w_s, w_g, w_c, w_w, \alpha, f\)) are constructor parameters stored in env_params.

Notes

The observation vector per agent is:

Feature

Size

Unit direction to objective (x, y)

2

Clamped displacement (x, y)

2

Velocity (x, y)

2

Angular velocity

3

Own priority

1

LiDAR proximity (normalised)

n_lidar_rays

Radial relative velocity

n_lidar_rays

LiDAR neighbour priority

n_lidar_rays

n_lidar_rays: int#

Number of angular bins for each LiDAR sensor.

classmethod Create(N: int = 64, min_box_size: float = 1.0, max_box_size: float = 1.0, box_padding: float = 5.0, max_steps: int = 5760, friction: float = 0.2, ang_damping: float = 0.07, shaping_weight: float = 1.5, goal_weight: float = 0.001, crowding_weight: float = 0.005, work_weight: float = 0.0005, goal_radius_factor: float = 1.0, alpha_r_bar: float = 0.07, lidar_range: float = 0.3, n_lidar_rays: int = 8) MultiRoller[source]#

Create a multi-agent roller environment.

Parameters:
  • N (int) – Number of agents.

  • min_box_size (float) – Range for the random square domain side length sampled at each reset().

  • max_box_size (float) – Range for the random square domain side length sampled at each reset().

  • box_padding (float) – Extra padding around the domain in multiples of the particle radius.

  • max_steps (int) – Episode length in physics steps.

  • friction (float) – Viscous drag coefficient applied as -friction * vel.

  • ang_damping (float) – Angular damping coefficient applied as -ang_damping * ang_vel.

  • shaping_weight (float) – Multiplier \(w_s\) on the potential-based shaping signal.

  • goal_weight (float) – Bonus \(w_g\) for being on target.

  • crowding_weight (float) – Penalty \(w_c\) per unit of LiDAR crowding vector norm.

  • work_weight (float) – Weight \(w_w\) of the quadratic action penalty \(\|a\|^2\).

  • goal_radius_factor (float) – Multiplicative factor \(f\) applied to the particle radius to define the goal activation threshold \(d < f \cdot r\).

  • alpha_r_bar (float) – EMA smoothing factor \(\alpha\) for the differential reward baseline \(\bar{r}\).

  • lidar_range (float) – Maximum detection range for the LiDAR sensor.

  • n_lidar_rays (int) – Number of angular LiDAR bins spanning \([-\pi, \pi)\).

Returns:

A freshly constructed environment (call reset() before use).

Return type:

MultiRoller

static reset(env: MultiRoller, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source]#

Reset the environment to a random initial configuration.

Parameters:
  • env (Environment) – The environment instance to reset.

  • key (ArrayLike) – PRNG key used to sample the domain, positions, objectives, and initial velocities.

Returns:

The environment with a fresh episode state.

Return type:

Environment

static step(env: MultiRoller, action: Array) Environment[source]#

Advance the environment by one physics step.

Applies torque actions with angular damping and viscous drag. After integration the method updates LiDAR sensors, displacement caches, and computes the reward with a differential baseline.

Parameters:
  • env (Environment) – Current environment.

  • action (jax.Array) – Torque actions for every agent, shape (N * 3,).

Returns:

Updated environment after physics integration, sensor updates, and reward computation.

Return type:

Environment

static observation(env: MultiRoller) Array[source]#

Build the per-agent observation vector from cached sensors.

All state-dependent components are pre-computed in step() and reset(). This method only concatenates cached arrays.

Returns:

Observation matrix of shape (N, obs_dim). See the class docstring for the feature layout.

Return type:

jax.Array

static reward(env: MultiRoller) Array[source]#

Return the reward cached by step().

Returns:

Reward vector of shape (N,).

Return type:

jax.Array

static done(env: MultiRoller) Array[source]#

Return True when the episode has exceeded max_steps.

property action_space_size: int[source]#

Number of scalar actions per agent (3-D torque).

property action_space_shape: tuple[int][source]#

Shape of a single agent’s action ((3,)).

property observation_space_size: int[source]#

Dimensionality of a single agent’s observation vector.

class jaxdem.rl.environments.SingleNavigator(state: State, system: System, env_params: dict[str, Any])#

Bases: Environment

Single-agent navigation environment toward a fixed target.

The agent controls a force vector that is applied directly to a sphere inside a reflective box. Viscous drag -friction * vel is added each step. The reward uses exponential potential-based shaping:

\[\mathrm{rew} = e^{-2\,d} - e^{-2\,d^{\mathrm{prev}}}\]

Notes

The observation vector per agent is:

Feature

Size

Unit direction to objective

dim

Clamped displacement

dim

Velocity

dim

classmethod Create(dim: int = 2, min_box_size: float = 2.0, max_box_size: float = 2.0, max_steps: int = 1000, friction: float = 0.2, work_weight: float = 0.0001) SingleNavigator[source]#

Create a single-agent navigator environment.

Parameters:
  • dim (int) – Spatial dimensionality (2 or 3).

  • min_box_size (float) – Range for the random square domain side length.

  • max_box_size (float) – Range for the random square domain side length.

  • max_steps (int) – Episode length in physics steps.

  • friction (float) – Viscous drag coefficient applied as -friction * vel.

  • work_weight (float) – Penalty coefficient for large actions.

Returns:

A freshly constructed environment (call reset() before use).

Return type:

SingleNavigator

static reset(env: SingleNavigator, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source]#

Initialize the environment with a randomly placed particle and velocity.

Parameters:
  • env ('SingleNavigator') – Current environment instance.

  • key (jax.random.PRNGKey) – JAX random number generator key.

Returns:

Freshly initialized environment.

Return type:

Environment

static step(env: SingleNavigator, action: Array) Environment[source]#

Advance one step. Actions are forces; simple drag is applied (-friction * vel).

Parameters:
  • env (Environment) – The current environment.

  • action (jax.Array) – The vector of actions each agent in the environment should take.

Returns:

The updated environment state.

Return type:

Environment

static observation(env: SingleNavigator) Array[source]#

Build per-agent observations.

Contents per agent#

  • Unit vector to objective (shape (dim,)) –> Direction

  • Clamped delta to objective (shape (dim,)) –> Local precision

  • Velocity (shape (dim,))

returns:

Array of shape (N, 3 * dim)

rtype:

jax.Array

static reward(env: SingleNavigator) Array[source]#

Returns a vector of per-agent rewards.

Reward:

\[\mathrm{rew}_i = e^{-2 \cdot d_i} - e^{-2 \cdot d_i^{\mathrm{prev}}}\]

where \(d_i\) is the distance from agent \(i\) to the objective.

Parameters:

env (Environment) – Current environment.

Returns:

Shape (N,).

Return type:

jax.Array

static done(env: SingleNavigator) Array[source]#

Returns a boolean indicating whether the environment has ended. The episode terminates when the maximum number of steps is reached.

Parameters:

env (Environment) – The current environment.

Returns:

Boolean array indicating whether the episode has ended.

Return type:

jax.Array

property action_space_size: int[source]#

Flattened action size per agent. Actions passed to step() have shape (A, action_space_size).

property action_space_shape: tuple[int][source]#

Original per-agent action shape (useful for reshaping inside the environment).

property observation_space_size: int[source]#

Flattened observation size per agent. observation() returns shape (A, observation_space_size).

class jaxdem.rl.environments.SingleRoller3D(state: State, system: System, env_params: dict[str, Any])#

Bases: Environment

Single-agent 3D navigation via torque-controlled rolling.

The agent is a sphere resting on a \(z = 0\) floor under gravity. Actions are 3-D torque vectors; translational motion arises from frictional contact with the floor (see frictional_wall_force()). A viscous drag -friction * vel and a fixed angular damping of 0.05 * ang_vel are applied each step.

The reward uses exponential potential-based shaping:

\[\mathrm{rew} = e^{-2\,d} - e^{-2\,d^{\mathrm{prev}}}\]

Notes

The observation vector per agent is:

Feature

Size

Unit direction to objective

2

Clamped displacement (x, y)

2

Velocity (x, y)

2

Angular velocity

3

classmethod Create(min_box_size: float = 2.0, max_box_size: float = 2.0, max_steps: int = 1000, friction: float = 0.2, work_weight: float = 0.0) SingleRoller3D[source]#

Create a single-agent roller environment.

Parameters:
  • min_box_size (float) – Range for the random square domain side length.

  • max_box_size (float) – Range for the random square domain side length.

  • max_steps (int) – Episode length in physics steps.

  • friction (float) – Viscous drag coefficient applied as -friction * vel.

  • work_weight (float) – Penalty coefficient for large actions.

Returns:

A freshly constructed environment (call reset() before use).

Return type:

SingleRoller3D

static reset(env: SingleRoller3D, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source]#

Randomly place the agent and objective on the floor.

Parameters:
  • env (Environment) – Current environment instance.

  • key (ArrayLike) – JAX PRNG key.

Returns:

Freshly initialised environment.

Return type:

Environment

static step(env: SingleRoller3D, action: Array) Environment[source]#

Apply a torque action, advance physics by one step.

Parameters:
  • env (Environment) – Current environment.

  • action (jax.Array) – 3-D torque vector per agent.

Returns:

Updated environment after one physics step.

Return type:

Environment

static observation(env: SingleRoller3D) Array[source]#

Per-agent observation vector.

Contents per agent:

  • Unit displacement to objective projected to x-y (shape (2,)).

  • Clamped displacement to objective projected to x-y (shape (2,)).

  • Velocity projected to x-y (shape (2,)).

  • Angular velocity (shape (3,)).

Returns:

Shape (N, 9).

Return type:

jax.Array

static reward(env: SingleRoller3D) Array[source]#

Returns a vector of per-agent rewards.

Exponential potential-based shaping:

\[\mathrm{rew}_i = e^{-2 \cdot d_i} - e^{-2 \cdot d_i^{\mathrm{prev}}}\]
Returns:

Shape (N,).

Return type:

jax.Array

static done(env: SingleRoller3D) Array[source]#

True when step_count exceeds max_steps.

property action_space_size: int[source]#

Per-agent flattened action dimensionality (3-D torque).

property action_space_shape: tuple[int][source]#

Per-agent action tensor shape.

property observation_space_size: int[source]#

Per-agent flattened observation dimensionality (9).

class jaxdem.rl.environments.SwarmNavigator(state: State, system: System, env_params: dict[str, Any], n_lidar_rays: int, k_objectives: int)#

Bases: Environment

Multi-agent 2-D swarm navigation with potential-based rewards.

Each agent controls a force vector applied directly to a sphere inside a reflective box. Viscous drag -friction * vel is added every step. Objectives are shared among all agents; each agent dynamically tracks its k nearest objectives. The potential-based shaping signal is computed independently for each of the k objectives and summed. Occupancy is determined via strict symmetry breaking: only the closest agent to each objective within the activation threshold may claim it.

Reward

\[R_i = w_s\,\sum_{j \in \text{top-}k} (e^{-2d_{ij}} - e^{-2d_{ij}^{\mathrm{prev}}}) + w_g\,\mathbf{1}[d_i < f \cdot r_i] - w_c\,\left\|\sum_j l_j\,\hat{r}_j\right\| - w_w\,\|a_i\|^2 + w_v\,\mathbf{1}[\text{all }k\text{ occupied}] - \bar{r}_i\]

where \(\bar{r}_i\) is an EMA baseline updated with factor \(\alpha\). All weights are constructor parameters stored in env_params.

Notes

The observation vector per agent is:

Feature

Size

Velocity

dim

LiDAR proximity

n_lidar_rays

LiDAR radial relative velocity

n_lidar_rays

LiDAR objective proximity

n_lidar_rays

Unit direction to top k objectives

k_objectives * dim

Clamped displacement to top k

k_objectives * dim

Occupancy status of top k

k_objectives

n_lidar_rays: int#

Number of angular bins for the agent-to-agent LiDAR sensor.

k_objectives: int#

Number of closest objectives tracked per agent.

classmethod Create(N: int = 64, min_box_size: float = 1.0, max_box_size: float = 1.0, box_padding: float = 20.0, max_steps: int = 5760, friction: float = 0.2, shaping_weight: float = 2.0, goal_weight: float = 0.001, crowding_weight: float = 0.005, work_weight: float = 0.0005, vacancy_weight: float = 0.005, goal_radius_factor: float = 1.0, alpha_r_bar: float = 0.07, lidar_range: float = 0.4, n_lidar_rays: int = 8, k_objectives: int = 5) SwarmNavigator[source]#

Create a swarm navigator environment.

Parameters:
  • N (int) – Number of agents.

  • min_box_size (float) – Range for the random square domain side length sampled at each reset().

  • max_box_size (float) – Range for the random square domain side length sampled at each reset().

  • box_padding (float) – Extra padding around the domain in multiples of the particle radius.

  • max_steps (int) – Episode length in physics steps.

  • friction (float) – Viscous drag coefficient applied as -friction * vel.

  • shaping_weight (float) – Multiplier \(w_s\) on the potential-based shaping signal summed over the k nearest objectives.

  • goal_weight (float) – Bonus \(w_g\) for uniquely claiming a target.

  • crowding_weight (float) – Penalty \(w_c\) per unit of LiDAR crowding vector norm.

  • work_weight (float) – Weight \(w_w\) of the quadratic action penalty \(\|a\|^2\).

  • vacancy_weight (float) – Reward \(w_v\) granted when all k nearest objectives are occupied.

  • goal_radius_factor (float) – Multiplicative factor \(f\) applied to the particle radius to define the goal activation threshold \(d < f \cdot r\).

  • alpha_r_bar (float) – EMA smoothing factor \(\alpha\) for the differential reward baseline \(\bar{r}\).

  • lidar_range (float) – Maximum detection range for the LiDAR sensor.

  • n_lidar_rays (int) – Number of angular LiDAR bins spanning \([-\pi, \pi)\).

  • k_objectives (int) – Number of closest objectives tracked per agent.

Returns:

A freshly constructed environment (call reset() before use).

Return type:

SwarmNavigator

static reset(env: SwarmNavigator, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source]#

Reset the environment to a random initial configuration.

Parameters:
  • env (Environment) – The environment instance to reset.

  • key (ArrayLike) – PRNG key used to sample the domain, positions, objectives, and initial velocities.

Returns:

The environment with a fresh episode state.

Return type:

Environment

static step(env: SwarmNavigator, action: Array) Environment[source]#

Advance the environment by one physics step.

Applies force actions with viscous drag. After integration the method updates all sensor caches and computes the reward with a differential baseline. The shaping signal is summed over the k nearest objectives.

Parameters:
  • env (Environment) – Current environment.

  • action (jax.Array) – Force actions for every agent, shape (N * dim,).

Returns:

Updated environment after physics integration, sensor updates, and reward computation.

Return type:

Environment

static observation(env: SwarmNavigator) Array[source]#

Build the per-agent observation vector from cached sensors.

All state-dependent components are pre-computed in step() and reset(). This method only concatenates cached arrays.

Returns:

Observation matrix of shape (N, obs_dim). See the class docstring for the feature layout.

Return type:

jax.Array

static reward(env: SwarmNavigator) Array[source]#

Return the reward cached by step().

Returns:

Reward vector of shape (N,).

Return type:

jax.Array

static done(env: SwarmNavigator) Array[source]#

Return True when the episode has exceeded max_steps.

property action_space_size: int[source]#

Number of scalar actions per agent (equal to dim).

property action_space_shape: tuple[int][source]#

Shape of a single agent’s action ((dim,)).

property observation_space_size: int[source]#

Dimensionality of a single agent’s observation vector.

class jaxdem.rl.environments.SwarmRoller(state: State, system: System, env_params: dict[str, Any], n_lidar_rays: int, k_objectives: int)#

Bases: Environment

Multi-agent 3-D rolling environment with potential-based rewards.

Each agent is a sphere resting on a \(z = 0\) floor under gravity. Actions are 3-D torque vectors; translational motion arises from frictional contact with the floor (see frictional_wall_force()). Viscous drag -friction * vel and angular damping -ang_damping * ang_vel are applied every step.

Objectives are shared among all agents; each agent dynamically tracks its k nearest objectives. The potential-based shaping signal is computed independently for each of the k objectives and summed. Occupancy is determined via strict symmetry breaking: only the closest agent to each objective within the activation threshold may claim it.

Reward

\[R_i = w_s\,\sum_{j \in \text{top-}k} (e^{-2d_{ij}} - e^{-2d_{ij}^{\mathrm{prev}}}) + w_g\,\mathbf{1}[d_i < f \cdot r_i] - w_c\,\left\|\sum_j l_j\,\hat{r}_j\right\| - w_w\,\|a_i\|^2 + w_v\,\mathbf{1}[\text{all }k\text{ occupied}] - \bar{r}_i\]

where \(\bar{r}_i\) is an EMA baseline updated with factor \(\alpha\). All weights are constructor parameters stored in env_params.

Notes

The observation vector per agent is:

Feature

Size

Velocity (x, y)

2

Angular velocity

3

LiDAR proximity

n_lidar_rays

LiDAR radial relative velocity

n_lidar_rays

LiDAR objective proximity

n_lidar_rays

Unit direction to top k objectives

k_objectives * 2

Clamped displacement to top k

k_objectives * 2

Occupancy status of top k

k_objectives

n_lidar_rays: int#

Number of angular bins for the agent-to-agent LiDAR sensor.

k_objectives: int#

Number of closest objectives tracked per agent.

classmethod Create(N: int = 64, min_box_size: float = 1.0, max_box_size: float = 1.0, box_padding: float = 20.0, max_steps: int = 5760, friction: float = 0.2, ang_damping: float = 0.07, shaping_weight: float = 1.0, goal_weight: float = 0.001, crowding_weight: float = 0.005, work_weight: float = 0.0005, vacancy_weight: float = 0.005, goal_radius_factor: float = 1.0, alpha_r_bar: float = 0.07, lidar_range: float = 0.4, n_lidar_rays: int = 8, k_objectives: int = 5) SwarmRoller[source]#

Create a swarm roller environment.

Parameters:
  • N (int) – Number of agents.

  • min_box_size (float) – Range for the random square domain side length sampled at each reset().

  • max_box_size (float) – Range for the random square domain side length sampled at each reset().

  • box_padding (float) – Extra padding around the domain in multiples of the particle radius.

  • max_steps (int) – Episode length in physics steps.

  • friction (float) – Viscous drag coefficient applied as -friction * vel.

  • ang_damping (float) – Angular damping coefficient applied as -ang_damping * ang_vel.

  • shaping_weight (float) – Multiplier \(w_s\) on the potential-based shaping signal summed over the k nearest objectives.

  • goal_weight (float) – Bonus \(w_g\) for uniquely claiming a target.

  • crowding_weight (float) – Penalty \(w_c\) per unit of LiDAR crowding vector norm.

  • work_weight (float) – Weight \(w_w\) of the quadratic action penalty \(\|a\|^2\).

  • vacancy_weight (float) – Reward \(w_v\) granted when all k nearest objectives are occupied.

  • goal_radius_factor (float) – Multiplicative factor \(f\) applied to the particle radius to define the goal activation threshold \(d < f \cdot r\).

  • alpha_r_bar (float) – EMA smoothing factor \(\alpha\) for the differential reward baseline \(\bar{r}\).

  • lidar_range (float) – Maximum detection range for the LiDAR sensor.

  • n_lidar_rays (int) – Number of angular LiDAR bins spanning \([-\pi, \pi)\).

  • k_objectives (int) – Number of closest objectives tracked per agent.

Returns:

A freshly constructed environment (call reset() before use).

Return type:

SwarmRoller

static reset(env: SwarmRoller, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source]#

Reset the environment to a random initial configuration.

Parameters:
  • env (Environment) – The environment instance to reset.

  • key (ArrayLike) – PRNG key used to sample the domain, positions, objectives, and initial velocities.

Returns:

The environment with a fresh episode state.

Return type:

Environment

static step(env: SwarmRoller, action: Array) Environment[source]#

Advance the environment by one physics step.

Applies torque actions with angular damping and viscous drag. After integration the method updates all sensor caches and computes the reward with a differential baseline. The shaping signal is summed over the k nearest objectives.

Parameters:
  • env (Environment) – Current environment.

  • action (jax.Array) – Torque actions for every agent, shape (N * 3,).

Returns:

Updated environment after physics integration, sensor updates, and reward computation.

Return type:

Environment

static observation(env: SwarmRoller) Array[source]#

Build the per-agent observation vector from cached sensors.

All state-dependent components are pre-computed in step() and reset(). This method only concatenates cached arrays.

Returns:

Observation matrix of shape (N, obs_dim). See the class docstring for the feature layout.

Return type:

jax.Array

static reward(env: SwarmRoller) Array[source]#

Return the reward cached by step().

Returns:

Reward vector of shape (N,).

Return type:

jax.Array

static done(env: SwarmRoller) Array[source]#

Return True when the episode has exceeded max_steps.

property action_space_size: int[source]#

Number of scalar actions per agent (3-D torque).

property action_space_shape: tuple[int][source]#

Shape of a single agent’s action ((3,)).

property observation_space_size: int[source]#

Dimensionality of a single agent’s observation vector.

class jaxdem.rl.environments.SwarmRoller3D(state: State, system: System, env_params: dict[str, Any], n_lidar_rays: int, n_lidar_elevation: int, k_objectives: int, n_objectives: int)#

Bases: Environment

Multi-agent 3-D rolling environment with magnetic interaction and pyramid objectives. Extends the swarm roller with two additions:

  1. Each agent has an extra binary magnet action. When two nearby agents both activate their magnets the mutual attraction is twice as strong:

    \[\mathbf{F}_{ij}^{\text{mag}} = -w_{\text{mag}} \, (m_i + m_j) \, \max\!\bigl(0,\; 1 - d/r_{\text{mag}}\bigr) \, \hat{n}_{ij}\]

    where \(m_i \in \{0, 1\}\) is the magnet flag for agent i, \(d = \|r_{ij}\|\), and \(r_{\text{mag}}\) is magnet_range.

  2. Pyramid objectives. Objectives are arranged in a pyramid: base layer on the floor and elevated apex targets. Agents must stack on top of one another to reach elevated targets. Occupancy uses full 3-D distance to prevent false apex claims.

Reward

\[R_i = w_s\,\sum_{j \in \text{top-}k} (e^{-2d_{ij}} - e^{-2d_{ij}^{\mathrm{prev}}}) + w_{th}\,\frac{1}{N}\sum_{m=1}^{N} z_m + w_g\,\mathbf{1}[\text{on target}] - w_w\,\|a_i\|^2 - w_{\mathrm{vel}}\,\|v_i\|^2 - \bar{r}_i\]

where \(\bar{r}_i\) is an EMA baseline updated with factor \(\alpha\), \(w_{th}\) scales the reward for the average team height, \(w_g\) is the bonus for being on a target, and \(w_{\mathrm{vel}}\) penalises high agent velocity. All weights are constructor parameters stored in env_params.

Notes

The observation vector per agent is:

Feature

Size

Velocity (x, y, z)

3

Angular velocity

3

Magnet flag

1

LiDAR proximity (normalised)

n_lidar_rays * n_lidar_elevation

Radial relative velocity

n_lidar_rays * n_lidar_elevation

Objective LiDAR proximity

n_lidar_rays * n_lidar_elevation

Unit direction to top k objectives

k_objectives * 3

Clamped displacement to top k

k_objectives * 3

Occupancy status of top k

k_objectives

n_lidar_rays: int#

Number of azimuthal bins for the 3-D LiDAR sensor.

n_lidar_elevation: int#

Number of elevation bins for the 3-D LiDAR sensor.

k_objectives: int#

Number of closest objectives tracked per agent.

n_objectives: int#

Number of shared objectives.

classmethod Create(N: int = 5, n_objectives: int = 5, min_box_size: float = 1.0, max_box_size: float = 1.0, box_padding: float = 0.0, max_steps: int = 5760, friction: float = 0.2, ang_damping: float = 0.07, shaping_weight: float = 2.0, team_height_weight: float = 1.0, goal_weight: float = 0.0, work_weight: float = 0.0, velocity_weight: float = 0.018, goal_radius_factor: float = 1.0, alpha_r_bar: float = 0.07, lidar_range: float = 0.4, n_lidar_rays: int = 6, n_lidar_elevation: int = 6, k_objectives: int = 4, magnet_strength: float = 40.0, magnet_range: float = 0.12) SwarmRoller3D[source]#

Create a swarm roller 3-D environment.

Parameters:
  • N (int) – Number of agents.

  • n_objectives (int) – Number of shared objectives.

  • min_box_size (float) – Range for the random square domain side length sampled at each reset().

  • max_box_size (float) – Range for the random square domain side length sampled at each reset().

  • box_padding (float) – Extra padding around the domain in multiples of the particle radius.

  • max_steps (int) – Episode length in physics steps.

  • friction (float) – Viscous drag coefficient applied as -friction * vel.

  • ang_damping (float) – Angular damping coefficient applied as -ang_damping * ang_vel.

  • shaping_weight (float) – Multiplier \(w_s\) on the potential-based shaping signal summed over the k nearest objectives.

  • team_height_weight (float) – Weight \(w_{th}\) scaling the average z-height of the swarm as a global reward.

  • goal_weight (float) – Bonus \(w_g\) for being positioned on a target.

  • work_weight (float) – Weight \(w_w\) of the quadratic action penalty \(\|a\|^2\).

  • velocity_weight (float) – Penalty \(w_{\mathrm{vel}}\) on the squared velocity magnitude \(\|v_i\|^2\).

  • goal_radius_factor (float) – Multiplicative factor \(f\) applied to the particle radius to define the goal activation threshold \(d < f \cdot r\).

  • alpha_r_bar (float) – EMA smoothing factor \(\alpha\) for the differential reward baseline \(\bar{r}\).

  • lidar_range (float) – Maximum detection range for the LiDAR sensor.

  • n_lidar_rays (int) – Number of azimuthal LiDAR bins spanning \([-\pi, \pi)\).

  • n_lidar_elevation (int) – Number of elevation LiDAR bins spanning \([-\pi/2, \pi/2]\).

  • k_objectives (int) – Number of closest objectives tracked per agent.

  • magnet_strength (float) – Magnitude of the magnetic attraction force.

  • magnet_range (float) – Maximum range for magnetic interaction (beyond this the force is zero).

Returns:

A freshly constructed environment (call reset() before use).

Return type:

SwarmRoller3D

static reset(env: SwarmRoller3D, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source]#

Reset the environment to a random initial configuration.

Parameters:
  • env (Environment) – The environment instance to reset.

  • key (ArrayLike) – PRNG key used to sample the domain, positions, objectives, and initial velocities.

Returns:

The environment with a fresh episode state.

Return type:

Environment

static step(env: SwarmRoller3D, action: Array) Environment[source]#

Advance the environment by one physics step.

Applies torque actions with angular damping, viscous drag, and pairwise magnetic attraction. After integration the method updates all sensor caches and computes the reward with a differential baseline. The shaping signal is summed over the k nearest objectives.

Parameters:
  • env (Environment) – Current environment.

  • action (jax.Array) – Actions for every agent, shape (N * 4,) (3-D torque + magnet flag).

Returns:

Updated environment after physics integration, sensor updates, and reward computation.

Return type:

Environment

static observation(env: SwarmRoller3D) Array[source]#

Build the per-agent observation vector from cached sensors. All state-dependent components are pre-computed in step() and reset(). This method only concatenates cached arrays.

Returns:

Observation matrix of shape (N, obs_dim). See the class docstring for the feature layout.

Return type:

jax.Array

static reward(env: SwarmRoller3D) Array[source]#

Return the reward cached by step().

Returns:

Reward vector of shape (N,).

Return type:

jax.Array

static done(env: SwarmRoller3D) Array[source]#

Return True when the episode has exceeded max_steps.

property action_space_size: int[source]#

Number of scalar actions per agent (3-D torque + magnet).

property action_space_shape: tuple[int][source]#

Shape of a single agent’s action ((4,)).

property observation_space_size: int[source]#

Dimensionality of a single agent’s observation vector.

class jaxdem.rl.environments.SwarmStacking3D(state: State, system: System, env_params: dict[str, Any], n_lidar_rays: int, n_lidar_elevation: int)#

Bases: Environment

Multi-agent 3-D stacking environment with periodic boundaries.

Agents must stack on top of one another to reach as high as possible.

Reward

\[R_i = w_{climb} (0.8 \cdot z_i + 0.2 \cdot \bar{z}_t) + w_{cohesion} \sum \text{lidar} - w_w\,\|\tau_i\|^2 - w_{\mathrm{vel}}\,\|v_i\|^2 - \bar{r}_i\]

where \(\bar{z}_t\) is the average height of the swarm.

Boundary Conditions: - Periodic in X and Y. - Frictional floor at Z=0. - Effectively unbounded Z (large box size).

n_lidar_rays: int#

Number of azimuthal bins for the 3-D LiDAR sensor.

n_lidar_elevation: int#

Number of elevation bins for the 3-D LiDAR sensor.

classmethod Create(N: int = 16, min_box_size: float = 0.5, max_box_size: float = 0.5, box_padding: float = 0.0, max_steps: int = 5760, friction: float = 0.2, ang_damping: float = 0.07, climb_weight: float = 20.0, cohesion_weight: float = 0.05, work_weight: float = 0.0, velocity_weight: float = 2.0, alpha_r_bar: float = 0.07, lidar_range: float = 0.5, n_lidar_rays: int = 8, n_lidar_elevation: int = 8, magnet_strength: float = 40.0, magnet_range: float = 0.12) SwarmStacking3D[source]#

Create a swarm stacking 3-D environment.

static reset(env: SwarmStacking3D, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source]#
static step(env: SwarmStacking3D, action: Array) Environment[source]#
static observation(env: SwarmStacking3D) Array[source]#
static reward(env: SwarmStacking3D) Array[source]#
static done(env: SwarmStacking3D) Array[source]#
property action_space_size: int[source]#

Flattened action size per agent. Actions passed to step() have shape (A, action_space_size).

property action_space_shape: tuple[int][source]#

Original per-agent action shape (useful for reshaping inside the environment).

property observation_space_size: int[source]#

Flattened observation size per agent. observation() returns shape (A, observation_space_size).

Modules

multi_navigator

Multi-agent 2-D navigation with collision avoidance and cooperative rewards.

multi_roller

Multi-agent 3-D rolling environment with LiDAR sensing.

single_navigator

Environment where a single agent navigates towards a target.

single_roller

Environment where a single agent rolls towards a target on the floor.

swarm_navigator

Multi-agent 2-D swarm navigation with potential-based rewards.

swarm_roller

Multi-agent 3-D swarm rolling environment with potential-based rewards.

swarm_roller_3d

Multi-agent 3-D swarm rolling environment with magnetic interaction and pyramid objectives.

swarm_stacking_3d

Multi-agent 3-D swarm stacking environment with periodic boundaries.