jaxdem.rl.environments.multi_navigator#

Environment where multiple agents navigate towards assigned targets.

Classes

MultiNavigator(state, system, env_params, ...)

Multi-agent navigation environment toward assigned targets.

class jaxdem.rl.environments.multi_navigator.MultiNavigator(state: State, system: System, env_params: dict[str, Any], n_lidar_rays: int)#

Bases: Environment

Multi-agent navigation environment toward assigned targets.

Each agent controls a force vector that is applied directly to a sphere inside a reflective box. Viscous drag -friction * vel is added each step. Objectives are sampled and assigned one-to-one via a random permutation.

The reward uses exponential potential-based shaping:

\[R_i = (e^{-2d_i} - e^{-2d_i^{\mathrm{prev}}}) - w_{\mathrm{ke}}(K_i - K_i^{\mathrm{prev}}) + w_{\mathrm{coop}} \cdot \frac{1}{N}\sum_j (e^{-2d_j} - e^{-2d_j^{\mathrm{prev}}}) + w_{\mathrm{near}}\,\mathbf{1}[d_i \le r_i]\]

where \(d_i\) is the distance to the assigned objective and \(K_i\) is the translational kinetic energy of agent \(i\).

Notes

The observation vector per agent is:

Feature

Size

Unit direction to objective

dim

Clamped displacement

dim

Velocity

dim

LiDAR proximity (normalised)

n_lidar_rays

If one wants some realistic parameters for training, skip_frames = 50 will give a response rate of 200 Hz, meaning that num_steps_epoch = 100 gives a horizon of 0.5 seconds.

n_lidar_rays: int#

Number of angular bins for each LiDAR sensor.

classmethod Create(N: int = 64, min_box_size: float = 20.0, max_box_size: float = 20.0, box_padding: float = 5.0, max_steps: int = 100000, friction: float = 0.2, ke_weight: float = 0.1, coop_weight: float = 0.2, near_goal_bonus: float = 0.1, lidar_range: float = 6.0, n_lidar_rays: int = 16) MultiNavigator[source]#

Create a multi-agent navigator environment.

Parameters:
  • N (int) – Number of agents.

  • min_box_size (float) – Range for the random square domain side length sampled at each reset().

  • max_box_size (float) – Range for the random square domain side length sampled at each reset().

  • box_padding (float) – Extra padding around the domain in multiples of the particle radius.

  • max_steps (int) – Episode length in physics steps.

  • friction (float) – Viscous drag coefficient applied as -friction * vel.

  • ke_weight (float) – Weight for the differential kinetic energy penalty.

  • coop_weight (float) – Weight for the shared team-progress bonus.

  • near_goal_bonus (float) – Reward bonus applied when an agent is within one radius of its objective.

  • lidar_range (float) – Maximum detection range for the LiDAR sensor.

  • n_lidar_rays (int) – Number of angular LiDAR bins spanning \([-\pi, \pi)\).

Returns:

A freshly constructed environment (call reset() before use).

Return type:

MultiNavigator

static reset(env: MultiNavigator, key: Array | ndarray | bool | number | bool | int | float | complex) Environment[source]#

Initialize the environment with random positions and objectives.

Parameters:
  • env (Environment) – Current environment instance.

  • key (ArrayLike) – JAX random number generator key.

Returns:

Freshly initialized environment.

Return type:

Environment

static step(env: MultiNavigator, action: Array) Environment[source]#

Advance one step. Actions are forces; simple drag is applied (-friction * vel).

Parameters:
  • env (Environment) – The current environment.

  • action (jax.Array) – The vector of actions each agent in the environment should take.

Returns:

The updated environment state.

Return type:

Environment

static observation(env: MultiNavigator) Array[source]#

Build per-agent observations.

Contents per agent#

  • Unit vector to objective (shape (dim,)) –> Direction

  • Clamped delta to objective (shape (dim,)) –> Local precision

  • Velocity (shape (dim,))

  • LiDAR proximity, normalized by lidar_range (shape (n_lidar_rays,))

returns:

Array of shape (N, 3 * dim + n_lidar_rays)

rtype:

jax.Array

static reward(env: MultiNavigator) Array[source]#

Returns a vector of per-agent rewards.

\[\mathrm{rew}_t = (e^{-2 \cdot d_t} - e^{-2 \cdot d_t^{\mathrm{prev}}}) - w_{\text{ke}} (K_t - K_{t-1}) + w_{\text{coop}} \cdot \mathrm{mean}(e^{-2 \cdot d_t} - e^{-2 \cdot d_t^{\mathrm{prev}}}) + w_{\text{near}} \cdot \mathbf{1}[d_t \le r]\]

where \(d_t\) is the distance to the objective at step \(t\), \(K_t\) is the kinetic energy at step \(t\), \(w_{\text{ke}}\) is the kinetic-energy penalty weight, and \(w_{\text{coop}}\) weights a shared team-progress bonus, and \(w_{\text{near}}\) weights a near-goal bonus.

Parameters:

env (Environment) – Current environment.

Returns:

Shape (N,).

Return type:

jax.Array

static done(env: MultiNavigator) Array[source]#

Returns a boolean indicating whether the environment has ended. The episode terminates when the maximum number of steps is reached.

Parameters:

env (Environment) – The current environment.

Returns:

Boolean array indicating whether the episode has ended.

Return type:

jax.Array

property action_space_size: int[source]#

Flattened action size per agent. Actions passed to step() have shape (A, action_space_size).

property action_space_shape: tuple[int][source]#

Original per-agent action shape (useful for reshaping inside the environment).

property observation_space_size: int[source]#

Flattened observation size per agent. observation() returns shape (A, observation_space_size).