jaxdem.rl.environments.multi_navigator#
Environment where multiple agents navigate towards assigned targets.
Classes
|
Multi-agent navigation environment toward assigned targets. |
- class jaxdem.rl.environments.multi_navigator.MultiNavigator(state: State, system: System, env_params: dict[str, Any], n_lidar_rays: int)#
Bases:
EnvironmentMulti-agent navigation environment toward assigned targets.
Each agent controls a force vector that is applied directly to a sphere inside a reflective box. Viscous drag
-friction * velis added each step. Objectives are sampled and assigned one-to-one via a random permutation.The reward uses exponential potential-based shaping:
\[R_i = (e^{-2d_i} - e^{-2d_i^{\mathrm{prev}}}) - w_{\mathrm{ke}}(K_i - K_i^{\mathrm{prev}}) + w_{\mathrm{coop}} \cdot \frac{1}{N}\sum_j (e^{-2d_j} - e^{-2d_j^{\mathrm{prev}}}) + w_{\mathrm{near}}\,\mathbf{1}[d_i \le r_i]\]where \(d_i\) is the distance to the assigned objective and \(K_i\) is the translational kinetic energy of agent \(i\).
Notes
The observation vector per agent is:
Feature
Size
Unit direction to objective
dimClamped displacement
dimVelocity
dimLiDAR proximity (normalised)
n_lidar_raysIf one wants some realistic parameters for training,
skip_frames = 50will give a response rate of 200 Hz, meaning thatnum_steps_epoch = 100gives a horizon of 0.5 seconds.- n_lidar_rays: int#
Number of angular bins for each LiDAR sensor.
- classmethod Create(N: int = 64, min_box_size: float = 20.0, max_box_size: float = 20.0, box_padding: float = 5.0, max_steps: int = 100000, friction: float = 0.2, ke_weight: float = 0.1, coop_weight: float = 0.2, near_goal_bonus: float = 0.1, lidar_range: float = 6.0, n_lidar_rays: int = 16) → MultiNavigator[source]#
Create a multi-agent navigator environment.
- Parameters:
N (int) – Number of agents.
min_box_size (float) – Range for the random square domain side length sampled at each
reset().max_box_size (float) – Range for the random square domain side length sampled at each
reset().box_padding (float) – Extra padding around the domain in multiples of the particle radius.
max_steps (int) – Episode length in physics steps.
friction (float) – Viscous drag coefficient applied as
-friction * vel.ke_weight (float) – Weight for the differential kinetic energy penalty.
coop_weight (float) – Weight for the shared team-progress bonus.
near_goal_bonus (float) – Reward bonus applied when an agent is within one radius of its objective.
lidar_range (float) – Maximum detection range for the LiDAR sensor.
n_lidar_rays (int) – Number of angular LiDAR bins spanning \([-\pi, \pi)\).
- Returns:
A freshly constructed environment (call
reset()before use).- Return type:
- static reset(env: MultiNavigator, key: Array | ndarray | bool | number | bool | int | float | complex) → Environment[source]#
Initialize the environment with random positions and objectives.
- Parameters:
env (Environment) – Current environment instance.
key (ArrayLike) – JAX random number generator key.
- Returns:
Freshly initialized environment.
- Return type:
- static step(env: MultiNavigator, action: Array) → Environment[source]#
Advance one step. Actions are forces; simple drag is applied (-friction * vel).
- Parameters:
env (Environment) – The current environment.
action (jax.Array) – The vector of actions each agent in the environment should take.
- Returns:
The updated environment state.
- Return type:
- static observation(env: MultiNavigator) → Array[source]#
Build per-agent observations.
Contents per agent#
Unit vector to objective (shape (dim,)) –> Direction
Clamped delta to objective (shape (dim,)) –> Local precision
Velocity (shape (dim,))
LiDAR proximity, normalized by
lidar_range(shape (n_lidar_rays,))
- returns:
Array of shape
(N, 3 * dim + n_lidar_rays)- rtype:
jax.Array
- static reward(env: MultiNavigator) → Array[source]#
Returns a vector of per-agent rewards.
\[\mathrm{rew}_t = (e^{-2 \cdot d_t} - e^{-2 \cdot d_t^{\mathrm{prev}}}) - w_{\text{ke}} (K_t - K_{t-1}) + w_{\text{coop}} \cdot \mathrm{mean}(e^{-2 \cdot d_t} - e^{-2 \cdot d_t^{\mathrm{prev}}}) + w_{\text{near}} \cdot \mathbf{1}[d_t \le r]\]where \(d_t\) is the distance to the objective at step \(t\), \(K_t\) is the kinetic energy at step \(t\), \(w_{\text{ke}}\) is the kinetic-energy penalty weight, and \(w_{\text{coop}}\) weights a shared team-progress bonus, and \(w_{\text{near}}\) weights a near-goal bonus.
- Parameters:
env (Environment) – Current environment.
- Returns:
Shape
(N,).- Return type:
jax.Array
- static done(env: MultiNavigator) → Array[source]#
Returns a boolean indicating whether the environment has ended. The episode terminates when the maximum number of steps is reached.
- Parameters:
env (Environment) – The current environment.
- Returns:
Boolean array indicating whether the episode has ended.
- Return type:
jax.Array
- property action_space_size: int[source]#
Flattened action size per agent. Actions passed to
step()have shape(A, action_space_size).
- property action_space_shape: tuple[int][source]#
Original per-agent action shape (useful for reshaping inside the environment).
- property observation_space_size: int[source]#
Flattened observation size per agent.
observation()returns shape(A, observation_space_size).