jaxdem.rl.environments#

Reinforcement-learning environment interface.

Classes

Environment(state, system, env_params)

Defines the interface for reinforcement-learning environments.

class jaxdem.rl.environments.Environment(state: State, system: System, env_params: dict[str, Any])#

Bases: Factory, ABC

Defines the interface for reinforcement-learning environments.

  • Let A be the number of agents (A ≥ 1). Single-agent environments still use A=1.

  • Observations and actions are flattened per agent to fixed sizes. Use action_space_shape to reshape inside the environment if needed.

Required shapes

  • Observation: (A, observation_space_size)

  • Action (input to step()): (A, action_space_size)

  • Reward: (A,)

  • Done: scalar boolean for the whole environment

Todo: - Truncated data field: per-agent termination flag - Render method

Example:#

To define a custom environment, inherit from Environment and implement the abstract methods:

>>> @Environment.register("Environment")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True)
>>> class MyCustomEnv(Environment):
    ...
state: State#

Simulation state.

system: System#

Simulation system configuration.

env_params: dict[str, Any]#

Environment-specific parameters.

abstractmethod static reset(env: Environment, key: Array | ndarray | bool | number | bool | int | float | complex) Environment[source]#

Initialize the environment to a valid start state.

Parameters:
  • env ('MyCustomEnv') – Instance of the environment.

  • key (jax.random.PRNGKey) – JAX random number generator key.

Returns:

Freshly initialized environment.

Return type:

Environment

static reset_if_done(env: Environment, done: Array, key: Array | ndarray | bool | number | bool | int | float | complex) Environment[source]#

Conditionally resets the environment if the environment has reached a terminal state.

This method checks the done flag and, if True, calls the environment’s reset method to reinitialize the state. Otherwise, it returns the current environment unchanged.

Parameters:
  • env (Environment) – The current environment instance.

  • done (jax.Array) – A boolean flag indicating whether the environment has reached a terminal state.

  • key (jax.random.PRNGKey) – JAX random number generator key used for reinitialization.

Returns:

Either the freshly reset environment (if done is True) or the unchanged environment (if done is False).

Return type:

Environment

abstractmethod static step(env: Environment, action: Array) Environment[source]#

Advance the simulation by one step using per-agent actions.

Parameters:
  • env (Environment) – The current environment.

  • action (jax.Array) – The vector of actions each agent in the environment should take.

Returns:

The updated environment state.

Return type:

Environment

abstractmethod static observation(env: Environment) Array[source]#

Returns the per-agent observation vector.

Parameters:

env (Environment) – The current environment.

Returns:

Vector corresponding to the environment observation.

Return type:

jax.Array

abstractmethod static reward(env: Environment) Array[source]#

Returns the per-agent immediate rewards.

Parameters:

env (Environment) – The current environment.

Returns:

Vector corresponding to all the agent’s rewards based on the current environment state.

Return type:

jax.Array

abstractmethod static done(env: Environment) Array[source]#

Returns a boolean indicating whether the environment has ended.

Parameters:

env (Environment) – The current environment.

Returns:

A bool indicating when the environment ended

Return type:

jax.Array

static info(env: Environment) dict[str, Any][source]#

Return auxiliary diagnostic information.

By default, returns an empty dict. Subclasses may override to provide environment specific information.

Parameters:

env (Environment) – The current state of the environment.

Returns:

A dictionary with additional information about the environment.

Return type:

Dict

property num_envs: int[source]#

Number of batched environments.

property max_num_agents: int[source]#

Maximum number of active agents in the environment.

property action_space_size: int[source]#

Flattened action size per agent. Actions passed to step() have shape (A, action_space_size).

property action_space_shape: tuple[int][source]#

Original per-agent action shape (useful for reshaping inside the environment).

property observation_space_size: int[source]#

Flattened observation size per agent. observation() returns shape (A, observation_space_size).

class jaxdem.rl.environments.MultiNavigator(state: State, system: System, env_params: dict[str, Any], n_lidar_rays: int)#

Bases: Environment

Multi-agent navigation environment toward assigned targets.

Each agent controls a force vector that is applied directly to a sphere inside a reflective box. Viscous drag -friction * vel is added each step. Objectives are sampled and assigned one-to-one via a random permutation.

The reward uses exponential potential-based shaping:

\[R_i = (e^{-2d_i} - e^{-2d_i^{\mathrm{prev}}}) - w_{\mathrm{ke}}(K_i - K_i^{\mathrm{prev}}) + w_{\mathrm{coop}} \cdot \frac{1}{N}\sum_j (e^{-2d_j} - e^{-2d_j^{\mathrm{prev}}}) + w_{\mathrm{near}}\,\mathbf{1}[d_i \le r_i]\]

where \(d_i\) is the distance to the assigned objective and \(K_i\) is the translational kinetic energy of agent \(i\).

Notes

The observation vector per agent is:

Feature

Size

Unit direction to objective

dim

Clamped displacement

dim

Velocity

dim

LiDAR proximity (normalised)

n_lidar_rays

If one wants some realistic parameters for training, skip_frames = 50 will give a response rate of 200 Hz, meaning that num_steps_epoch = 100 gives a horizon of 0.5 seconds.

n_lidar_rays: int#

Number of angular bins for each LiDAR sensor.

classmethod Create(N: int = 64, min_box_size: float = 20.0, max_box_size: float = 20.0, box_padding: float = 5.0, max_steps: int = 100000, friction: float = 0.2, ke_weight: float = 0.1, coop_weight: float = 0.2, near_goal_bonus: float = 0.1, lidar_range: float = 6.0, n_lidar_rays: int = 16) MultiNavigator[source]#

Create a multi-agent navigator environment.

Parameters:
  • N (int) – Number of agents.

  • min_box_size (float) – Range for the random square domain side length sampled at each reset().

  • max_box_size (float) – Range for the random square domain side length sampled at each reset().

  • box_padding (float) – Extra padding around the domain in multiples of the particle radius.

  • max_steps (int) – Episode length in physics steps.

  • friction (float) – Viscous drag coefficient applied as -friction * vel.

  • ke_weight (float) – Weight for the differential kinetic energy penalty.

  • coop_weight (float) – Weight for the shared team-progress bonus.

  • near_goal_bonus (float) – Reward bonus applied when an agent is within one radius of its objective.

  • lidar_range (float) – Maximum detection range for the LiDAR sensor.

  • n_lidar_rays (int) – Number of angular LiDAR bins spanning \([-\pi, \pi)\).

Returns:

A freshly constructed environment (call reset() before use).

Return type:

MultiNavigator

static reset(env: MultiNavigator, key: Array | ndarray | bool | number | bool | int | float | complex) Environment[source]#

Initialize the environment with random positions and objectives.

Parameters:
  • env (Environment) – Current environment instance.

  • key (ArrayLike) – JAX random number generator key.

Returns:

Freshly initialized environment.

Return type:

Environment

static step(env: MultiNavigator, action: Array) Environment[source]#

Advance one step. Actions are forces; simple drag is applied (-friction * vel).

Parameters:
  • env (Environment) – The current environment.

  • action (jax.Array) – The vector of actions each agent in the environment should take.

Returns:

The updated environment state.

Return type:

Environment

static observation(env: MultiNavigator) Array[source]#

Build per-agent observations.

Contents per agent#

  • Unit vector to objective (shape (dim,)) –> Direction

  • Clamped delta to objective (shape (dim,)) –> Local precision

  • Velocity (shape (dim,))

  • LiDAR proximity, normalized by lidar_range (shape (n_lidar_rays,))

returns:

Array of shape (N, 3 * dim + n_lidar_rays)

rtype:

jax.Array

static reward(env: MultiNavigator) Array[source]#

Returns a vector of per-agent rewards.

\[\mathrm{rew}_t = (e^{-2 \cdot d_t} - e^{-2 \cdot d_t^{\mathrm{prev}}}) - w_{\text{ke}} (K_t - K_{t-1}) + w_{\text{coop}} \cdot \mathrm{mean}(e^{-2 \cdot d_t} - e^{-2 \cdot d_t^{\mathrm{prev}}}) + w_{\text{near}} \cdot \mathbf{1}[d_t \le r]\]

where \(d_t\) is the distance to the objective at step \(t\), \(K_t\) is the kinetic energy at step \(t\), \(w_{\text{ke}}\) is the kinetic-energy penalty weight, and \(w_{\text{coop}}\) weights a shared team-progress bonus, and \(w_{\text{near}}\) weights a near-goal bonus.

Parameters:

env (Environment) – Current environment.

Returns:

Shape (N,).

Return type:

jax.Array

static done(env: MultiNavigator) Array[source]#

Returns a boolean indicating whether the environment has ended. The episode terminates when the maximum number of steps is reached.

Parameters:

env (Environment) – The current environment.

Returns:

Boolean array indicating whether the episode has ended.

Return type:

jax.Array

property action_space_size: int[source]#

Flattened action size per agent. Actions passed to step() have shape (A, action_space_size).

property action_space_shape: tuple[int][source]#

Original per-agent action shape (useful for reshaping inside the environment).

property observation_space_size: int[source]#

Flattened observation size per agent. observation() returns shape (A, observation_space_size).

class jaxdem.rl.environments.MultiRoller(state: State, system: System, env_params: dict[str, Any], n_lidar_rays: int)#

Bases: Environment

Multi-agent rolling environment toward assigned targets.

Each agent controls a torque vector that is applied directly to a sphere on a \(z=0\) floor. Translational drag -friction * vel and angular damping -friction * ang_vel are applied each step. Objectives are sampled and assigned one-to-one via a random permutation.

The reward uses exponential potential-based shaping:

\[R_i = (e^{-2d_i} - e^{-2d_i^{\mathrm{prev}}}) - w_{\mathrm{ke}}(K_i - K_i^{\mathrm{prev}}) + w_{\mathrm{coop}} \cdot \frac{1}{N}\sum_j (e^{-2d_j} - e^{-2d_j^{\mathrm{prev}}}) + w_{\mathrm{near}}\,\mathbf{1}[d_i \le r_i]\]

where \(d_i\) is the distance to the assigned objective in the \(xy\) plane and \(K_i\) is the translational kinetic energy of agent \(i\).

Notes

The observation vector per agent is:

Feature

Size

Unit direction to objective

2

Clamped displacement

2

Velocity

2

LiDAR proximity (normalised)

n_lidar_rays

If one wants some realistic parameters for training, skip_frames = 50 will give a response rate of 200 Hz, meaning that num_steps_epoch = 100 gives a horizon of 0.5 seconds.

n_lidar_rays: int#

Number of angular bins for each LiDAR sensor.

classmethod Create(N: int = 64, min_box_size: float = 20.0, max_box_size: float = 20.0, box_padding: float = 5.0, max_steps: int = 100000, friction: float = 0.2, ke_weight: float = 0.1, coop_weight: float = 0.2, near_goal_bonus: float = 0.1, lidar_range: float = 6.0, n_lidar_rays: int = 16) MultiRoller[source]#

Create a multi-agent roller environment.

Parameters:
  • N (int) – Number of agents.

  • min_box_size (float) – Range for the random square domain side length sampled at each reset().

  • max_box_size (float) – Range for the random square domain side length sampled at each reset().

  • box_padding (float) – Extra padding around the domain in multiples of the particle radius.

  • max_steps (int) – Episode length in physics steps.

  • friction (float) – Translational and angular damping coefficient.

  • ke_weight (float) – Weight for the differential kinetic energy penalty.

  • coop_weight (float) – Weight for the shared team-progress bonus.

  • near_goal_bonus (float) – Reward bonus applied when an agent is within one radius of its objective.

  • lidar_range (float) – Maximum detection range for the LiDAR sensor.

  • n_lidar_rays (int) – Number of angular LiDAR bins spanning \([-\pi, \pi)\).

Returns:

A freshly constructed environment (call reset() before use).

Return type:

MultiRoller

static reset(env: MultiRoller, key: Array | ndarray | bool | number | bool | int | float | complex) Environment[source]#

Initialize the environment with random positions and objectives.

Parameters:
  • env (Environment) – Current environment instance.

  • key (ArrayLike) – JAX random number generator key.

Returns:

Freshly initialized environment.

Return type:

Environment

static step(env: MultiRoller, action: Array) Environment[source]#

Advance one step. Actions are torques; simple damping is applied.

Parameters:
  • env (Environment) – The current environment.

  • action (jax.Array) – The vector of actions each agent in the environment should take.

Returns:

The updated environment state.

Return type:

Environment

static observation(env: MultiRoller) Array[source]#

Build per-agent observations.

Contents per agent#

  • Unit vector to objective in the \(xy\) plane (shape (2,)).

  • Clamped objective delta in the \(xy\) plane (shape (2,)).

  • Velocity in the \(xy\) plane (shape (2,)).

  • LiDAR proximity, normalized by lidar_range (shape (n_lidar_rays,)).

returns:

Array of shape (N, 6 + n_lidar_rays)

rtype:

jax.Array

static reward(env: MultiRoller) Array[source]#

Returns a vector of per-agent rewards.

\[\mathrm{rew}_t = (e^{-2 \cdot d_t} - e^{-2 \cdot d_t^{\mathrm{prev}}}) - w_{\text{ke}} (K_t - K_{t-1}) + w_{\text{coop}} \cdot \mathrm{mean}(e^{-2 \cdot d_t} - e^{-2 \cdot d_t^{\mathrm{prev}}}) + w_{\text{near}} \cdot \mathbf{1}[d_t \le r]\]

where \(d_t\) is the distance to the objective at step \(t\), \(K_t\) is the kinetic energy at step \(t\), \(w_{\text{ke}}\) is the kinetic-energy penalty weight, and \(w_{\text{coop}}\) weights a shared team-progress bonus, and \(w_{\text{near}}\) weights a near-goal bonus.

Parameters:

env (Environment) – Current environment.

Returns:

Shape (N,).

Return type:

jax.Array

static done(env: MultiRoller) Array[source]#

Returns a boolean indicating whether the environment has ended. The episode terminates when the maximum number of steps is reached.

Parameters:

env (Environment) – The current environment.

Returns:

Boolean array indicating whether the environment has ended.

Return type:

jax.Array

property action_space_size: int[source]#

Flattened action size per agent. Actions passed to step() have shape (A, action_space_size).

property action_space_shape: tuple[int][source]#

Original per-agent action shape (useful for reshaping inside the environment).

property observation_space_size: int[source]#

Flattened observation size per agent. observation() returns shape (A, observation_space_size).

class jaxdem.rl.environments.SingleNavigator(state: State, system: System, env_params: dict[str, Any])#

Bases: Environment

Single-agent navigation environment toward a fixed target.

The agent controls a force vector that is applied directly to a sphere inside a reflective box. Viscous drag -friction * vel is added each step. The reward uses exponential potential-based shaping:

\[\mathrm{rew}_t = (e^{-2 \cdot d_t} - e^{-2 \cdot d_t^{\mathrm{prev}}}) - w_{\text{ke}} (K_t - K_{t-1})\]

where \(d_t\) is the distance to the objective at step \(t\), \(K_t\) is the kinetic energy at step \(t\), and \(w_{\text{ke}}\) is the weight for the kinetic energy penalty.

Notes

The observation vector per agent is:

Feature

Size

Unit direction to objective

dim

Clamped displacement

dim

Velocity

dim

If one wants some realistic parameters for training, skip_frames = 50 will give a response rate of 200 Hz, meaning that num_steps_epoch = 100 gives a horizon of 0.5 seconds.

classmethod Create(dim: int = 2, min_box_size: float = 40.0, max_box_size: float = 40.0, max_steps: int = 20000, friction: float = 0.2, ke_weight: float = 0.1) SingleNavigator[source]#

Create a single-agent navigator environment.

Parameters:
  • dim (int) – Spatial dimensionality (2 or 3).

  • min_box_size (float) – Range for the random square domain side length.

  • max_box_size (float) – Range for the random square domain side length.

  • max_steps (int) – Episode length in physics steps.

  • friction (float) – Viscous drag coefficient applied as -friction * vel.

  • ke_weight (float) – Weight for the differential kinetic energy penalty.

Returns:

A freshly constructed environment (call reset() before use).

Return type:

SingleNavigator

static reset(env: SingleNavigator, key: Array | ndarray | bool | number | bool | int | float | complex) Environment[source]#

Initialize the environment with a randomly placed particle and velocity.

Parameters:
  • env ('SingleNavigator') – Current environment instance.

  • key (jax.random.PRNGKey) – JAX random number generator key.

Returns:

Freshly initialized environment.

Return type:

Environment

static step(env: SingleNavigator, action: Array) Environment[source]#

Advance one step. Actions are forces; simple drag is applied (-friction * vel).

Parameters:
  • env (Environment) – The current environment.

  • action (jax.Array) – The vector of actions each agent in the environment should take.

Returns:

The updated environment state.

Return type:

Environment

static observation(env: SingleNavigator) Array[source]#

Build per-agent observations.

Contents per agent#

  • Unit vector to objective (shape (dim,)) –> Direction

  • Clamped delta to objective (shape (dim,)) –> Local precision

  • Velocity (shape (dim,))

returns:

Array of shape (N, 3 * dim)

rtype:

jax.Array

static reward(env: SingleNavigator) Array[source]#

Returns a vector of per-agent rewards.

Reward:

\[\mathrm{rew}_t = (e^{-2 \cdot d_t} - e^{-2 \cdot d_t^{\mathrm{prev}}}) - w_{\text{ke}} (K_t - K_{t-1})\]

where \(d_t\) is the distance to the objective at step \(t\), \(K_t\) is the kinetic energy at step \(t\), and \(w_{\text{ke}}\) is the weight for the kinetic energy penalty.

Parameters:

env (Environment) – Current environment.

Returns:

Shape (N,).

Return type:

jax.Array

static done(env: SingleNavigator) Array[source]#

Returns a boolean indicating whether the environment has ended. The episode terminates when the maximum number of steps is reached.

Parameters:

env (Environment) – The current environment.

Returns:

Boolean array indicating whether the episode has ended.

Return type:

jax.Array

property action_space_size: int[source]#

Flattened action size per agent. Actions passed to step() have shape (A, action_space_size).

property action_space_shape: tuple[int][source]#

Original per-agent action shape (useful for reshaping inside the environment).

property observation_space_size: int[source]#

Flattened observation size per agent. observation() returns shape (A, observation_space_size).

class jaxdem.rl.environments.SingleRoller3D(state: State, system: System, env_params: dict[str, Any])#

Bases: Environment

Single-agent 3D navigation via torque-controlled rolling.

The agent is a sphere resting on a \(z = 0\) floor under gravity. Actions are 3-D torque vectors; translational motion arises from frictional contact with the floor (see frictional_wall_force()). A viscous drag -friction * vel and a fixed angular damping of -friction * ang_vel are applied each step.

The reward uses exponential potential-based shaping:

\[\mathrm{rew}_t = (e^{-2 \cdot d_t} - e^{-2 \cdot d_t^{\mathrm{prev}}}) - w_{\text{ke}} (K_t - K_{t-1})\]

where \(d_t\) is the distance to the objective at step \(t\), \(K_t\) is the kinetic energy at step \(t\), and \(w_{\text{ke}}\) is the weight for the kinetic energy penalty.

Notes

The observation vector per agent is:

Feature

Size

Unit direction to objective

2

Clamped displacement (x, y)

2

Velocity (x, y)

2

Angular velocity

3

If one wants some realistic parameters for training, skip_frames = 50 will give a response rate of 200 Hz, meaning that num_steps_epoch = 100 gives a horizon of 0.5 seconds.

classmethod Create(min_box_size: float = 40.0, max_box_size: float = 40.0, max_steps: int = 20000, friction: float = 0.2, ke_weight: float = 0.1) SingleRoller3D[source]#

Create a single-agent roller environment.

Parameters:
  • min_box_size (float) – Range for the random square domain side length.

  • max_box_size (float) – Range for the random square domain side length.

  • max_steps (int) – Episode length in physics steps.

  • friction (float) – Viscous drag coefficient applied as -friction * vel.

  • ke_weight (float) – Weight for the differential kinetic energy penalty.

Returns:

A freshly constructed environment (call reset() before use).

Return type:

SingleRoller3D

static reset(env: SingleRoller3D, key: Array | ndarray | bool | number | bool | int | float | complex) Environment[source]#

Randomly place the agent and objective on the floor.

Parameters:
  • env (Environment) – Current environment instance.

  • key (ArrayLike) – JAX PRNG key.

Returns:

Freshly initialised environment.

Return type:

Environment

static step(env: SingleRoller3D, action: Array) Environment[source]#

Apply a torque action, advance physics by one step.

Parameters:
  • env (Environment) – Current environment.

  • action (jax.Array) – 3-D torque vector per agent.

Returns:

Updated environment after one physics step.

Return type:

Environment

static observation(env: SingleRoller3D) Array[source]#

Per-agent observation vector.

Contents per agent:

  • Unit displacement to objective projected to x-y (shape (2,)).

  • Clamped displacement to objective projected to x-y (shape (2,)).

  • Velocity projected to x-y (shape (2,)).

  • Angular velocity (shape (3,)).

Returns:

Shape (N, 9).

Return type:

jax.Array

static reward(env: SingleRoller3D) Array[source]#

Returns a vector of per-agent rewards.

Exponential potential-based shaping:

\[\mathrm{rew}_t = (e^{-2 \cdot d_t} - e^{-2 \cdot d_t^{\mathrm{prev}}}) - w_{\text{ke}} (K_t - K_{t-1})\]
Returns:

Shape (N,).

Return type:

jax.Array

static done(env: SingleRoller3D) Array[source]#

True when step_count exceeds max_steps.

property action_space_size: int[source]#

Per-agent flattened action dimensionality (3-D torque).

property action_space_shape: tuple[int][source]#

Per-agent action tensor shape.

property observation_space_size: int[source]#

Per-agent flattened observation dimensionality (9).

class jaxdem.rl.environments.SwarmNavigator(state: State, system: System, env_params: dict[str, Any], n_lidar_rays: int)#

Bases: Environment

Multi-agent navigation environment toward nearby shared targets.

Each agent controls a force vector that is applied directly to a sphere inside a reflective box. Viscous drag -friction * vel is added each step. Objectives are sampled globally, and each agent observes objective LiDAR and agent LiDAR.

At reset, a small subset of agents is spawned in the central objective region while the rest are spawned in the outer padding ring.

The reward uses exponential potential-based shaping:

\[R_i = (S_i - S_i^{\mathrm{prev}}) - w_{\mathrm{ke}}(K_i - K_i^{\mathrm{prev}}) + w_{\mathrm{coop}} \cdot \frac{1}{N}\sum_m (S_m - S_m^{\mathrm{prev}}) + w_{\mathrm{near}}\,\mathbf{1}[d_i \le r_i]\]

where \(d_i\) is the distance to the closest objective, \(K_i\) is the translational kinetic energy of agent \(i\), and \(S_i = \sum_{r \in \text{obj-LiDAR}} e^{-2 d_{ir}}\) sums exponential shaping over objectives detected by objective LiDAR rays.

Notes

The observation vector per agent is:

Feature

Size

Velocity

dim

Objective LiDAR proximity

n_lidar_rays

Agent LiDAR proximity

n_lidar_rays

n_lidar_rays: int#

Number of angular bins for each LiDAR sensor.

classmethod Create(N: int = 64, min_box_size: float = 20.0, max_box_size: float = 20.0, box_padding: float = 20.0, max_steps: int = 100000, friction: float = 0.2, ke_weight: float = 0.1, coop_weight: float = 0.2, near_goal_bonus: float = 0.1, lidar_range: float = 10.0, n_lidar_rays: int = 24) SwarmNavigator[source]#

Create a swarm navigator environment.

Parameters:
  • N (int) – Number of agents and number of sampled objectives.

  • min_box_size (float) – Range for the random square domain side length sampled at each reset().

  • max_box_size (float) – Range for the random square domain side length sampled at each reset().

  • box_padding (float) – Extra padding around the domain in multiples of the particle radius. The padding region is used as the outer spawn ring.

  • max_steps (int) – Episode length in physics steps.

  • friction (float) – Viscous drag coefficient applied as -friction * vel.

  • ke_weight (float) – Weight for the differential kinetic energy penalty.

  • coop_weight (float) – Weight for the shared team-progress bonus.

  • near_goal_bonus (float) – Reward bonus applied when an agent is within one radius of its closest objective.

  • lidar_range (float) – Maximum detection range for the LiDAR sensor.

  • n_lidar_rays (int) – Number of angular LiDAR bins spanning \([-\pi, \pi)\).

Returns:

A freshly constructed environment (call reset() before use).

Return type:

SwarmNavigator

static reset(env: SwarmNavigator, key: Array | ndarray | bool | number | bool | int | float | complex) Environment[source]#

Initialize the environment with random positions and objectives.

Parameters:
  • env (Environment) – Current environment instance.

  • key (ArrayLike) – JAX random number generator key.

Returns:

Freshly initialized environment.

Return type:

Environment

static step(env: SwarmNavigator, action: Array) Environment[source]#

Advance one step. Actions are forces; simple drag is applied (-friction * vel).

Parameters:
  • env (Environment) – The current environment.

  • action (jax.Array) – The vector of actions each agent in the environment should take.

Returns:

The updated environment state.

Return type:

Environment

static observation(env: SwarmNavigator) Array[source]#

Build per-agent observations.

Contents per agent#

  • Velocity (shape (dim,)).

  • Objective LiDAR proximity, normalized by lidar_range (shape (n_lidar_rays,)).

  • Agent LiDAR proximity, normalized by lidar_range (shape (n_lidar_rays,)).

returns:

Array of shape (N, dim + 2 * n_lidar_rays)

rtype:

jax.Array

static reward(env: SwarmNavigator) Array[source]#

Returns a vector of per-agent rewards.

\[\mathrm{rew}_t = (S_t - S_t^{\mathrm{prev}}) - w_{\text{ke}} (K_t - K_{t-1}) + w_{\text{coop}} \cdot \mathrm{mean}\left( (S_t - S_t^{\mathrm{prev}})\right) + w_{\text{near}} \cdot \mathbf{1}[d_t \le r]\]

where \(d_t\) is the distance to the closest objective at step \(t\), \(K_t\) is the kinetic energy at step \(t\), and \(S_t\) is the per-agent sum of \(e^{-2d}\) over objectives detected by objective LiDAR rays, \(w_{\text{ke}}\) is the kinetic-energy penalty weight, and \(w_{\text{coop}}\) weights a shared team-progress bonus, and \(w_{\text{near}}\) weights a near-goal bonus.

Parameters:

env (Environment) – Current environment.

Returns:

Shape (N,).

Return type:

jax.Array

static done(env: SwarmNavigator) Array[source]#

Returns a boolean indicating whether the environment has ended. The episode terminates when the maximum number of steps is reached.

Parameters:

env (Environment) – The current environment.

Returns:

Boolean array indicating whether the environment has ended.

Return type:

jax.Array

property action_space_size: int[source]#

Flattened action size per agent. Actions passed to step() have shape (A, action_space_size).

property action_space_shape: tuple[int][source]#

Original per-agent action shape (useful for reshaping inside the environment).

property observation_space_size: int[source]#

Flattened observation size per agent. observation() returns shape (A, observation_space_size).

class jaxdem.rl.environments.SwarmRoller(state: State, system: System, env_params: dict[str, Any], n_lidar_rays: int)#

Bases: Environment

Multi-agent rolling environment toward nearby shared targets.

Each agent controls a torque vector that is applied directly to a sphere on a \(z=0\) floor. Translational drag -friction * vel and angular damping -friction * ang_vel are applied each step. Objectives are sampled globally, and each agent observes objective LiDAR and agent LiDAR.

At reset, a small subset of agents is spawned in the central objective region while the rest are spawned in the outer padding ring.

The reward uses exponential potential-based shaping:

\[R_i = (S_i - S_i^{\mathrm{prev}}) - w_{\mathrm{ke}}(K_i - K_i^{\mathrm{prev}}) + w_{\mathrm{coop}} \cdot \frac{1}{N}\sum_m (S_m - S_m^{\mathrm{prev}}) + w_{\mathrm{near}}\,\mathbf{1}[d_i \le r_i]\]

where \(d_i\) is the distance to the closest objective, \(K_i\) is the translational kinetic energy of agent \(i\), and \(S_i = \sum_{r \in \text{obj-LiDAR}} e^{-2 d_{ir}}\) sums exponential shaping over objectives detected by objective LiDAR rays.

Notes

The observation vector per agent is:

Feature

Size

Velocity

dim

Objective LiDAR proximity

n_lidar_rays

Agent LiDAR proximity

n_lidar_rays

n_lidar_rays: int#

Number of angular bins for each LiDAR sensor.

classmethod Create(N: int = 64, min_box_size: float = 20.0, max_box_size: float = 20.0, box_padding: float = 20.0, max_steps: int = 100000, friction: float = 0.2, ke_weight: float = 0.1, coop_weight: float = 0.2, near_goal_bonus: float = 0.1, lidar_range: float = 10.0, n_lidar_rays: int = 24) SwarmRoller[source]#

Create a swarm roller environment.

Parameters:
  • N (int) – Number of agents and number of sampled objectives.

  • min_box_size (float) – Range for the random square domain side length sampled at each reset().

  • max_box_size (float) – Range for the random square domain side length sampled at each reset().

  • box_padding (float) – Extra padding around the domain in multiples of the particle radius. The padding region is used as the outer spawn ring.

  • max_steps (int) – Episode length in physics steps.

  • friction (float) – Translational and angular damping coefficient.

  • ke_weight (float) – Weight for the differential kinetic energy penalty.

  • coop_weight (float) – Weight for the shared team-progress bonus.

  • near_goal_bonus (float) – Reward bonus applied when an agent is within one radius of its closest objective.

  • lidar_range (float) – Maximum detection range for the LiDAR sensor.

  • n_lidar_rays (int) – Number of angular LiDAR bins spanning \([-\pi, \pi)\).

Returns:

A freshly constructed environment (call reset() before use).

Return type:

SwarmRoller

static reset(env: SwarmRoller, key: Array | ndarray | bool | number | bool | int | float | complex) Environment[source]#

Initialize the environment with random positions and objectives.

Parameters:
  • env (Environment) – Current environment instance.

  • key (ArrayLike) – JAX random number generator key.

Returns:

Freshly initialized environment.

Return type:

Environment

static step(env: SwarmRoller, action: Array) Environment[source]#

Advance one step. Actions are torques; simple damping is applied.

Parameters:
  • env (Environment) – The current environment.

  • action (jax.Array) – The vector of actions each agent in the environment should take.

Returns:

The updated environment state.

Return type:

Environment

static observation(env: SwarmRoller) Array[source]#

Build per-agent observations.

Contents per agent#

  • Velocity (shape (dim,)).

  • Objective LiDAR proximity, normalized by lidar_range (shape (n_lidar_rays,)).

  • Agent LiDAR proximity, normalized by lidar_range (shape (n_lidar_rays,)).

returns:

Array of shape (N, dim + 2 * n_lidar_rays)

rtype:

jax.Array

static reward(env: SwarmRoller) Array[source]#

Returns a vector of per-agent rewards.

\[\mathrm{rew}_t = (S_t - S_t^{\mathrm{prev}}) - w_{\text{ke}} (K_t - K_{t-1}) + w_{\text{coop}} \cdot \mathrm{mean}\left( (S_t - S_t^{\mathrm{prev}})\right) + w_{\text{near}} \cdot \mathbf{1}[d_t \le r]\]

where \(d_t\) is the distance to the closest objective at step \(t\), \(K_t\) is the kinetic energy at step \(t\), and \(S_t\) is the per-agent sum of \(e^{-2d}\) over objectives detected by objective LiDAR rays, \(w_{\text{ke}}\) is the kinetic-energy penalty weight, and \(w_{\text{coop}}\) weights a shared team-progress bonus, and \(w_{\text{near}}\) weights a near-goal bonus.

Parameters:

env (Environment) – Current environment.

Returns:

Shape (N,).

Return type:

jax.Array

static done(env: SwarmRoller) Array[source]#

Returns a boolean indicating whether the environment has ended. The episode terminates when the maximum number of steps is reached.

Parameters:

env (Environment) – The current environment.

Returns:

Boolean array indicating whether the environment has ended.

Return type:

jax.Array

property action_space_size: int[source]#

Flattened action size per agent. Actions passed to step() have shape (A, action_space_size).

property action_space_shape: tuple[int][source]#

Original per-agent action shape (useful for reshaping inside the environment).

property observation_space_size: int[source]#

Flattened observation size per agent. observation() returns shape (A, observation_space_size).

class jaxdem.rl.environments.SwarmRoller3D(state: State, system: System, env_params: dict[str, Any], n_lidar_rays: int, n_lidar_elevation: int)#

Bases: Environment

Multi-agent 3-D rolling environment with magnetic interaction and pyramid objectives.

n_lidar_rays: int#

Number of azimuthal bins for the 3-D LiDAR sensor.

n_lidar_elevation: int#

Number of elevation bins for the 3-D LiDAR sensor.

classmethod Create(N: int = 5, min_box_size: float = 10.0, max_box_size: float = 10.0, box_padding: float = 0.0, max_steps: int = 100000, friction: float = 0.2, lidar_range: float = 10.0, n_lidar_rays: int = 8, n_lidar_elevation: int = 8, magnet_strength: float = 4.0, magnet_range: float = 3.0, ke_weight: float = 0.1, coop_weight: float = 0.2, near_goal_bonus: float = 0.1) SwarmRoller3D[source]#

Create a swarm roller 3-D environment.

static reset(env: SwarmRoller3D, key: Array | ndarray | bool | number | bool | int | float | complex) Environment[source]#

Reset the environment to a random initial configuration.

static step(env: SwarmRoller3D, action: Array) Environment[source]#

Advance the environment by one physics step.

static observation(env: SwarmRoller3D) Array[source]#

Build per-agent observations.

static reward(env: SwarmRoller3D) Array[source]#

Return per-agent rewards.

static done(env: SwarmRoller3D) Array[source]#

Return True when the episode has exceeded max_steps.

property action_space_size: int[source]#

Number of scalar actions per agent (3-D torque).

property action_space_shape: tuple[int][source]#

Shape of a single agent’s action ((3,)).

property observation_space_size: int[source]#

Dimensionality of a single agent’s observation vector.

class jaxdem.rl.environments.SwarmStacking3D(state: State, system: System, env_params: dict[str, Any], n_lidar_rays: int, n_lidar_elevation: int)#

Bases: Environment

Multi-agent 3-D stacking environment with periodic boundaries.

n_lidar_rays: int#

Number of azimuthal bins for the 3-D LiDAR sensor.

n_lidar_elevation: int#

Number of elevation bins for the 3-D LiDAR sensor.

classmethod Create(N: int = 16, min_box_size: float = 20.0, max_box_size: float = 20.0, box_padding: float = 20.0, max_steps: int = 5760, friction: float = 0.2, lidar_range: float = 10.0, n_lidar_rays: int = 8, n_lidar_elevation: int = 8, magnet_strength: float = 40.0, magnet_range: float = 2.4, ke_weight: float = 0.1, coop_weight: float = 0.2, near_goal_bonus: float = 0.1) SwarmStacking3D[source]#

Create a swarm stacking 3-D environment.

static reset(env: SwarmStacking3D, key: Array | ndarray | bool | number | bool | int | float | complex) Environment[source]#
static step(env: SwarmStacking3D, action: Array) Environment[source]#
static observation(env: SwarmStacking3D) Array[source]#
static reward(env: SwarmStacking3D) Array[source]#
static done(env: SwarmStacking3D) Array[source]#
property action_space_size: int[source]#

Flattened action size per agent. Actions passed to step() have shape (A, action_space_size).

property action_space_shape: tuple[int][source]#

Original per-agent action shape (useful for reshaping inside the environment).

property observation_space_size: int[source]#

Flattened observation size per agent. observation() returns shape (A, observation_space_size).

class jaxdem.rl.environments.ThreeGears(state: State, system: System, env_params: dict[str, Any])#

Bases: Environment

Multi-agent 2-D environment with three gears.

The environment consists of three active gears composed of spheres. All gears can apply torque to themselves. The shared objective is to navigate the gears to form a triangular structure defined by a randomized target position.

Note

Similar to the TwoGears environment, if one wants some realistic parameters for training, skip_frames = 50 will give a response rate of 200 Hz, meaning that num_steps_epoch = 100 gives a horizon of 0.5 seconds.

classmethod Create(box_size: float = 10.0, max_steps: int = 100000, friction: float = 0.2, ke_weight: float = 0.001) ThreeGears[source]#

Create a three-gears 2-D environment.

Parameters:
  • box_size (float) – Size of the square bounding box.

  • max_steps (int) – Episode length in physics steps.

  • friction (float) – Viscous drag coefficient applied as -friction * vel.

  • ke_weight (float) – Weight for the differential kinetic energy penalty.

Returns:

A freshly constructed environment (call reset() before use).

Return type:

ThreeGears

static reset(env: ThreeGears, key: Array) Environment[source]#

Reset the environment to a random initial configuration.

Parameters:
  • env (Environment) – The environment instance to reset.

  • key (jax.Array) – PRNG key used to sample the initial positions and objective triangle.

Returns:

The environment with a fresh episode state.

Return type:

Environment

static step(env: ThreeGears, action: Array) Environment[source]#

Advance the environment by one physics step.

Applies torque to the active gears, computes inter-gear forces, and applies viscous drag.

Parameters:
  • env (Environment) – Current environment.

  • action (jax.Array) – Actions for all gears.

Returns:

Updated environment after physics integration and sensor updates.

Return type:

Environment

static observation(env: ThreeGears) Array[source]#

Build the observation vector.

The observation vector contains 22 features:

Feature

Size

Distance to floor

1

Distance to left/right walls

2

Unit vector to target

2

Clamped displacement to target

2

Unit vector to neighbor j

2

Clamped displacement to neighbor j

2

\(\sin(\Delta\theta_j)\)

1

\(\cos(\Delta\theta_j)\)

1

Unit vector to neighbor k

2

Clamped displacement to neighbor k

2

\(\sin(\Delta\theta_k)\)

1

\(\cos(\Delta\theta_k)\)

1

Velocity (x, y)

2

Angular velocity

1

Returns:

Observation vector of size 22.

Return type:

jax.Array

static reward(env: ThreeGears) Array[source]#

Compute the cooperative reward.

The shared reward is based on the differential distance to the objective minus a penalty for the change in kinetic energy:

\[R_t = (d_{t-1} - d_t) - w_{\text{ke}} \sum_i (K_t^i - K_{t-1}^i)\]

where \(d_t\) is the total distance to the objective at step \(t\), \(K_t^i\) is the kinetic energy of agent \(i\), and \(w_{\text{ke}}\) is the kinetic energy weight.

Returns:

Reward value, identical for all agents.

Return type:

jax.Array

static done(env: ThreeGears) Array[source]#
property action_space_size: int[source]#

Flattened action size per agent. Actions passed to step() have shape (A, action_space_size).

property action_space_shape: tuple[int][source]#

Original per-agent action shape (useful for reshaping inside the environment).

property observation_space_size: int[source]#

Flattened observation size per agent. observation() returns shape (A, observation_space_size).

property max_num_agents: int[source]#

Maximum number of active agents in the environment.

class jaxdem.rl.environments.TwoGears(state: State, system: System, env_params: dict[str, Any])#

Bases: Environment

Two-dimensional environment with two gears.

The environment consists of two gears composed of spheres. One gear is frozen on the floor, and the other is an active agent that can apply torque to itself. The objective is to navigate the active gear to a specified target position above the frozen gear. The active gear is attracted to the frozen gear by a magnetic force.

Note

After experimentation, one needs the max torque to be at least 4.0 * mgr for the gear to be able to climb correctly, and attraction at least 1 * mg. If one wants some realistic parameters for training, skip_frames = 50 will give a response rate of 200 Hz, meaning that num_steps_epoch = 100 gives a horizon of 0.5 seconds.

classmethod Create(box_size: float = 10.0, max_steps: int = 100000, friction: float = 0.2, ke_weight: float = 0.1, attraction_mag: float = 4.0) TwoGears[source]#

Create a two-gears 2-D environment.

Parameters:
  • box_size (float) – Size of the square bounding box.

  • max_steps (int) – Episode length in physics steps.

  • friction (float) – Viscous drag coefficient applied as -friction * vel.

  • ke_weight (float) – Weight for the differential kinetic energy penalty.

  • attraction_mag (float) – Magnitude of the attraction force between the two gears.

Returns:

A freshly constructed environment (call reset() before use).

Return type:

TwoGears

static reset(env: TwoGears, key: Array) Environment[source]#

Reset the environment to a random initial configuration.

Parameters:
  • env (Environment) – The environment instance to reset.

  • key (jax.Array) – PRNG key used to sample the initial positions and objective.

Returns:

The environment with a fresh episode state.

Return type:

Environment

static step(env: TwoGears, action: Array) Environment[source]#

Advance the environment by one step.

Applies torque to the active agent, computes the attraction force between the gears, and applies viscous drag.

The attraction force is defined as:

\[\mathbf{F}_{\text{attraction}} = - \frac{C}{d^3} \hat{n},\]

when \(d < 3 r\), where \(d\) is the distance between the centers, \(\hat{n}\) is the unit vector from the frozen gear to the active gear, and \(C\) is determined by attraction_mag as \(C = m_{\text{attr}} (2r)^3\). r is the gear radius.

Parameters:
  • env (Environment) – Current environment.

  • action (jax.Array) – Actions for the active gear.

Returns:

Updated environment after physics integration and sensor updates.

Return type:

Environment

static observation(env: TwoGears) Array[source]#

Build the observation vector.

The observation vector contains 16 features:

Feature

Size

Distance to floor

1

Distance to left/right walls

2

Unit vector to target

2

Clamped displacement to target

2

Unit vector to frozen gear

2

Clamped displacement to frozen gear

2

\(\sin(\Delta\theta)\)

1

\(\cos(\Delta\theta)\)

1

Velocity (x, y)

2

Angular velocity

1

Returns:

Observation vector of size 16.

Return type:

jax.Array

static reward(env: TwoGears) Array[source]#

Compute the reward.

The reward is based on the differential distance to the objective minus a penalty for the change in kinetic energy:

\[R_t = (d_{t-1} - d_t) - w_{\text{ke}} (K_t - K_{t-1})\]

where \(d_t\) is the distance to the objective at step \(t\), \(K_t\) is the kinetic energy at step \(t\), and \(w_{\text{ke}}\) is the weight for the kinetic energy penalty.

Returns:

Reward value for the active agent.

Return type:

jax.Array

static done(env: TwoGears) Array[source]#
property action_space_size: int[source]#

Flattened action size per agent. Actions passed to step() have shape (A, action_space_size).

property action_space_shape: tuple[int][source]#

Original per-agent action shape (useful for reshaping inside the environment).

property observation_space_size: int[source]#

Flattened observation size per agent. observation() returns shape (A, observation_space_size).

property max_num_agents: int[source]#

Maximum number of active agents in the environment.

Modules

multi_navigator

Environment where multiple agents navigate towards assigned targets.

multi_roller

Environment where multiple rolling agents navigate towards assigned targets.

single_navigator

Environment where a single agent navigates towards a target.

single_roller

Environment where a single agent rolls towards a target on the floor.

swarm_navigator

Environment where multiple agents navigate towards nearby shared targets.

swarm_roller

Environment where multiple rolling agents navigate towards nearby shared targets.

swarm_roller_3d

Multi-agent 3-D swarm rolling environment with magnetic interaction and pyramid objectives.

swarm_stacking_3d

Multi-agent 3-D swarm stacking environment with periodic boundaries.

three_gears

Multi-agent 2-D environment with three gears and a triangle objective.

two_gears

Two-dimensional environment with two gears for RL training.