jaxdem.rl.environments.three_gears#

Multi-agent 2-D environment with three gears and a triangle objective.

Classes

ThreeGears(state, system, env_params)

Multi-agent 2-D environment with three gears.

class jaxdem.rl.environments.three_gears.ThreeGears(state: State, system: System, env_params: dict[str, Any])#

Bases: Environment

Multi-agent 2-D environment with three gears.

The environment consists of three active gears composed of spheres. All gears can apply torque to themselves. The shared objective is to navigate the gears to form a triangular structure defined by a randomized target position.

Note

Similar to the TwoGears environment, if one wants some realistic parameters for training, skip_frames = 50 will give a response rate of 200 Hz, meaning that num_steps_epoch = 100 gives a horizon of 0.5 seconds.

classmethod Create(box_size: float = 10.0, max_steps: int = 100000, friction: float = 0.2, ke_weight: float = 0.001) ThreeGears[source]#

Create a three-gears 2-D environment.

Parameters:
  • box_size (float) – Size of the square bounding box.

  • max_steps (int) – Episode length in physics steps.

  • friction (float) – Viscous drag coefficient applied as -friction * vel.

  • ke_weight (float) – Weight for the differential kinetic energy penalty.

Returns:

A freshly constructed environment (call reset() before use).

Return type:

ThreeGears

static reset(env: ThreeGears, key: Array) Environment[source]#

Reset the environment to a random initial configuration.

Parameters:
  • env (Environment) – The environment instance to reset.

  • key (jax.Array) – PRNG key used to sample the initial positions and objective triangle.

Returns:

The environment with a fresh episode state.

Return type:

Environment

static step(env: ThreeGears, action: Array) Environment[source]#

Advance the environment by one physics step.

Applies torque to the active gears, computes inter-gear forces, and applies viscous drag.

Parameters:
  • env (Environment) – Current environment.

  • action (jax.Array) – Actions for all gears.

Returns:

Updated environment after physics integration and sensor updates.

Return type:

Environment

static observation(env: ThreeGears) Array[source]#

Build the observation vector.

The observation vector contains 22 features:

Feature

Size

Distance to floor

1

Distance to left/right walls

2

Unit vector to target

2

Clamped displacement to target

2

Unit vector to neighbor j

2

Clamped displacement to neighbor j

2

\(\sin(\Delta\theta_j)\)

1

\(\cos(\Delta\theta_j)\)

1

Unit vector to neighbor k

2

Clamped displacement to neighbor k

2

\(\sin(\Delta\theta_k)\)

1

\(\cos(\Delta\theta_k)\)

1

Velocity (x, y)

2

Angular velocity

1

Returns:

Observation vector of size 22.

Return type:

jax.Array

static reward(env: ThreeGears) Array[source]#

Compute the cooperative reward.

The shared reward is based on the differential distance to the objective minus a penalty for the change in kinetic energy:

\[R_t = (d_{t-1} - d_t) - w_{\text{ke}} \sum_i (K_t^i - K_{t-1}^i)\]

where \(d_t\) is the total distance to the objective at step \(t\), \(K_t^i\) is the kinetic energy of agent \(i\), and \(w_{\text{ke}}\) is the kinetic energy weight.

Returns:

Reward value, identical for all agents.

Return type:

jax.Array

static done(env: ThreeGears) Array[source]#
property action_space_size: int[source]#

Flattened action size per agent. Actions passed to step() have shape (A, action_space_size).

property action_space_shape: tuple[int][source]#

Original per-agent action shape (useful for reshaping inside the environment).

property observation_space_size: int[source]#

Flattened observation size per agent. observation() returns shape (A, observation_space_size).

property max_num_agents: int[source]#

Maximum number of active agents in the environment.