jaxdem.rl.environments.multi_roller#

Environment where multiple agents roll towards targets on a 3D floor.

Functions

frictional_wall_force(pos, state, system)

Calculates normal and frictional forces for spheres on a y=0 plane.

Classes

MultiRoller(state, system, env_params, ...)

Multi-agent 3D rolling environment.

class jaxdem.rl.environments.multi_roller.MultiRoller(state: State, system: System, env_params: Dict[str, Any], n_lidar_rays: int)[source]#

Bases: Environment

Multi-agent 3D rolling environment.

Agents are spheres that roll on a floor. They are controlled via 3D torque vectors. Includes collision handling, LiDAR sensing, and distance-based reward shaping.

n_lidar_rays: int#
classmethod Create(N: int = 64, min_box_size: float = 1.0, max_box_size: float = 1.0, box_padding: float = 5.0, max_steps: int = 5760, final_reward: float = 1.0, shaping_factor: float = 0.005, prev_shaping_factor: float = 0.0, global_shaping_factor: float = 0.0, collision_penalty: float = -0.005, goal_threshold: float = 0.6666666666666666, lidar_range: float = 0.45, n_lidar_rays: int = 16) MultiRoller[source][source]#
static reset(env: Environment, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source][source]#

Initialize the environment with randomly placed particles.

Parameters:
  • env (Environment) – Current environment instance.

  • key (jax.random.PRNGKey) – JAX random number generator key.

Returns:

Freshly initialized environment.

Return type:

Environment

static step(env: Environment, action: Array) Environment[source][source]#
static observation(env: Environment) Array[source][source]#
static reward(env: Environment) Array[source][source]#
static done(env: Environment) Array[source][source]#
property action_space_size: int[source]#

Flattened action size per agent. Actions passed to step() have shape (A, action_space_size).

property action_space_shape: Tuple[int][source]#

Original per-agent action shape (useful for reshaping inside the environment).

property observation_space_size: int[source]#

Flattened observation size per agent. observation() returns shape (A, observation_space_size).