Source code for jaxdem.rl.environments.multi_navigator

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""Multi-agent navigation task with collision penalties."""

from __future__ import annotations

import jax
import jax.numpy as jnp
from jax.typing import ArrayLike
from jax import ShapeDtypeStruct

from dataclasses import dataclass, replace
from functools import partial

import numpy as np
from scipy.stats import qmc

from . import Environment
from ...state import State
from ...system import System


def PoissonDisk(
    N: int,
    dim: int,
    rad: float,
    l_bounds: jax.Array,
    u_bounds: jax.Array,
    key: ArrayLike,
):
    numpy_seed = int(jax.random.randint(key, (), 0, jnp.iinfo(jnp.int32).max))

    sampler = qmc.PoissonDisk(
        d=int(dim),
        radius=2 * float(rad),
        l_bounds=np.asarray(l_bounds, dtype=float) + float(rad),
        u_bounds=np.asarray(u_bounds, dtype=float) - float(rad),
        seed=int(numpy_seed),
        ncandidates=2000,
    )

    pts = jnp.asarray(sampler.random(N), dtype=float)
    m = int(pts.shape[0])
    if m != N:
        raise RuntimeError(
            "Could not place requested number of points without overlap: "
            f"requested N={N}, placed {m}. Try reducing the radius, increasing the box, or decreasing N."
        )

    return pts


[docs] @Environment.register("multiNavigator") @jax.tree_util.register_dataclass @dataclass(slots=True, frozen=True) class MultiNavigator(Environment): """Multi-agent navigation environment with collision penalties."""
[docs] @classmethod def Create( cls, N: int = 2, min_box_size: float = 1.0, max_box_size: float = 2.0, max_steps: int = 5000, final_reward: float = 0.05, shaping_factor: float = 1.0, collision_penalty: float = -2.0, lidar_range: float = 0.35, n_lidar_rays: int = 12, ) -> "MultiNavigator": dim = 2 state = State.create(pos=jnp.zeros((N, dim))) system = System.create(dim) n_lidar_rays = int(n_lidar_rays) env_params = dict( objective=jnp.zeros_like(state.pos), min_box_size=jnp.asarray(min_box_size, dtype=float), max_box_size=jnp.asarray(max_box_size, dtype=float), max_steps=jnp.asarray(max_steps, dtype=int), final_reward=jnp.asarray(final_reward, dtype=float), collision_penalty=jnp.asarray(collision_penalty, dtype=float), shaping_factor=jnp.asarray(shaping_factor, dtype=float), prev_rew=jnp.zeros_like(state.rad), lidar_range=jnp.asarray(lidar_range, dtype=float), n_lidar_rays=n_lidar_rays, lidar=jnp.zeros((N, n_lidar_rays), dtype=float), ) action_space_size = dim action_space_shape = (dim,) observation_space_size = 2 * dim + n_lidar_rays return cls( state=state, system=system, env_params=env_params, max_num_agents=N, action_space_size=action_space_size, action_space_shape=action_space_shape, observation_space_size=observation_space_size, )
[docs] @staticmethod @partial(jax.jit, donate_argnames=("env",)) def reset(env: "Environment", key: ArrayLike) -> "Environment": """ Initialize the environment with randomly placed particles and velocities. Parameters ---------- env: Environment Current environment instance. key : jax.random.PRNGKey JAX random number generator key. Returns ------- Environment Freshly initialized environment. """ key, key_pos, key_vel, key_box, key_objective = jax.random.split(key, 5) N = env.max_num_agents dim = env.state.dim box = jax.random.uniform( key_box, (dim,), minval=env.env_params["min_box_size"], maxval=env.env_params["max_box_size"], dtype=float, ) rad = 0.05 result_spec = ShapeDtypeStruct((N, dim), env.state.pos.dtype) env.env_params["objective"] = jax.pure_callback( PoissonDisk, result_spec, N, dim, rad, jnp.zeros_like(box), box, key_objective, vmap_method="sequential", ) pos = jax.pure_callback( PoissonDisk, result_spec, N, dim, rad, jnp.zeros_like(box), box, key_pos, vmap_method="sequential", ) vel = jax.random.uniform( key_vel, (N, dim), minval=-0.1, maxval=0.1, dtype=float ) rad = rad * jnp.ones(N) state = State.create(pos=pos, vel=vel, rad=rad) system = System.create( env.state.dim, domain_type="reflect", domain_kw=dict(box_size=box, anchor=jnp.zeros_like(box)), ) env = replace(env, state=state, system=system) env.env_params["prev_rew"] = jnp.zeros_like(env.state.rad) return env
[docs] @staticmethod @partial(jax.jit, donate_argnames=("env", "action")) def step(env: "Environment", action: jax.Array) -> "Environment": """ Advance the simulation by one step. Actions are interpreted as accelerations. Parameters ---------- env : Environment The current environment. action : jax.Array The vector of actions each agent in the environment should take. Returns ------- Environment The updated environment state. """ a = action.reshape(env.max_num_agents, *env.action_space_shape) state = replace(env.state, accel=a - jnp.sign(env.state.vel) * 0.08) state, system = env.system.step(state, env.system) env = replace(env, state=state, system=system) return env
[docs] @staticmethod @jax.jit def observation(env: "Environment") -> jax.Array: """ Returns the observation vector for each agent. LiDAR bins store proximity values as ``max(0, R - d_min)``; a value of 0 means no detection or that an object lies beyond the LiDAR range. The observation concatenates the displacement to the objective, the particle velocity, and the LiDAR readings normalized by ``R``. """ nbins = env.env_params["lidar"].shape[-1] R = env.env_params["lidar_range"] indices = jax.lax.iota(int, env.max_num_agents) pos = env.state.pos vel = env.state.vel obj = env.env_params["objective"] def lidar_for_i(i: jax.Array) -> jax.Array: rij = jax.vmap( lambda j: env.system.domain.displacement(pos[i], pos[j], env.system) )(indices) r = jnp.linalg.norm(rij, axis=-1) r = r.at[i].set(jnp.inf) theta = jnp.arctan2(rij[..., 1], rij[..., 0]) bins = jnp.floor((theta + jnp.pi) * (nbins / (2.0 * jnp.pi))).astype(int) d_in = jnp.where(r < R, r, jnp.inf) d_bins = jnp.full((nbins,), jnp.inf, dtype=pos.dtype).at[bins].min(d_in) proximity = jnp.where( jnp.isfinite(d_bins), jnp.maximum(0.0, R - d_bins), 0.0 ) return proximity lidar = jax.vmap(lidar_for_i)(indices) env.env_params["lidar"] = lidar obs = jnp.concatenate([obj - pos, vel, lidar], axis=-1) return obs / R
[docs] @staticmethod @jax.jit def reward(env: "Environment") -> jax.Array: r""" Returns a vector of per-agent rewards. **Equation** Let :math:`\delta_i=\operatorname{displacement}(\mathbf{x}_i,\mathbf{objective})`, :math:`d_i=\lVert\delta_i\rVert_2`, and :math:`\mathbf{1}[\cdot]` the indicator. With shaping factor :math:`\alpha`, final reward :math:`R_f`, radius :math:`r_i`, previous reward :math:`\mathrm{rew}^{\text{prev}}_i`, collision-penalty coefficient :math:`C_\mathrm{col}\le 0`, LiDAR range :math:`R`, measured proximities :math:`\mathrm{prox}_{i,j}`, and safety factor :math:`\kappa=2.05`: .. math:: \mathrm{rew}^{\text{shape}}_i \;=\; \mathrm{rew}^{\text{prev}}_i \;-\; \alpha\, d_i Define per-beam “too close” hits using a distance threshold :math:`\tau_i = \max(0,\, R - \kappa\, r_i)`: .. math:: \mathrm{hit}_{i,j} \;=\; \mathbf{1}\!\left[\,\mathrm{prox}_{i,j} > \tau_i\,\right],\qquad n^{\text{hits}}_i \;=\; \sum_j \mathrm{hit}_{i,j} Total reward: .. math:: \mathrm{rew}_i \;=\; \mathrm{rew}^{\text{shape}}_i \;+\; R_f\,\mathbf{1}[\,d_i < r_i\,] \;+\; C_\mathrm{col}\, n^{\text{hits}}_i The function updates :math:`\mathrm{rew}^{\text{prev}}_i \leftarrow \mathrm{rew}^{\text{shape}}_i` and returns :math:`(\mathrm{rew}_i)_{i=1}^N` reshaped to ``(env.max_num_agents,)``. """ pos = env.state.pos objective = env.env_params["objective"] delta = env.system.domain.displacement(pos, objective, env.system) d = jnp.linalg.norm(delta, axis=-1) on_goal = d < env.state.rad rew = env.env_params["prev_rew"] - d * env.env_params["shaping_factor"] env.env_params["prev_rew"] = rew prox = env.env_params["lidar"] R = env.env_params["lidar_range"] two_r = 2.05 * env.state.rad[:, None] closeness_thresh = jnp.maximum(0.0, R - two_r) hits = prox > closeness_thresh n_hits = hits.sum(axis=-1).astype(rew.dtype) reward = ( rew + env.env_params["final_reward"] * on_goal + env.env_params["collision_penalty"] * n_hits ) return reward.reshape(env.max_num_agents)
[docs] @staticmethod @jax.jit def done(env: "Environment") -> jax.Array: """ Returns a boolean indicating whether the environment has ended. The episode terminates when the maximum number of steps is reached. Parameters ---------- env : Environment The current environment. Returns ------- jax.Array Boolean array indicating whether the episode has ended. """ return jnp.asarray(env.system.step_count > env.env_params["max_steps"])
__all__ = ["MultiNavigator"]