jaxdem.rl.environments#
Reinforcement-learning environment interface.
Classes
|
Defines the interface for reinforcement-learning environments. |
- class jaxdem.rl.environments.Environment(state: State, system: System, env_params: Dict[str, Any])[source]#
Bases:
Factory,ABCDefines the interface for reinforcement-learning environments.
Let A be the number of agents (A ≥ 1). Single-agent environments still use A=1.
Observations and actions are flattened per agent to fixed sizes. Use
action_space_shapeto reshape inside the environment if needed.
Required shapes
Observation:
(A, observation_space_size)Action (input to
step()):(A, action_space_size)Reward:
(A,)Done: scalar boolean for the whole environment
TODO:
Truncated data field: per-agent termination flag
Render method
Example
To define a custom environment, inherit from
Environmentand implement the abstract methods:>>> @Environment.register("MyCustomEnv") >>> @jax.tree_util.register_dataclass >>> @dataclass(slots=True) >>> class MyCustomEnv(Environment): ...
- env_params: Dict[str, Any]#
Environment-specific parameters.
- abstractmethod static reset(env: Environment, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source][source]#
Initialize the environment to a valid start state.
- Parameters:
env (Environment) – Instance of the environment.
key (jax.random.PRNGKey) – JAX random number generator key.
- Returns:
Freshly initialized environment.
- Return type:
- static reset_if_done(env: Environment, done: Array, key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Environment[source][source]#
Conditionally resets the environment if the environment has reached a terminal state.
This method checks the done flag and, if True, calls the environment’s reset method to reinitialize the state. Otherwise, it returns the current environment unchanged.
- Parameters:
env (Environment) – The current environment instance.
done (jax.Array) – A boolean flag indicating whether the environment has reached a terminal state.
key (jax.random.PRNGKey) – JAX random number generator key used for reinitialization.
- Returns:
Either the freshly reset environment (if done is True) or the unchanged environment (if done is False).
- Return type:
- abstractmethod static step(env: Environment, action: Array) Environment[source][source]#
Advance the simulation by one step using per-agent actions.
- Parameters:
env (Environment) – The current environment.
action (jax.Array) – The vector of actions each agent in the environment should take.
- Returns:
The updated environment state.
- Return type:
- abstractmethod static observation(env: Environment) Array[source][source]#
Returns the per-agent observation vector.
- Parameters:
env (Environment) – The current environment.
- Returns:
Vector corresponding to the environment observation.
- Return type:
jax.Array
- abstractmethod static reward(env: Environment) Array[source][source]#
Returns the per-agent immediate rewards.
- Parameters:
env (Environment) – The current environment.
- Returns:
Vector corresponding to all the agent’s rewards based on the current environment state.
- Return type:
jax.Array
- abstractmethod static done(env: Environment) Array[source][source]#
Returns a boolean indicating whether the environment has ended.
- Parameters:
env (Environment) – The current environment.
- Returns:
A bool indicating when the environment ended
- Return type:
jax.Array
- static info(env: Environment) Dict[str, Any][source][source]#
Return auxiliary diagnostic information.
By default, returns an empty dict. Subclasses may override to provide environment specific information.
- Parameters:
env (Environment) – The current state of the environment.
- Returns:
A dictionary with additional information about the environment.
- Return type:
Dict
- property action_space_size: int[source]#
Flattened action size per agent. Actions passed to
step()have shape(A, action_space_size).
- property action_space_shape: Tuple[int][source]#
Original per-agent action shape (useful for reshaping inside the environment).
- property observation_space_size: int[source]#
Flattened observation size per agent.
observation()returns shape(A, observation_space_size).
Bases:
EnvironmentMulti-agent navigation environment with collision penalties.
Agents seek fixed objectives in a 2D reflective box. Each step applies a force-like action, advances simple dynamics, updates LiDAR, and returns shaped rewards with an optional final bonus on goal.
Number of lidar rays for the vision system.
Original per-agent action shape (useful for reshaping inside the environment).
Flattened action size per agent. Actions passed to
step()have shape(A, action_space_size).
Returns a boolean indicating whether the environment has ended. The episode terminates when the maximum number of steps is reached.
- Parameters:
env (Environment) – The current environment.
- Returns:
Boolean array indicating whether the episode has ended.
- Return type:
jax.Array
Build per-agent observations.
Contents per agent#
Wrapped displacement to objective
Δx(shape(2,)).Velocity
v(shape(2,)).LiDAR proximities (shape
(n_lidar_rays,)).
- returns:
Array of shape
(N, 2 * dim + n_lidar_rays)scaled by the maximum box size for normalization.- rtype:
jax.Array
Flattened observation size per agent.
observation()returns shape(A, observation_space_size).
Initialize the environment with randomly placed particles and velocities.
- Parameters:
env (Environment) – Current environment instance.
key (jax.random.PRNGKey) – JAX random number generator key.
- Returns:
Freshly initialized environment.
- Return type:
Per-agent reward with distance shaping, goal bonus, LiDAR collision penalty, and a global shaping term.
Equations
Let \(\delta_i=\operatorname{displacement}(\mathbf{x}_i,\mathbf{objective})\), \(d_i=\lVert\delta_i\rVert_2\), and \(\mathbf{1}[\cdot]\) the indicator. With shaping factors \(\alpha_{\text{prev}},\alpha\), final reward \(R_f\), collision penalty math:C, global shaping factor math:beta, and radius \(r_i\). Let \(\ell_{i,k}\) be the LiDAR proximities for agent \(i\) and ray \(k\), and \(h_i = \sum_k \mathbf{1}[\ell_{i,k} > (\text{LIDAR_range} - 2r_i)]\) be the collision count. The rewards consists on:
\[\mathrm{rew}^{\text{shape}}_i = \alpha_{\text{prev}}\,d^{\text{prev}}_i - \alpha\, d_i\]\[\mathrm{rew}_i = \mathrm{rew}^{\text{shape}}_i + R_f\,\mathbf{1}[\,d_i < \text{goal_threshold}\times r_i\,] + C\, h_i - \beta\, \overline{d},\]\[\overline{d} = \tfrac{1}{N}\sum_j d_j\]- Parameters:
env (Environment) – Current environment.
- Returns:
Shape
(N,). The normalized per-agent reward vector.- Return type:
jax.Array
Advance one step. Actions are forces; simple drag is applied.
- Parameters:
env (Environment) – The current environment.
action (jax.Array) – The vector of actions each agent in the environment should take.
- Returns:
The updated environment state.
- Return type:
Bases:
EnvironmentSingle-agent navigation environment toward a fixed target.
Custom factory method for this environment.
Original per-agent action shape (useful for reshaping inside the environment).
Flattened action size per agent. Actions passed to
step()have shape(A, action_space_size).
Returns a boolean indicating whether the environment has ended. The episode terminates when the maximum number of steps is reached.
- Parameters:
env (Environment) – The current environment.
- Returns:
Boolean array indicating whether the episode has ended.
- Return type:
jax.Array
Build per-agent observations.
Contents per agent#
Wrapped displacement to objective
Δx(shape(2,)).Velocity
v(shape(2,)).
- returns:
Array of shape
(N, 2 * dim)scaled by the maximum box size for normalization.- rtype:
jax.Array
Flattened observation size per agent.
observation()returns shape(A, observation_space_size).
Initialize the environment with a randomly placed particle and velocity.
- Parameters:
env (Environment) – Current environment instance.
key (jax.random.PRNGKey) – JAX random number generator key.
- Returns:
Freshly initialized environment.
- Return type:
Returns a vector of per-agent rewards.
Equation
Let \(\delta_i=\operatorname{displacement}(\mathbf{x}_i,\mathbf{objective})\), \(d_i=\lVert\delta_i\rVert_2\), and \(\mathbf{1}[\cdot]\) the indicator. With shaping factors \(\alpha_{\text{prev}},\alpha\), final reward \(R_f\), and radius \(r_i\):
\[\mathrm{rew}^{\text{shape}}_i = \alpha_{\text{prev}}\,d^{\text{prev}}_i - \alpha\, d_i\]Final reward:
\[\mathrm{rew}_i = \mathrm{rew}^{\text{shape}}_i + R_f\,\mathbf{1}[\,d_i < \text{goal_threshold}\times r_i\,]\]- Parameters:
env (Environment) – Current environment.
- Returns:
Shape
(N,). The normalized per-agent reward vector.- Return type:
jax.Array
Advance one step. Actions are forces; simple drag is applied.
- Parameters:
env (Environment) – The current environment.
action (jax.Array) – The vector of actions each agent in the environment should take.
- Returns:
The updated environment state.
- Return type:
Modules
Multi-agent navigation task with collision penalties. |
|
Environment where a single agent navigates towards a target. |