jaxdem.rl.environments#
Reinforcement-learning environment interface.
Classes
|
Defines the interface for reinforcement-learning environments. |
- class jaxdem.rl.environments.Environment(state: State, system: System, env_params: Dict[str, Any], max_num_agents: int = 0, action_space_size: int = 0, action_space_shape: Tuple[int, ...] = (), observation_space_size: int = 0)[source]#
Bases:
Factory
,ABC
Defines the interface for reinforcement-learning environments.
Let A be the number of agents (A ≥ 1). Single-agent environments still use A=1.
Observations and actions are flattened per agent to fixed sizes. Use
action_space_shape
to reshape inside the environment if needed.
Required shapes
Observation:
(A, observation_space_size)
Action (input to
step()
):(A, action_space_size)
Reward:
(A,)
Done: scalar boolean for the whole environment
TODO:
Truncated data field: per-agent termination flag
Render method
Example
To define a custom environment, inherit from
Environment
and implement the abstract methods:>>> @Environment.register("MyCustomEnv") >>> @jax.tree_util.register_dataclass >>> @dataclass(slots=True, frozen=True) >>> class MyCustomEnv(Environment): ...
- env_params: Dict[str, Any]#
Environment-specific parameters.
- max_num_agents: int#
Maximum number of active agents in the environment.
- action_space_size: int#
Flattened action size per agent. Actions passed to
step()
have shape(A, action_space_size)
.
- action_space_shape: Tuple[int, ...]#
Original per-agent action shape (useful for reshaping inside the environment).
- observation_space_size: int#
Flattened observation size per agent.
observation()
returns shape(A, observation_space_size)
.
- abstractmethod static reset(env: Environment, key: Array | ndarray | bool | number | bool | int | float | complex) Environment [source][source]#
Initialize the environment to a valid start state.
- Parameters:
env (Environment) – Instance of the environment.
key (jax.random.PRNGKey) – JAX random number generator key.
- Returns:
Freshly initialized environment.
- Return type:
- static reset_if_done(env: Environment, done: Array, key: Array | ndarray | bool | number | bool | int | float | complex) Environment [source][source]#
Conditionally resets the environment if the environment has reached a terminal state.
This method checks the done flag and, if True, calls the environment’s reset method to reinitialize the state. Otherwise, it returns the current environment unchanged.
- Parameters:
env (Environment) – The current environment instance.
done (jax.Array) – A boolean flag indicating whether the environment has reached a terminal state.
key (jax.random.PRNGKey) – JAX random number generator key used for reinitialization.
- Returns:
Either the freshly reset environment (if done is True) or the unchanged environment (if done is False).
- Return type:
- abstractmethod static step(env: Environment, action: Array) Environment [source][source]#
Advance the simulation by one step using per-agent actions.
- Parameters:
env (Environment) – The current environment.
action (jax.Array) – The vector of actions each agent in the environment should take.
- Returns:
The updated environment state.
- Return type:
- abstractmethod static observation(env: Environment) Array [source][source]#
Returns the per-agent observation vector.
- Parameters:
env (Environment) – The current environment.
- Returns:
Vector corresponding to the environment observation.
- Return type:
jax.Array
- abstractmethod static reward(env: Environment) Array [source][source]#
Returns the per-agent immediate rewards.
- Parameters:
env (Environment) – The current environment.
- Returns:
Vector corresponding to all the agent’s rewards based on the current environment state.
- Return type:
jax.Array
- abstractmethod static done(env: Environment) Array [source][source]#
Returns a boolean indicating whether the environment has ended.
- Parameters:
env (Environment) – The current environment.
- Returns:
A bool indicating when the environment ended
- Return type:
jax.Array
- static info(env: Environment) Dict[str, Any] [source][source]#
Return auxiliary diagnostic information.
By default, returns an empty dict. Subclasses may override to provide environment specific information.
- Parameters:
env (Environment) – The current state of the environment.
- Returns:
A dictionary with additional information about the environment.
- Return type:
Dict
Bases:
Environment
Multi-agent navigation environment with collision penalties.
Returns a boolean indicating whether the environment has ended. The episode terminates when the maximum number of steps is reached.
- Parameters:
env (Environment) – The current environment.
- Returns:
Boolean array indicating whether the episode has ended.
- Return type:
jax.Array
Returns the observation vector for each agent.
LiDAR bins store proximity values as
max(0, R - d_min)
; a value of 0 means no detection or that an object lies beyond the LiDAR range. The observation concatenates the displacement to the objective, the particle velocity, and the LiDAR readings normalized byR
.
Initialize the environment with randomly placed particles and velocities.
- Parameters:
env (Environment) – Current environment instance.
key (jax.random.PRNGKey) – JAX random number generator key.
- Returns:
Freshly initialized environment.
- Return type:
Returns a vector of per-agent rewards.
Equation
Let \(\delta_i=\operatorname{displacement}(\mathbf{x}_i,\mathbf{objective})\), \(d_i=\lVert\delta_i\rVert_2\), and \(\mathbf{1}[\cdot]\) the indicator. With shaping factor \(\alpha\), final reward \(R_f\), radius \(r_i\), previous reward \(\mathrm{rew}^{\text{prev}}_i\), collision-penalty coefficient \(C_\mathrm{col}\le 0\), LiDAR range \(R\), measured proximities \(\mathrm{prox}_{i,j}\), and safety factor \(\kappa=2.05\):
\[\mathrm{rew}^{\text{shape}}_i \;=\; \mathrm{rew}^{\text{prev}}_i \;-\; \alpha\, d_i\]Define per-beam “too close” hits using a distance threshold \(\tau_i = \max(0,\, R - \kappa\, r_i)\):
\[\mathrm{hit}_{i,j} \;=\; \mathbf{1}\!\left[\,\mathrm{prox}_{i,j} > \tau_i\,\right],\qquad n^{\text{hits}}_i \;=\; \sum_j \mathrm{hit}_{i,j}\]Total reward:
\[\mathrm{rew}_i \;=\; \mathrm{rew}^{\text{shape}}_i \;+\; R_f\,\mathbf{1}[\,d_i < r_i\,] \;+\; C_\mathrm{col}\, n^{\text{hits}}_i\]The function updates \(\mathrm{rew}^{\text{prev}}_i \leftarrow \mathrm{rew}^{\text{shape}}_i\) and returns \((\mathrm{rew}_i)_{i=1}^N\) reshaped to
(env.max_num_agents,)
.
Advance the simulation by one step. Actions are interpreted as accelerations.
- Parameters:
env (Environment) – The current environment.
action (jax.Array) – The vector of actions each agent in the environment should take.
- Returns:
The updated environment state.
- Return type:
Bases:
Environment
Single-agent navigation environment toward a fixed target.
Custom factory method for this environment.
Returns a boolean indicating whether the environment has ended. The episode terminates when the maximum number of steps is reached.
- Parameters:
env (Environment) – The current environment.
- Returns:
Boolean array indicating whether the episode has ended.
- Return type:
jax.Array
Returns the observation vector, which concatenates the displacement between the particle and the objective with the particle’s velocity.
- Parameters:
env (Environment) – The current environment.
- Returns:
Observation vector for the environment.
- Return type:
jax.Array
Initialize the environment with a randomly placed particle and velocity.
- Parameters:
env (Environment) – Current environment instance.
key (jax.random.PRNGKey) – JAX random number generator key.
- Returns:
Freshly initialized environment.
- Return type:
Returns a vector of per-agent rewards.
Equation
Let \(\delta_i = \mathrm{displacement}(\mathbf{x}_i, \mathbf{objective})\), \(d_i = \lVert \delta_i \rVert_2\), and \(\mathbf{1}[\cdot]\) the indicator. With shaping factor \(\alpha\), final reward \(R_f\), radius \(r_i\), and previous reward \(rew^{\text{prev}}_i\):
\[rew^{\text{shape}}_i \;=\; rew^{\text{prev}}_i \;-\; \alpha\, d_i\]\[rew_i \;=\; rew^{\text{shape}}_i \;+\; R_f \,\mathbf{1}[\,d_i < r_i\,]\]The function updates \(rew^{\text{prev}}_i \leftarrow rew^{\text{shape}}_i\)
- Parameters:
env (Environment) – Current environment.
Advance the simulation by one step. Actions are interpreted as accelerations.
- Parameters:
env (Environment) – The current environment.
action (jax.Array) – The vector of actions each agent in the environment should take.
- Returns:
The updated environment state.
- Return type:
Modules
Multi-agent navigation task with collision penalties. |
|
Environment where a single agent navigates towards a target. |