jaxdem.rl.environments.three_gears#
Multi-agent 2-D environment with three gears and a triangle objective.
Classes
|
Multi-agent 2-D environment with three gears. |
- class jaxdem.rl.environments.three_gears.ThreeGears(state: State, system: System, env_params: dict[str, Any])#
Bases:
EnvironmentMulti-agent 2-D environment with three gears.
The environment consists of three active gears composed of spheres. All gears can apply torque to themselves. The shared objective is to navigate the gears to form a triangular structure defined by a randomized target position.
Note
Similar to the TwoGears environment, if one wants some realistic parameters for training,
skip_frames = 50will give a response rate of 200 Hz, meaning thatnum_steps_epoch = 100gives a horizon of 0.5 seconds.- classmethod Create(box_size: float = 10.0, max_steps: int = 100000, friction: float = 0.2, ke_weight: float = 0.001) ThreeGears[source]#
Create a three-gears 2-D environment.
- Parameters:
box_size (float) – Size of the square bounding box.
max_steps (int) – Episode length in physics steps.
friction (float) – Viscous drag coefficient applied as
-friction * vel.ke_weight (float) – Weight for the differential kinetic energy penalty.
- Returns:
A freshly constructed environment (call
reset()before use).- Return type:
- static reset(env: ThreeGears, key: Array) Environment[source]#
Reset the environment to a random initial configuration.
- Parameters:
env (Environment) – The environment instance to reset.
key (jax.Array) – PRNG key used to sample the initial positions and objective triangle.
- Returns:
The environment with a fresh episode state.
- Return type:
- static step(env: ThreeGears, action: Array) Environment[source]#
Advance the environment by one physics step.
Applies torque to the active gears, computes inter-gear forces, and applies viscous drag.
- Parameters:
env (Environment) – Current environment.
action (jax.Array) – Actions for all gears.
- Returns:
Updated environment after physics integration and sensor updates.
- Return type:
- static observation(env: ThreeGears) Array[source]#
Build the observation vector.
The observation vector contains 22 features:
Feature
Size
Distance to floor
1Distance to left/right walls
2Unit vector to target
2Clamped displacement to target
2Unit vector to neighbor j
2Clamped displacement to neighbor j
2\(\sin(\Delta\theta_j)\)
1\(\cos(\Delta\theta_j)\)
1Unit vector to neighbor k
2Clamped displacement to neighbor k
2\(\sin(\Delta\theta_k)\)
1\(\cos(\Delta\theta_k)\)
1Velocity (x, y)
2Angular velocity
1- Returns:
Observation vector of size
22.- Return type:
jax.Array
- static reward(env: ThreeGears) Array[source]#
Compute the cooperative reward.
The shared reward is based on the differential distance to the objective minus a penalty for the change in kinetic energy:
\[R_t = (d_{t-1} - d_t) - w_{\text{ke}} \sum_i (K_t^i - K_{t-1}^i)\]where \(d_t\) is the total distance to the objective at step \(t\), \(K_t^i\) is the kinetic energy of agent \(i\), and \(w_{\text{ke}}\) is the kinetic energy weight.
- Returns:
Reward value, identical for all agents.
- Return type:
jax.Array
- static done(env: ThreeGears) Array[source]#
- property action_space_size: int[source]#
Flattened action size per agent. Actions passed to
step()have shape(A, action_space_size).
- property action_space_shape: tuple[int][source]#
Original per-agent action shape (useful for reshaping inside the environment).
- property observation_space_size: int[source]#
Flattened observation size per agent.
observation()returns shape(A, observation_space_size).