jaxdem.rl.environments#
Reinforcement-learning environment interface.
Classes
|
Defines the interface for reinforcement-learning environments. |
- class jaxdem.rl.environments.Environment(state: State, system: System, env_params: Dict[str, Any])[source]#
Bases:
Factory,ABCDefines the interface for reinforcement-learning environments.
Let A be the number of agents (A ≥ 1). Single-agent environments still use A=1.
Observations and actions are flattened per agent to fixed sizes. Use
action_space_shapeto reshape inside the environment if needed.
Required shapes
Observation:
(A, observation_space_size)Action (input to
step()):(A, action_space_size)Reward:
(A,)Done: scalar boolean for the whole environment
TODO:
Truncated data field: per-agent termination flag
Render method
Example
To define a custom environment, inherit from
Environmentand implement the abstract methods:>>> @Environment.register("MyCustomEnv") >>> @jax.tree_util.register_dataclass >>> @dataclass(slots=True) >>> class MyCustomEnv(Environment): ...
- env_params: Dict[str, Any]#
Environment-specific parameters.
- abstractmethod static reset(env: Environment, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source][source]#
Initialize the environment to a valid start state.
- Parameters:
env (Environment) – Instance of the environment.
key (jax.random.PRNGKey) – JAX random number generator key.
- Returns:
Freshly initialized environment.
- Return type:
- static reset_if_done(env: Environment, done: Array, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source][source]#
Conditionally resets the environment if the environment has reached a terminal state.
This method checks the done flag and, if True, calls the environment’s reset method to reinitialize the state. Otherwise, it returns the current environment unchanged.
- Parameters:
env (Environment) – The current environment instance.
done (jax.Array) – A boolean flag indicating whether the environment has reached a terminal state.
key (jax.random.PRNGKey) – JAX random number generator key used for reinitialization.
- Returns:
Either the freshly reset environment (if done is True) or the unchanged environment (if done is False).
- Return type:
- abstractmethod static step(env: Environment, action: Array) Environment[source][source]#
Advance the simulation by one step using per-agent actions.
- Parameters:
env (Environment) – The current environment.
action (jax.Array) – The vector of actions each agent in the environment should take.
- Returns:
The updated environment state.
- Return type:
- abstractmethod static observation(env: Environment) Array[source][source]#
Returns the per-agent observation vector.
- Parameters:
env (Environment) – The current environment.
- Returns:
Vector corresponding to the environment observation.
- Return type:
jax.Array
- abstractmethod static reward(env: Environment) Array[source][source]#
Returns the per-agent immediate rewards.
- Parameters:
env (Environment) – The current environment.
- Returns:
Vector corresponding to all the agent’s rewards based on the current environment state.
- Return type:
jax.Array
- abstractmethod static done(env: Environment) Array[source][source]#
Returns a boolean indicating whether the environment has ended.
- Parameters:
env (Environment) – The current environment.
- Returns:
A bool indicating when the environment ended
- Return type:
jax.Array
- static info(env: Environment) Dict[str, Any][source][source]#
Return auxiliary diagnostic information.
By default, returns an empty dict. Subclasses may override to provide environment specific information.
- Parameters:
env (Environment) – The current state of the environment.
- Returns:
A dictionary with additional information about the environment.
- Return type:
Dict
- property action_space_size: int[source]#
Flattened action size per agent. Actions passed to
step()have shape(A, action_space_size).
- property action_space_shape: Tuple[int][source]#
Original per-agent action shape (useful for reshaping inside the environment).
- property observation_space_size: int[source]#
Flattened observation size per agent.
observation()returns shape(A, observation_space_size).
Bases:
EnvironmentMulti-agent navigation environment with collision penalties.
Agents seek fixed objectives in a 2D reflective box. Each step applies a force-like action, advances simple dynamics, updates LiDAR, and returns shaped rewards with an optional final bonus on goal.
Number of lidar rays for the vision system.
Initialize the environment with randomly placed particles and velocities.
- Parameters:
env (Environment) – Current environment instance.
key (jax.random.PRNGKey) – JAX random number generator key.
- Returns:
Freshly initialized environment.
- Return type:
Advance one step. Actions are forces; simple drag is applied.
- Parameters:
env (Environment) – The current environment.
action (jax.Array) – The vector of actions each agent in the environment should take.
- Returns:
The updated environment state.
- Return type:
Build per-agent observations.
Contents per agent#
Wrapped displacement to objective
Δx(shape(2,)).Velocity
v(shape(2,)).LiDAR proximities (shape
(n_lidar_rays,)).
- returns:
Array of shape
(N, 2 * dim + n_lidar_rays)scaled by the maximum box size for normalization.- rtype:
jax.Array
Per-agent reward with distance shaping, goal bonus, LiDAR collision penalty, and a global shaping term.
Equations
Let \(\delta_i=\operatorname{displacement}(\mathbf{x}_i,\mathbf{objective})\), \(d_i=\lVert\delta_i\rVert_2\), and \(\mathbf{1}[\cdot]\) the indicator. With shaping factors \(\alpha_{\text{prev}},\alpha\), final reward \(R_f\), collision penalty math:C, global shaping factor math:beta, and radius \(r_i\). Let \(\ell_{i,k}\) be the LiDAR proximities for agent \(i\) and ray \(k\), and \(h_i = \sum_k \mathbf{1}[\ell_{i,k} > (\text{LIDAR_range} - 2r_i)]\) be the collision count. The rewards consists on:
\[\mathrm{rew}^{\text{shape}}_i = \alpha_{\text{prev}}\,d^{\text{prev}}_i - \alpha\, d_i\]\[\mathrm{rew}_i = \mathrm{rew}^{\text{shape}}_i + R_f\,\mathbf{1}[\,d_i < \text{goal_threshold}\times r_i\,] + C\, h_i - \beta\, \overline{d},\]\[\overline{d} = \tfrac{1}{N}\sum_j d_j\]- Parameters:
env (Environment) – Current environment.
- Returns:
Shape
(N,). The normalized per-agent reward vector.- Return type:
jax.Array
Returns a boolean indicating whether the environment has ended. The episode terminates when the maximum number of steps is reached.
- Parameters:
env (Environment) – The current environment.
- Returns:
Boolean array indicating whether the episode has ended.
- Return type:
jax.Array
Flattened action size per agent. Actions passed to
step()have shape(A, action_space_size).
Original per-agent action shape (useful for reshaping inside the environment).
Flattened observation size per agent.
observation()returns shape(A, observation_space_size).
Bases:
EnvironmentSingle-agent navigation environment toward a fixed target.
Custom factory method for this environment.
Initialize the environment with a randomly placed particle and velocity.
- Parameters:
env (Environment) – Current environment instance.
key (jax.random.PRNGKey) – JAX random number generator key.
- Returns:
Freshly initialized environment.
- Return type:
Advance one step. Actions are forces; simple drag is applied.
- Parameters:
env (Environment) – The current environment.
action (jax.Array) – The vector of actions each agent in the environment should take.
- Returns:
The updated environment state.
- Return type:
Build per-agent observations.
Contents per agent#
Wrapped displacement to objective
Δx(shape(2,)).Velocity
v(shape(2,)).
- returns:
Array of shape
(N, 2 * dim)scaled by the maximum box size for normalization.- rtype:
jax.Array
Returns a vector of per-agent rewards.
Equation
Let \(\delta_i=\operatorname{displacement}(\mathbf{x}_i,\mathbf{objective})\), \(d_i=\lVert\delta_i\rVert_2\), and \(\mathbf{1}[\cdot]\) the indicator. With shaping factors \(\alpha_{\text{prev}},\alpha\), final reward \(R_f\), and radius \(r_i\):
\[\mathrm{rew}^{\text{shape}}_i = \alpha_{\text{prev}}\,d^{\text{prev}}_i - \alpha\, d_i\]Final reward:
\[\mathrm{rew}_i = \mathrm{rew}^{\text{shape}}_i + R_f\,\mathbf{1}[\,d_i < \text{goal_threshold}\times r_i\,]\]- Parameters:
env (Environment) – Current environment.
- Returns:
Shape
(N,). The normalized per-agent reward vector.- Return type:
jax.Array
Returns a boolean indicating whether the environment has ended. The episode terminates when the maximum number of steps is reached.
- Parameters:
env (Environment) – The current environment.
- Returns:
Boolean array indicating whether the episode has ended.
- Return type:
jax.Array
Flattened action size per agent. Actions passed to
step()have shape(A, action_space_size).
Original per-agent action shape (useful for reshaping inside the environment).
Flattened observation size per agent.
observation()returns shape(A, observation_space_size).
- class jaxdem.rl.environments.SingleRoller3D(state: State, system: System, env_params: Dict[str, Any])[source]#
Bases:
EnvironmentSingle-agent navigation where the agent rolls on a plane using torque control.
- classmethod Create(min_box_size: float = 1.0, max_box_size: float = 1.0, max_steps: int = 6000, final_reward: float = 2.0, shaping_factor: float = 1.0, prev_shaping_factor: float = 1.0, goal_threshold: float = 0.6666666666666666) SingleRoller3D[source][source]#
- static reset(env: Environment, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source][source]#
- static step(env: Environment, action: Array) Environment[source][source]#
- static observation(env: Environment) Array[source][source]#
- static reward(env: Environment) Array[source][source]#
- static done(env: Environment) Array[source][source]#
- property action_space_size: int[source]#
Flattened action size per agent. Actions passed to
step()have shape(A, action_space_size).
- property action_space_shape: Tuple[int][source]#
Original per-agent action shape (useful for reshaping inside the environment).
- property observation_space_size: int[source]#
Flattened observation size per agent.
observation()returns shape(A, observation_space_size).
- class jaxdem.rl.environments.MultiRoller(state: State, system: System, env_params: Dict[str, Any], n_lidar_rays: int)[source]#
Bases:
EnvironmentMulti-agent 3D rolling environment.
Agents are spheres that roll on a floor. They are controlled via 3D torque vectors. Includes collision handling, LiDAR sensing, and distance-based reward shaping.
- n_lidar_rays: int#
- classmethod Create(N: int = 64, min_box_size: float = 1.0, max_box_size: float = 1.0, box_padding: float = 5.0, max_steps: int = 5760, final_reward: float = 1.0, shaping_factor: float = 0.005, prev_shaping_factor: float = 0.0, global_shaping_factor: float = 0.0, collision_penalty: float = -0.005, goal_threshold: float = 0.6666666666666666, lidar_range: float = 0.45, n_lidar_rays: int = 16) MultiRoller[source][source]#
- static reset(env: Environment, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source][source]#
Initialize the environment with randomly placed particles.
- Parameters:
env (Environment) – Current environment instance.
key (jax.random.PRNGKey) – JAX random number generator key.
- Returns:
Freshly initialized environment.
- Return type:
- static step(env: Environment, action: Array) Environment[source][source]#
- static observation(env: Environment) Array[source][source]#
- static reward(env: Environment) Array[source][source]#
- static done(env: Environment) Array[source][source]#
- property action_space_size: int[source]#
Flattened action size per agent. Actions passed to
step()have shape(A, action_space_size).
- property action_space_shape: Tuple[int][source]#
Original per-agent action shape (useful for reshaping inside the environment).
- property observation_space_size: int[source]#
Flattened observation size per agent.
observation()returns shape(A, observation_space_size).
Modules
Multi-agent navigation task with collision penalties. |
|
Environment where multiple agents roll towards targets on a 3D floor. |
|
Environment where a single agent navigates towards a target. |
|
Environment where a single agent rolls towards a target on the floor. |