Source code for jaxdem.rl.environments.multi_navigator

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""Multi-agent navigation task with collision penalties."""

from __future__ import annotations

import jax
import jax.numpy as jnp
from jax.typing import ArrayLike

from dataclasses import dataclass, field
from functools import partial
from typing import Tuple

from . import Environment
from ...state import State
from ...system import System
from ...utils import lidar
from ...materials import MaterialTable, Material
from ...material_matchmakers import MaterialMatchmaker


@partial(jax.jit, static_argnames=("N",))
@partial(jax.named_call, name="multi_navigator._sample_objectives")
def _sample_objectives(key: ArrayLike, N: int, box: jax.Array, rad: float) -> jax.Array:
    i = jax.lax.iota(int, N)  # 0..N-1
    Lx, Ly = box.astype(float)

    nx = jnp.ceil(jnp.sqrt(N * Lx / Ly)).astype(int)
    ny = jnp.ceil(N / nx).astype(int)

    ix = jnp.mod(i, nx)
    iy = i // nx

    dx = Lx / nx
    dy = Ly / ny

    xs = (ix + 0.5) * dx
    ys = (iy + 0.5) * dy
    base = jnp.stack([xs, ys], axis=1)

    noise = jax.random.uniform(key, (N, 2), minval=-1.0, maxval=1.0) * jnp.asarray(
        [jnp.maximum(0.0, dx / 2 - rad), jnp.maximum(0.0, dy / 2 - rad)]
    )
    return base + noise


[docs] @Environment.register("multiNavigator") @jax.tree_util.register_dataclass @dataclass(slots=True) class MultiNavigator(Environment): """ Multi-agent navigation environment with collision penalties. Agents seek fixed objectives in a 2D reflective box. Each step applies a force-like action, advances simple dynamics, updates LiDAR, and returns shaped rewards with an optional final bonus on goal. """ n_lidar_rays: int = field(metadata={"static": True}) """ Number of lidar rays for the vision system. """
[docs] @classmethod @partial(jax.named_call, name="MultiNavigator.Create") def Create( cls, N: int = 64, min_box_size: float = 1.0, max_box_size: float = 1.0, box_padding: float = 5.0, max_steps: int = 5760, final_reward: float = 1.0, # 1.0 shaping_factor: float = 0.005, prev_shaping_factor: float = 0.0, global_shaping_factor: float = 0.0, collision_penalty: float = -0.005, goal_threshold: float = 2 / 3, lidar_range: float = 0.45, n_lidar_rays: int = 16, ) -> "MultiNavigator": dim = 2 state = State.create(pos=jnp.zeros((N, dim))) system = System.create(state.shape) env_params = dict( objective=jnp.zeros_like(state.pos), min_box_size=jnp.asarray(min_box_size, dtype=float), max_box_size=jnp.asarray(max_box_size, dtype=float), box_padding=jnp.asarray(box_padding, dtype=float), max_steps=jnp.asarray(max_steps, dtype=int), final_reward=jnp.asarray(final_reward, dtype=float), collision_penalty=jnp.asarray(collision_penalty, dtype=float), shaping_factor=jnp.asarray(shaping_factor, dtype=float), prev_shaping_factor=jnp.asarray(prev_shaping_factor, dtype=float), global_shaping_factor=jnp.asarray(global_shaping_factor, dtype=float), goal_threshold=jnp.asarray(goal_threshold, dtype=float), prev_rew=jnp.zeros_like(state.rad), lidar_range=jnp.asarray(lidar_range, dtype=float), lidar=jnp.zeros((state.N, int(n_lidar_rays)), dtype=float), goal_scale=jnp.asarray(1.0, dtype=float), objective_index=jnp.zeros_like(state.rad, dtype=int), ) return cls( state=state, system=system, env_params=env_params, n_lidar_rays=int(n_lidar_rays), )
[docs] @staticmethod @partial(jax.jit, donate_argnames=("env",)) @partial(jax.named_call, name="MultiNavigator.reset") def reset(env: "Environment", key: ArrayLike) -> "Environment": """ Initialize the environment with randomly placed particles and velocities. Parameters ---------- env: Environment Current environment instance. key : jax.random.PRNGKey JAX random number generator key. Returns ------- Environment Freshly initialized environment. """ root = key key_box = jax.random.fold_in(root, jnp.uint32(0)) key_pos = jax.random.fold_in(root, jnp.uint32(1)) key_objective = jax.random.fold_in(root, jnp.uint32(2)) key_shuffle = jax.random.fold_in(root, jnp.uint32(3)) key_vel = jax.random.fold_in(root, jnp.uint32(4)) N = env.max_num_agents dim = env.state.dim box = jax.random.uniform( key_box, (dim,), minval=env.env_params["min_box_size"], maxval=env.env_params["max_box_size"], dtype=float, ) rad = 0.05 pos = ( _sample_objectives( key_pos, int(N), box + env.env_params["box_padding"] * rad, rad ) - env.env_params["box_padding"] * rad / 2 ) objective = _sample_objectives(key_objective, int(N), box, rad) env.env_params["goal_scale"] = jnp.max(box) base_idx = jnp.arange(N, dtype=int) perm = jax.random.permutation(key_shuffle, base_idx) env.env_params["objective"] = objective[perm] env.env_params["objective_index"] = perm vel = jax.random.uniform( key_vel, (N, dim), minval=-0.1, maxval=0.1, dtype=float ) Rad = rad * jnp.ones(N) env.state = State.create(pos=pos, vel=vel, rad=Rad) matcher = MaterialMatchmaker.create("harmonic") mat_table = MaterialTable.from_materials( [Material.create("elastic", density=0.27, young=6e3, poisson=0.3)], matcher=matcher, ) env.system = System.create( env.state.shape, dt=0.004, domain_type="reflect", domain_kw=dict( box_size=box + env.env_params["box_padding"] * rad, anchor=jnp.zeros_like(box) - env.env_params["box_padding"] * rad / 2, ), mat_table=mat_table, ) delta = env.system.domain.displacement( env.state.pos, env.env_params["objective"], env.system ) d = jnp.vecdot(delta, delta) env.env_params["prev_rew"] = jnp.sqrt(d) / env.env_params["goal_scale"] env.env_params["lidar"] = lidar(env) return env
[docs] @staticmethod @partial(jax.jit, donate_argnames=("env",)) @partial(jax.named_call, name="MultiNavigator.step") def step(env: "Environment", action: jax.Array) -> "Environment": """ Advance one step. Actions are forces; simple drag is applied. Parameters ---------- env : Environment The current environment. action : jax.Array The vector of actions each agent in the environment should take. Returns ------- Environment The updated environment state. """ force = ( action.reshape(env.max_num_agents, *env.action_space_shape) - jnp.sign(env.state.vel) * 0.08 ) env.system = env.system.force_manager.add_force(env.state, env.system, force) delta = env.system.domain.displacement( env.state.pos, env.env_params["objective"], env.system ) d = jnp.vecdot(delta, delta) env.env_params["prev_rew"] = jnp.sqrt(d) / env.env_params["goal_scale"] env.state, env.system = env.system.step(env.state, env.system) env.env_params["lidar"] = lidar(env) return env
[docs] @staticmethod @jax.jit @partial(jax.named_call, name="MultiNavigator.observation") def observation(env: "Environment") -> jax.Array: """ Build per-agent observations. Contents per agent ------------------ - Wrapped displacement to objective ``Δx`` (shape ``(2,)``). - Velocity ``v`` (shape ``(2,)``). - LiDAR proximities (shape ``(n_lidar_rays,)``). Returns ------- jax.Array Array of shape ``(N, 2 * dim + n_lidar_rays)`` scaled by the maximum box size for normalization. """ return jnp.concatenate( [ env.system.domain.displacement( env.env_params["objective"], env.state.pos, env.system ) / env.env_params["goal_scale"], env.state.vel, env.env_params["lidar"] / env.env_params["lidar_range"], ], axis=-1, )
[docs] @staticmethod @jax.jit @partial(jax.named_call, name="MultiNavigator.reward") def reward(env: "Environment") -> jax.Array: r""" Per-agent reward with distance shaping, goal bonus, LiDAR collision penalty, and a global shaping term. **Equations** Let :math:`\delta_i=\operatorname{displacement}(\mathbf{x}_i,\mathbf{objective})`, :math:`d_i=\lVert\delta_i\rVert_2`, and :math:`\mathbf{1}[\cdot]` the indicator. With shaping factors :math:`\alpha_{\text{prev}},\alpha`, final reward :math:`R_f`, collision penalty math:`C`, global shaping factor math:`\beta`, and radius :math:`r_i`. Let :math:`\ell_{i,k}` be the LiDAR proximities for agent :math:`i` and ray :math:`k`, and :math:`h_i = \sum_k \mathbf{1}[\ell_{i,k} > (\text{LIDAR_range} - 2r_i)]` be the collision count. The rewards consists on: .. math:: \mathrm{rew}^{\text{shape}}_i = \alpha_{\text{prev}}\,d^{\text{prev}}_i - \alpha\, d_i .. math:: \mathrm{rew}_i = \mathrm{rew}^{\text{shape}}_i + R_f\,\mathbf{1}[\,d_i < \text{goal_threshold}\times r_i\,] + C\, h_i - \beta\, \overline{d}, .. math:: \overline{d} = \tfrac{1}{N}\sum_j d_j Parameters ----------- env : Environment Current environment. Returns ------- jax.Array Shape ``(N,)``. The normalized per-agent reward vector. """ delta = env.system.domain.displacement( env.state.pos, env.env_params["objective"], env.system ) d = jnp.vecdot(delta, delta) d = jnp.sqrt(d) / env.env_params["goal_scale"] rew = ( env.env_params["prev_shaping_factor"] * env.env_params["prev_rew"] - env.env_params["shaping_factor"] * d ) closeness_thresh = jnp.maximum( 0.0, env.env_params["lidar_range"] - 2.0 * env.state.rad[:, None] ) n_hits = ( (env.env_params["lidar"] > closeness_thresh).sum(axis=-1).astype(rew.dtype) ) on_goal = d < env.env_params["goal_threshold"] * env.state.rad reward = ( rew + env.env_params["final_reward"] * on_goal + env.env_params["collision_penalty"] * n_hits ) reward += jnp.mean(reward) * env.env_params["global_shaping_factor"] return reward.reshape(env.max_num_agents)
[docs] @staticmethod @partial(jax.jit, inline=True) @partial(jax.named_call, name="MultiNavigator.done") def done(env: "Environment") -> jax.Array: """ Returns a boolean indicating whether the environment has ended. The episode terminates when the maximum number of steps is reached. Parameters ---------- env : Environment The current environment. Returns ------- jax.Array Boolean array indicating whether the episode has ended. """ return jnp.asarray(env.system.step_count > env.env_params["max_steps"])
@property def action_space_size(self) -> int: """ Flattened action size per agent. Actions passed to :meth:`step` have shape ``(A, action_space_size)``. """ return self.state.dim @property def action_space_shape(self) -> Tuple[int]: """ Original per-agent action shape (useful for reshaping inside the environment). """ return (self.state.dim,) @property def observation_space_size(self) -> int: """ Flattened observation size per agent. :meth:`observation` returns shape ``(A, observation_space_size)``. """ return 2 * self.state.dim + self.n_lidar_rays
__all__ = ["MultiNavigator"]