jaxdem.rl.environments.single_navigator#

Environment where a single agent navigates towards a target.

Classes

SingleNavigator(state, system, env_params[, ...])

Single-agent navigation environment toward a fixed target.

class jaxdem.rl.environments.single_navigator.SingleNavigator(state: State, system: System, env_params: Dict[str, Any], max_num_agents: int = 0, action_space_size: int = 0, action_space_shape: Tuple[int, ...] = (), observation_space_size: int = 0)[source]#

Bases: Environment

Single-agent navigation environment toward a fixed target.

classmethod Create(dim: int = 2, min_box_size: float = 1.0, max_box_size: float = 2.0, max_steps: int = 2000, final_reward: float = 0.05, shaping_factor: float = 1.0) SingleNavigator[source][source]#

Custom factory method for this environment.

static reset(env: Environment, key: Array | ndarray | bool | number | bool | int | float | complex) Environment[source][source]#

Initialize the environment with a randomly placed particle and velocity.

Parameters:
  • env (Environment) – Current environment instance.

  • key (jax.random.PRNGKey) – JAX random number generator key.

Returns:

Freshly initialized environment.

Return type:

Environment

static step(env: Environment, action: Array) Environment[source][source]#

Advance the simulation by one step. Actions are interpreted as accelerations.

Parameters:
  • env (Environment) – The current environment.

  • action (jax.Array) – The vector of actions each agent in the environment should take.

Returns:

The updated environment state.

Return type:

Environment

static observation(env: Environment) Array[source][source]#

Returns the observation vector, which concatenates the displacement between the particle and the objective with the particle’s velocity.

Parameters:

env (Environment) – The current environment.

Returns:

Observation vector for the environment.

Return type:

jax.Array

static reward(env: Environment) Array[source][source]#

Returns a vector of per-agent rewards.

Equation

Let \(\delta_i = \mathrm{displacement}(\mathbf{x}_i, \mathbf{objective})\), \(d_i = \lVert \delta_i \rVert_2\), and \(\mathbf{1}[\cdot]\) the indicator. With shaping factor \(\alpha\), final reward \(R_f\), radius \(r_i\), and previous reward \(rew^{\text{prev}}_i\):

\[rew^{\text{shape}}_i \;=\; rew^{\text{prev}}_i \;-\; \alpha\, d_i\]
\[rew_i \;=\; rew^{\text{shape}}_i \;+\; R_f \,\mathbf{1}[\,d_i < r_i\,]\]

The function updates \(rew^{\text{prev}}_i \leftarrow rew^{\text{shape}}_i\)

Parameters:

env (Environment) – Current environment.

static done(env: Environment) Array[source][source]#

Returns a boolean indicating whether the environment has ended. The episode terminates when the maximum number of steps is reached.

Parameters:

env (Environment) – The current environment.

Returns:

Boolean array indicating whether the episode has ended.

Return type:

jax.Array

classmethod registry_name() str[source]#
property type_name: str[source]#