jaxdem.rl.environments.multi_roller#
Environment where multiple rolling agents navigate towards assigned targets.
Functions
|
Normal, frictional, and restitution forces for spheres on a \(z = 0\) plane. |
Classes
|
Multi-agent rolling environment toward assigned targets. |
- class jaxdem.rl.environments.multi_roller.MultiRoller(state: State, system: System, env_params: dict[str, Any], n_lidar_rays: int)#
Bases:
EnvironmentMulti-agent rolling environment toward assigned targets.
Each agent controls a torque vector that is applied directly to a sphere on a \(z=0\) floor. Translational drag
-friction * veland angular damping-friction * ang_velare applied each step. Objectives are sampled and assigned one-to-one via a random permutation.The reward uses exponential potential-based shaping:
\[R_i = (e^{-2d_i} - e^{-2d_i^{\mathrm{prev}}}) - w_{\mathrm{ke}}(K_i - K_i^{\mathrm{prev}}) + w_{\mathrm{coop}} \cdot \frac{1}{N}\sum_j (e^{-2d_j} - e^{-2d_j^{\mathrm{prev}}}) + w_{\mathrm{near}}\,\mathbf{1}[d_i \le r_i]\]where \(d_i\) is the distance to the assigned objective in the \(xy\) plane and \(K_i\) is the translational kinetic energy of agent \(i\).
Notes
The observation vector per agent is:
Feature
Size
Unit direction to objective
2Clamped displacement
2Velocity
2LiDAR proximity (normalised)
n_lidar_raysIf one wants some realistic parameters for training,
skip_frames = 50will give a response rate of 200 Hz, meaning thatnum_steps_epoch = 100gives a horizon of 0.5 seconds.- n_lidar_rays: int#
Number of angular bins for each LiDAR sensor.
- classmethod Create(N: int = 64, min_box_size: float = 20.0, max_box_size: float = 20.0, box_padding: float = 5.0, max_steps: int = 100000, friction: float = 0.2, ke_weight: float = 0.1, coop_weight: float = 0.2, near_goal_bonus: float = 0.1, lidar_range: float = 6.0, n_lidar_rays: int = 16) MultiRoller[source]#
Create a multi-agent roller environment.
- Parameters:
N (int) – Number of agents.
min_box_size (float) – Range for the random square domain side length sampled at each
reset().max_box_size (float) – Range for the random square domain side length sampled at each
reset().box_padding (float) – Extra padding around the domain in multiples of the particle radius.
max_steps (int) – Episode length in physics steps.
friction (float) – Translational and angular damping coefficient.
ke_weight (float) – Weight for the differential kinetic energy penalty.
coop_weight (float) – Weight for the shared team-progress bonus.
near_goal_bonus (float) – Reward bonus applied when an agent is within one radius of its objective.
lidar_range (float) – Maximum detection range for the LiDAR sensor.
n_lidar_rays (int) – Number of angular LiDAR bins spanning \([-\pi, \pi)\).
- Returns:
A freshly constructed environment (call
reset()before use).- Return type:
- static reset(env: MultiRoller, key: Array | ndarray | bool | number | bool | int | float | complex) Environment[source]#
Initialize the environment with random positions and objectives.
- Parameters:
env (Environment) – Current environment instance.
key (ArrayLike) – JAX random number generator key.
- Returns:
Freshly initialized environment.
- Return type:
- static step(env: MultiRoller, action: Array) Environment[source]#
Advance one step. Actions are torques; simple damping is applied.
- Parameters:
env (Environment) – The current environment.
action (jax.Array) – The vector of actions each agent in the environment should take.
- Returns:
The updated environment state.
- Return type:
- static observation(env: MultiRoller) Array[source]#
Build per-agent observations.
Contents per agent#
Unit vector to objective in the \(xy\) plane (shape (2,)).
Clamped objective delta in the \(xy\) plane (shape (2,)).
Velocity in the \(xy\) plane (shape (2,)).
LiDAR proximity, normalized by
lidar_range(shape (n_lidar_rays,)).
- returns:
Array of shape
(N, 6 + n_lidar_rays)- rtype:
jax.Array
- static reward(env: MultiRoller) Array[source]#
Returns a vector of per-agent rewards.
\[\mathrm{rew}_t = (e^{-2 \cdot d_t} - e^{-2 \cdot d_t^{\mathrm{prev}}}) - w_{\text{ke}} (K_t - K_{t-1}) + w_{\text{coop}} \cdot \mathrm{mean}(e^{-2 \cdot d_t} - e^{-2 \cdot d_t^{\mathrm{prev}}}) + w_{\text{near}} \cdot \mathbf{1}[d_t \le r]\]where \(d_t\) is the distance to the objective at step \(t\), \(K_t\) is the kinetic energy at step \(t\), \(w_{\text{ke}}\) is the kinetic-energy penalty weight, and \(w_{\text{coop}}\) weights a shared team-progress bonus, and \(w_{\text{near}}\) weights a near-goal bonus.
- Parameters:
env (Environment) – Current environment.
- Returns:
Shape
(N,).- Return type:
jax.Array
- static done(env: MultiRoller) Array[source]#
Returns a boolean indicating whether the environment has ended. The episode terminates when the maximum number of steps is reached.
- Parameters:
env (Environment) – The current environment.
- Returns:
Boolean array indicating whether the environment has ended.
- Return type:
jax.Array
- property action_space_size: int[source]#
Flattened action size per agent. Actions passed to
step()have shape(A, action_space_size).
- property action_space_shape: tuple[int][source]#
Original per-agent action shape (useful for reshaping inside the environment).
- property observation_space_size: int[source]#
Flattened observation size per agent.
observation()returns shape(A, observation_space_size).