jaxdem.rl.environments.single_navigator#
Environment where a single agent navigates towards a target.
Classes
|
Single-agent navigation environment toward a fixed target. |
- class jaxdem.rl.environments.single_navigator.SingleNavigator(state: State, system: System, env_params: dict[str, Any])#
Bases:
EnvironmentSingle-agent navigation environment toward a fixed target.
The agent controls a force vector that is applied directly to a sphere inside a reflective box. Viscous drag
-friction * velis added each step. The reward uses exponential potential-based shaping:\[\mathrm{rew}_t = (e^{-2 \cdot d_t} - e^{-2 \cdot d_t^{\mathrm{prev}}}) - w_{\text{ke}} (K_t - K_{t-1})\]where \(d_t\) is the distance to the objective at step \(t\), \(K_t\) is the kinetic energy at step \(t\), and \(w_{\text{ke}}\) is the weight for the kinetic energy penalty.
Notes
The observation vector per agent is:
Feature
Size
Unit direction to objective
dimClamped displacement
dimVelocity
dimIf one wants some realistic parameters for training,
skip_frames = 50will give a response rate of 200 Hz, meaning thatnum_steps_epoch = 100gives a horizon of 0.5 seconds.- classmethod Create(dim: int = 2, min_box_size: float = 40.0, max_box_size: float = 40.0, max_steps: int = 20000, friction: float = 0.2, ke_weight: float = 0.1) → SingleNavigator[source]#
Create a single-agent navigator environment.
- Parameters:
dim (int) – Spatial dimensionality (2 or 3).
min_box_size (float) – Range for the random square domain side length.
max_box_size (float) – Range for the random square domain side length.
max_steps (int) – Episode length in physics steps.
friction (float) – Viscous drag coefficient applied as
-friction * vel.ke_weight (float) – Weight for the differential kinetic energy penalty.
- Returns:
A freshly constructed environment (call
reset()before use).- Return type:
- static reset(env: SingleNavigator, key: Array | ndarray | bool | number | bool | int | float | complex) → Environment[source]#
Initialize the environment with a randomly placed particle and velocity.
- Parameters:
env ('SingleNavigator') – Current environment instance.
key (jax.random.PRNGKey) – JAX random number generator key.
- Returns:
Freshly initialized environment.
- Return type:
- static step(env: SingleNavigator, action: Array) → Environment[source]#
Advance one step. Actions are forces; simple drag is applied (-friction * vel).
- Parameters:
env (Environment) – The current environment.
action (jax.Array) – The vector of actions each agent in the environment should take.
- Returns:
The updated environment state.
- Return type:
- static observation(env: SingleNavigator) → Array[source]#
Build per-agent observations.
Contents per agent#
Unit vector to objective (shape (dim,)) –> Direction
Clamped delta to objective (shape (dim,)) –> Local precision
Velocity (shape (dim,))
- returns:
Array of shape
(N, 3 * dim)- rtype:
jax.Array
- static reward(env: SingleNavigator) → Array[source]#
Returns a vector of per-agent rewards.
Reward:
\[\mathrm{rew}_t = (e^{-2 \cdot d_t} - e^{-2 \cdot d_t^{\mathrm{prev}}}) - w_{\text{ke}} (K_t - K_{t-1})\]where \(d_t\) is the distance to the objective at step \(t\), \(K_t\) is the kinetic energy at step \(t\), and \(w_{\text{ke}}\) is the weight for the kinetic energy penalty.
- Parameters:
env (Environment) – Current environment.
- Returns:
Shape
(N,).- Return type:
jax.Array
- static done(env: SingleNavigator) → Array[source]#
Returns a boolean indicating whether the environment has ended. The episode terminates when the maximum number of steps is reached.
- Parameters:
env (Environment) – The current environment.
- Returns:
Boolean array indicating whether the episode has ended.
- Return type:
jax.Array
- property action_space_size: int[source]#
Flattened action size per agent. Actions passed to
step()have shape(A, action_space_size).
- property action_space_shape: tuple[int][source]#
Original per-agent action shape (useful for reshaping inside the environment).
- property observation_space_size: int[source]#
Flattened observation size per agent.
observation()returns shape(A, observation_space_size).