jaxdem.rl.environments#

Reinforcement-learning environment interface.

Classes

Environment(state, system, env_params)

Defines the interface for reinforcement-learning environments.

class jaxdem.rl.environments.Environment(state: State, system: System, env_params: Dict[str, Any])[source]#

Bases: Factory, ABC

Defines the interface for reinforcement-learning environments.

  • Let A be the number of agents (A ≥ 1). Single-agent environments still use A=1.

  • Observations and actions are flattened per agent to fixed sizes. Use action_space_shape to reshape inside the environment if needed.

Required shapes

  • Observation: (A, observation_space_size)

  • Action (input to step()): (A, action_space_size)

  • Reward: (A,)

  • Done: scalar boolean for the whole environment

TODO:

  • Truncated data field: per-agent termination flag

  • Render method

Example

To define a custom environment, inherit from Environment and implement the abstract methods:

>>> @Environment.register("MyCustomEnv")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True)
>>> class MyCustomEnv(Environment):
    ...
state: State#

Simulation state.

system: System#

Simulation system configuration.

env_params: Dict[str, Any]#

Environment-specific parameters.

abstractmethod static reset(env: Environment, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source][source]#

Initialize the environment to a valid start state.

Parameters:
  • env (Environment) – Instance of the environment.

  • key (jax.random.PRNGKey) – JAX random number generator key.

Returns:

Freshly initialized environment.

Return type:

Environment

static reset_if_done(env: Environment, done: Array, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source][source]#

Conditionally resets the environment if the environment has reached a terminal state.

This method checks the done flag and, if True, calls the environment’s reset method to reinitialize the state. Otherwise, it returns the current environment unchanged.

Parameters:
  • env (Environment) – The current environment instance.

  • done (jax.Array) – A boolean flag indicating whether the environment has reached a terminal state.

  • key (jax.random.PRNGKey) – JAX random number generator key used for reinitialization.

Returns:

Either the freshly reset environment (if done is True) or the unchanged environment (if done is False).

Return type:

Environment

abstractmethod static step(env: Environment, action: Array) Environment[source][source]#

Advance the simulation by one step using per-agent actions.

Parameters:
  • env (Environment) – The current environment.

  • action (jax.Array) – The vector of actions each agent in the environment should take.

Returns:

The updated environment state.

Return type:

Environment

abstractmethod static observation(env: Environment) Array[source][source]#

Returns the per-agent observation vector.

Parameters:

env (Environment) – The current environment.

Returns:

Vector corresponding to the environment observation.

Return type:

jax.Array

abstractmethod static reward(env: Environment) Array[source][source]#

Returns the per-agent immediate rewards.

Parameters:

env (Environment) – The current environment.

Returns:

Vector corresponding to all the agent’s rewards based on the current environment state.

Return type:

jax.Array

abstractmethod static done(env: Environment) Array[source][source]#

Returns a boolean indicating whether the environment has ended.

Parameters:

env (Environment) – The current environment.

Returns:

A bool indicating when the environment ended

Return type:

jax.Array

static info(env: Environment) Dict[str, Any][source][source]#

Return auxiliary diagnostic information.

By default, returns an empty dict. Subclasses may override to provide environment specific information.

Parameters:

env (Environment) – The current state of the environment.

Returns:

A dictionary with additional information about the environment.

Return type:

Dict

property num_envs: int[source]#

Number of batched environments.

property max_num_agents: int[source]#

Maximum number of active agents in the environment.

property action_space_size: int[source]#

Flattened action size per agent. Actions passed to step() have shape (A, action_space_size).

property action_space_shape: Tuple[int][source]#

Original per-agent action shape (useful for reshaping inside the environment).

property observation_space_size: int[source]#

Flattened observation size per agent. observation() returns shape (A, observation_space_size).

class jaxdem.rl.environments.MultiNavigator(state: State, system: System, env_params: Dict[str, Any], n_lidar_rays: int)[source]#

Bases: Environment

Multi-agent navigation environment with collision penalties.

Agents seek fixed objectives in a 2D reflective box. Each step applies a force-like action, advances simple dynamics, updates LiDAR, and returns shaped rewards with an optional final bonus on goal.

n_lidar_rays: int#

Number of lidar rays for the vision system.

classmethod Create(N: int = 64, min_box_size: float = 1.0, max_box_size: float = 1.0, box_padding: float = 5.0, max_steps: int = 5760, final_reward: float = 1.0, shaping_factor: float = 0.005, prev_shaping_factor: float = 0.0, global_shaping_factor: float = 0.0, collision_penalty: float = -0.005, goal_threshold: float = 0.6666666666666666, lidar_range: float = 0.45, n_lidar_rays: int = 16) MultiNavigator[source][source]#
property action_space_shape: Tuple[int][source]#

Original per-agent action shape (useful for reshaping inside the environment).

property action_space_size: int[source]#

Flattened action size per agent. Actions passed to step() have shape (A, action_space_size).

static done(env: Environment) Array[source][source]#

Returns a boolean indicating whether the environment has ended. The episode terminates when the maximum number of steps is reached.

Parameters:

env (Environment) – The current environment.

Returns:

Boolean array indicating whether the episode has ended.

Return type:

jax.Array

static observation(env: Environment) Array[source][source]#

Build per-agent observations.

Contents per agent#

  • Wrapped displacement to objective Δx (shape (2,)).

  • Velocity v (shape (2,)).

  • LiDAR proximities (shape (n_lidar_rays,)).

returns:

Array of shape (N, 2 * dim + n_lidar_rays) scaled by the maximum box size for normalization.

rtype:

jax.Array

property observation_space_size: int[source]#

Flattened observation size per agent. observation() returns shape (A, observation_space_size).

static reset(env: Environment, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source][source]#

Initialize the environment with randomly placed particles and velocities.

Parameters:
  • env (Environment) – Current environment instance.

  • key (jax.random.PRNGKey) – JAX random number generator key.

Returns:

Freshly initialized environment.

Return type:

Environment

static reward(env: Environment) Array[source][source]#

Per-agent reward with distance shaping, goal bonus, LiDAR collision penalty, and a global shaping term.

Equations

Let \(\delta_i=\operatorname{displacement}(\mathbf{x}_i,\mathbf{objective})\), \(d_i=\lVert\delta_i\rVert_2\), and \(\mathbf{1}[\cdot]\) the indicator. With shaping factors \(\alpha_{\text{prev}},\alpha\), final reward \(R_f\), collision penalty math:C, global shaping factor math:beta, and radius \(r_i\). Let \(\ell_{i,k}\) be the LiDAR proximities for agent \(i\) and ray \(k\), and \(h_i = \sum_k \mathbf{1}[\ell_{i,k} > (\text{LIDAR_range} - 2r_i)]\) be the collision count. The rewards consists on:

\[\mathrm{rew}^{\text{shape}}_i = \alpha_{\text{prev}}\,d^{\text{prev}}_i - \alpha\, d_i\]
\[\mathrm{rew}_i = \mathrm{rew}^{\text{shape}}_i + R_f\,\mathbf{1}[\,d_i < \text{goal_threshold}\times r_i\,] + C\, h_i - \beta\, \overline{d},\]
\[\overline{d} = \tfrac{1}{N}\sum_j d_j\]
Parameters:

env (Environment) – Current environment.

Returns:

Shape (N,). The normalized per-agent reward vector.

Return type:

jax.Array

static step(env: Environment, action: Array) Environment[source][source]#

Advance one step. Actions are forces; simple drag is applied.

Parameters:
  • env (Environment) – The current environment.

  • action (jax.Array) – The vector of actions each agent in the environment should take.

Returns:

The updated environment state.

Return type:

Environment

class jaxdem.rl.environments.SingleNavigator(state: State, system: System, env_params: Dict[str, Any])[source]#

Bases: Environment

Single-agent navigation environment toward a fixed target.

classmethod Create(dim: int = 2, min_box_size: float = 1.0, max_box_size: float = 1.0, max_steps: int = 2000, final_reward: float = 2.0, shaping_factor: float = 0.0, prev_shaping_factor: float = 0.0, goal_threshold: float = 0.6666666666666666) SingleNavigator[source][source]#

Custom factory method for this environment.

property action_space_shape: Tuple[int][source]#

Original per-agent action shape (useful for reshaping inside the environment).

property action_space_size: int[source]#

Flattened action size per agent. Actions passed to step() have shape (A, action_space_size).

static done(env: Environment) Array[source][source]#

Returns a boolean indicating whether the environment has ended. The episode terminates when the maximum number of steps is reached.

Parameters:

env (Environment) – The current environment.

Returns:

Boolean array indicating whether the episode has ended.

Return type:

jax.Array

static observation(env: Environment) Array[source][source]#

Build per-agent observations.

Contents per agent#

  • Wrapped displacement to objective Δx (shape (2,)).

  • Velocity v (shape (2,)).

returns:

Array of shape (N, 2 * dim) scaled by the maximum box size for normalization.

rtype:

jax.Array

property observation_space_size: int[source]#

Flattened observation size per agent. observation() returns shape (A, observation_space_size).

static reset(env: Environment, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source][source]#

Initialize the environment with a randomly placed particle and velocity.

Parameters:
  • env (Environment) – Current environment instance.

  • key (jax.random.PRNGKey) – JAX random number generator key.

Returns:

Freshly initialized environment.

Return type:

Environment

static reward(env: Environment) Array[source][source]#

Returns a vector of per-agent rewards.

Equation

Let \(\delta_i=\operatorname{displacement}(\mathbf{x}_i,\mathbf{objective})\), \(d_i=\lVert\delta_i\rVert_2\), and \(\mathbf{1}[\cdot]\) the indicator. With shaping factors \(\alpha_{\text{prev}},\alpha\), final reward \(R_f\), and radius \(r_i\):

\[\mathrm{rew}^{\text{shape}}_i = \alpha_{\text{prev}}\,d^{\text{prev}}_i - \alpha\, d_i\]

Final reward:

\[\mathrm{rew}_i = \mathrm{rew}^{\text{shape}}_i + R_f\,\mathbf{1}[\,d_i < \text{goal_threshold}\times r_i\,]\]
Parameters:

env (Environment) – Current environment.

Returns:

Shape (N,). The normalized per-agent reward vector.

Return type:

jax.Array

static step(env: Environment, action: Array) Environment[source][source]#

Advance one step. Actions are forces; simple drag is applied.

Parameters:
  • env (Environment) – The current environment.

  • action (jax.Array) – The vector of actions each agent in the environment should take.

Returns:

The updated environment state.

Return type:

Environment

Modules

multi_navigator

Multi-agent navigation task with collision penalties.

single_navigator

Environment where a single agent navigates towards a target.