Source code for jaxdem.rl.environments.single_navigator

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""Environment where a single agent navigates towards a target."""

from __future__ import annotations

from dataclasses import dataclass, replace
from functools import partial
import jax
import jax.numpy as jnp
from jax.typing import ArrayLike

from . import Environment
from ...state import State
from ...system import System


[docs] @Environment.register("singleNavigator") @jax.tree_util.register_dataclass @dataclass(slots=True, frozen=True) class SingleNavigator(Environment): """Single-agent navigation environment toward a fixed target."""
[docs] @classmethod def Create( cls, dim: int = 2, min_box_size: float = 1.0, max_box_size: float = 2.0, max_steps: int = 2000, final_reward: float = 0.05, shaping_factor: float = 1.0, ) -> "SingleNavigator": """ Custom factory method for this environment. """ N = 1 state = State.create(pos=jnp.zeros((N, dim))) system = System.create(dim) env_params = dict( objective=jnp.zeros_like(state.pos), min_box_size=jnp.asarray(min_box_size, dtype=float), max_box_size=jnp.asarray(max_box_size, dtype=float), max_steps=jnp.asarray(max_steps, dtype=int), final_reward=jnp.asarray(final_reward, dtype=float), shaping_factor=jnp.asarray(shaping_factor, dtype=float), prev_rew=jnp.zeros_like(state.rad), ) action_space_size = dim action_space_shape = (dim,) observation_space_size = 2 * dim return cls( state=state, system=system, env_params=env_params, max_num_agents=N, action_space_size=action_space_size, action_space_shape=action_space_shape, observation_space_size=observation_space_size, )
[docs] @staticmethod @partial(jax.jit, donate_argnames=("env",)) def reset(env: "Environment", key: ArrayLike) -> "Environment": """ Initialize the environment with a randomly placed particle and velocity. Parameters ---------- env: Environment Current environment instance. key : jax.random.PRNGKey JAX random number generator key. Returns ------- Environment Freshly initialized environment. """ key, key_pos, key_vel, key_box, key_objective = jax.random.split(key, 5) N = env.max_num_agents dim = env.state.dim box = jax.random.uniform( key_box, (dim,), minval=env.env_params["min_box_size"], maxval=env.env_params["max_box_size"], dtype=float, ) rad = 0.05 min_pos = rad * jnp.ones_like(box) pos = jax.random.uniform( key_pos, (N, dim), minval=min_pos, maxval=box - min_pos, dtype=float, ) objective = jax.random.uniform( key_objective, (N, dim), minval=min_pos, maxval=box - min_pos, dtype=float, ) env.env_params["objective"] = objective vel = jax.random.uniform( key_vel, (N, dim), minval=-0.1, maxval=0.1, dtype=float ) rad = rad * jnp.ones(N) state = State.create(pos=pos, vel=vel, rad=rad) system = System.create( env.state.dim, domain_type="reflect", domain_kw=dict(box_size=box, anchor=jnp.zeros_like(box)), ) env = replace(env, state=state, system=system) env.env_params["prev_rew"] = jnp.zeros_like(env.state.rad) return env
[docs] @staticmethod @partial(jax.jit, donate_argnames=("env", "action")) def step(env: "Environment", action: jax.Array) -> "Environment": """ Advance the simulation by one step. Actions are interpreted as accelerations. Parameters ---------- env : Environment The current environment. action : jax.Array The vector of actions each agent in the environment should take. Returns ------- Environment The updated environment state. """ a = action.reshape(env.max_num_agents, *env.action_space_shape) state = replace(env.state, accel=a - jnp.sign(env.state.vel) * 0.08) state, system = env.system.step(state, env.system) return replace(env, state=state, system=system)
[docs] @staticmethod @jax.jit def observation(env: "Environment") -> jax.Array: """ Returns the observation vector, which concatenates the displacement between the particle and the objective with the particle's velocity. Parameters ---------- env : Environment The current environment. Returns ------- jax.Array Observation vector for the environment. """ return jnp.concatenate( [ env.system.domain.displacement( env.state.pos, env.env_params["objective"], env.system ), env.state.vel, ], axis=-1, )
[docs] @staticmethod @jax.jit def reward(env: "Environment") -> jax.Array: r""" Returns a vector of per-agent rewards. **Equation** Let :math:`\delta_i = \mathrm{displacement}(\mathbf{x}_i, \mathbf{objective})`, :math:`d_i = \lVert \delta_i \rVert_2`, and :math:`\mathbf{1}[\cdot]` the indicator. With shaping factor :math:`\alpha`, final reward :math:`R_f`, radius :math:`r_i`, and previous reward :math:`rew^{\text{prev}}_i`: .. math:: rew^{\text{shape}}_i \;=\; rew^{\text{prev}}_i \;-\; \alpha\, d_i .. math:: rew_i \;=\; rew^{\text{shape}}_i \;+\; R_f \,\mathbf{1}[\,d_i < r_i\,] The function updates :math:`rew^{\text{prev}}_i \leftarrow rew^{\text{shape}}_i` Parameters ---------- env : Environment Current environment. """ pos = env.state.pos objective = env.env_params["objective"] delta = env.system.domain.displacement(pos, objective, env.system) d = jnp.linalg.norm(delta, axis=-1) on_goal = d < env.state.rad rew = env.env_params["prev_rew"] - d * env.env_params["shaping_factor"] env.env_params["prev_rew"] = rew reward = rew + env.env_params["final_reward"] * on_goal return jnp.asarray(reward).reshape(env.max_num_agents)
[docs] @staticmethod @jax.jit def done(env: "Environment") -> jax.Array: """ Returns a boolean indicating whether the environment has ended. The episode terminates when the maximum number of steps is reached. Parameters ---------- env : Environment The current environment. Returns ------- jax.Array Boolean array indicating whether the episode has ended. """ return jnp.asarray(env.system.step_count > env.env_params["max_steps"])
__all__ = ["SingleNavigator"]