jaxdem.rl.environments#

Reinforcement-learning environment interface.

Classes

Environment(state, system, env_params[, ...])

Defines the interface for reinforcement-learning environments.

class jaxdem.rl.environments.Environment(state: State, system: System, env_params: Dict[str, Any], max_num_agents: int = 0, action_space_size: int = 0, action_space_shape: Tuple[int, ...] = (), observation_space_size: int = 0)[source]#

Bases: Factory, ABC

Defines the interface for reinforcement-learning environments.

  • Let A be the number of agents (A ≥ 1). Single-agent environments still use A=1.

  • Observations and actions are flattened per agent to fixed sizes. Use action_space_shape to reshape inside the environment if needed.

Required shapes

  • Observation: (A, observation_space_size)

  • Action (input to step()): (A, action_space_size)

  • Reward: (A,)

  • Done: scalar boolean for the whole environment

TODO:

  • Truncated data field: per-agent termination flag

  • Render method

Example

To define a custom environment, inherit from Environment and implement the abstract methods:

>>> @Environment.register("MyCustomEnv")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True, frozen=True)
>>> class MyCustomEnv(Environment):
    ...
state: State#

Simulation state.

system: System#

Simulation system configuration.

env_params: Dict[str, Any]#

Environment-specific parameters.

max_num_agents: int#

Maximum number of active agents in the environment.

action_space_size: int#

Flattened action size per agent. Actions passed to step() have shape (A, action_space_size).

action_space_shape: Tuple[int, ...]#

Original per-agent action shape (useful for reshaping inside the environment).

observation_space_size: int#

Flattened observation size per agent. observation() returns shape (A, observation_space_size).

abstractmethod static reset(env: Environment, key: Array | ndarray | bool | number | bool | int | float | complex) Environment[source][source]#

Initialize the environment to a valid start state.

Parameters:
  • env (Environment) – Instance of the environment.

  • key (jax.random.PRNGKey) – JAX random number generator key.

Returns:

Freshly initialized environment.

Return type:

Environment

static reset_if_done(env: Environment, done: Array, key: Array | ndarray | bool | number | bool | int | float | complex) Environment[source][source]#

Conditionally resets the environment if the environment has reached a terminal state.

This method checks the done flag and, if True, calls the environment’s reset method to reinitialize the state. Otherwise, it returns the current environment unchanged.

Parameters:
  • env (Environment) – The current environment instance.

  • done (jax.Array) – A boolean flag indicating whether the environment has reached a terminal state.

  • key (jax.random.PRNGKey) – JAX random number generator key used for reinitialization.

Returns:

Either the freshly reset environment (if done is True) or the unchanged environment (if done is False).

Return type:

Environment

abstractmethod static step(env: Environment, action: Array) Environment[source][source]#

Advance the simulation by one step using per-agent actions.

Parameters:
  • env (Environment) – The current environment.

  • action (jax.Array) – The vector of actions each agent in the environment should take.

Returns:

The updated environment state.

Return type:

Environment

abstractmethod static observation(env: Environment) Array[source][source]#

Returns the per-agent observation vector.

Parameters:

env (Environment) – The current environment.

Returns:

Vector corresponding to the environment observation.

Return type:

jax.Array

abstractmethod static reward(env: Environment) Array[source][source]#

Returns the per-agent immediate rewards.

Parameters:

env (Environment) – The current environment.

Returns:

Vector corresponding to all the agent’s rewards based on the current environment state.

Return type:

jax.Array

abstractmethod static done(env: Environment) Array[source][source]#

Returns a boolean indicating whether the environment has ended.

Parameters:

env (Environment) – The current environment.

Returns:

A bool indicating when the environment ended

Return type:

jax.Array

static info(env: Environment) Dict[str, Any][source][source]#

Return auxiliary diagnostic information.

By default, returns an empty dict. Subclasses may override to provide environment specific information.

Parameters:

env (Environment) – The current state of the environment.

Returns:

A dictionary with additional information about the environment.

Return type:

Dict

class jaxdem.rl.environments.MultiNavigator(state: State, system: System, env_params: Dict[str, Any], max_num_agents: int = 0, action_space_size: int = 0, action_space_shape: Tuple[int, ...] = (), observation_space_size: int = 0)[source]#

Bases: Environment

Multi-agent navigation environment with collision penalties.

classmethod Create(N: int = 2, min_box_size: float = 1.0, max_box_size: float = 2.0, max_steps: int = 5000, final_reward: float = 0.05, shaping_factor: float = 1.0, collision_penalty: float = -2.0, lidar_range: float = 0.35, n_lidar_rays: int = 12) MultiNavigator[source][source]#
static done(env: Environment) Array[source][source]#

Returns a boolean indicating whether the environment has ended. The episode terminates when the maximum number of steps is reached.

Parameters:

env (Environment) – The current environment.

Returns:

Boolean array indicating whether the episode has ended.

Return type:

jax.Array

static observation(env: Environment) Array[source][source]#

Returns the observation vector for each agent.

LiDAR bins store proximity values as max(0, R - d_min); a value of 0 means no detection or that an object lies beyond the LiDAR range. The observation concatenates the displacement to the objective, the particle velocity, and the LiDAR readings normalized by R.

classmethod registry_name() str[source]#
static reset(env: Environment, key: Array | ndarray | bool | number | bool | int | float | complex) Environment[source][source]#

Initialize the environment with randomly placed particles and velocities.

Parameters:
  • env (Environment) – Current environment instance.

  • key (jax.random.PRNGKey) – JAX random number generator key.

Returns:

Freshly initialized environment.

Return type:

Environment

static reward(env: Environment) Array[source][source]#

Returns a vector of per-agent rewards.

Equation

Let \(\delta_i=\operatorname{displacement}(\mathbf{x}_i,\mathbf{objective})\), \(d_i=\lVert\delta_i\rVert_2\), and \(\mathbf{1}[\cdot]\) the indicator. With shaping factor \(\alpha\), final reward \(R_f\), radius \(r_i\), previous reward \(\mathrm{rew}^{\text{prev}}_i\), collision-penalty coefficient \(C_\mathrm{col}\le 0\), LiDAR range \(R\), measured proximities \(\mathrm{prox}_{i,j}\), and safety factor \(\kappa=2.05\):

\[\mathrm{rew}^{\text{shape}}_i \;=\; \mathrm{rew}^{\text{prev}}_i \;-\; \alpha\, d_i\]

Define per-beam “too close” hits using a distance threshold \(\tau_i = \max(0,\, R - \kappa\, r_i)\):

\[\mathrm{hit}_{i,j} \;=\; \mathbf{1}\!\left[\,\mathrm{prox}_{i,j} > \tau_i\,\right],\qquad n^{\text{hits}}_i \;=\; \sum_j \mathrm{hit}_{i,j}\]

Total reward:

\[\mathrm{rew}_i \;=\; \mathrm{rew}^{\text{shape}}_i \;+\; R_f\,\mathbf{1}[\,d_i < r_i\,] \;+\; C_\mathrm{col}\, n^{\text{hits}}_i\]

The function updates \(\mathrm{rew}^{\text{prev}}_i \leftarrow \mathrm{rew}^{\text{shape}}_i\) and returns \((\mathrm{rew}_i)_{i=1}^N\) reshaped to (env.max_num_agents,).

static step(env: Environment, action: Array) Environment[source][source]#

Advance the simulation by one step. Actions are interpreted as accelerations.

Parameters:
  • env (Environment) – The current environment.

  • action (jax.Array) – The vector of actions each agent in the environment should take.

Returns:

The updated environment state.

Return type:

Environment

property type_name: str[source]#
class jaxdem.rl.environments.SingleNavigator(state: State, system: System, env_params: Dict[str, Any], max_num_agents: int = 0, action_space_size: int = 0, action_space_shape: Tuple[int, ...] = (), observation_space_size: int = 0)[source]#

Bases: Environment

Single-agent navigation environment toward a fixed target.

classmethod Create(dim: int = 2, min_box_size: float = 1.0, max_box_size: float = 2.0, max_steps: int = 2000, final_reward: float = 0.05, shaping_factor: float = 1.0) SingleNavigator[source][source]#

Custom factory method for this environment.

static done(env: Environment) Array[source][source]#

Returns a boolean indicating whether the environment has ended. The episode terminates when the maximum number of steps is reached.

Parameters:

env (Environment) – The current environment.

Returns:

Boolean array indicating whether the episode has ended.

Return type:

jax.Array

static observation(env: Environment) Array[source][source]#

Returns the observation vector, which concatenates the displacement between the particle and the objective with the particle’s velocity.

Parameters:

env (Environment) – The current environment.

Returns:

Observation vector for the environment.

Return type:

jax.Array

classmethod registry_name() str[source]#
static reset(env: Environment, key: Array | ndarray | bool | number | bool | int | float | complex) Environment[source][source]#

Initialize the environment with a randomly placed particle and velocity.

Parameters:
  • env (Environment) – Current environment instance.

  • key (jax.random.PRNGKey) – JAX random number generator key.

Returns:

Freshly initialized environment.

Return type:

Environment

static reward(env: Environment) Array[source][source]#

Returns a vector of per-agent rewards.

Equation

Let \(\delta_i = \mathrm{displacement}(\mathbf{x}_i, \mathbf{objective})\), \(d_i = \lVert \delta_i \rVert_2\), and \(\mathbf{1}[\cdot]\) the indicator. With shaping factor \(\alpha\), final reward \(R_f\), radius \(r_i\), and previous reward \(rew^{\text{prev}}_i\):

\[rew^{\text{shape}}_i \;=\; rew^{\text{prev}}_i \;-\; \alpha\, d_i\]
\[rew_i \;=\; rew^{\text{shape}}_i \;+\; R_f \,\mathbf{1}[\,d_i < r_i\,]\]

The function updates \(rew^{\text{prev}}_i \leftarrow rew^{\text{shape}}_i\)

Parameters:

env (Environment) – Current environment.

static step(env: Environment, action: Array) Environment[source][source]#

Advance the simulation by one step. Actions are interpreted as accelerations.

Parameters:
  • env (Environment) – The current environment.

  • action (jax.Array) – The vector of actions each agent in the environment should take.

Returns:

The updated environment state.

Return type:

Environment

property type_name: str[source]#

Modules

multi_navigator

Multi-agent navigation task with collision penalties.

single_navigator

Environment where a single agent navigates towards a target.