Source code for jaxdem.rl.environments.multi_roller

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project - https://github.com/cdelv/JaxDEM
"""Environment where multiple agents roll towards targets on a 3D floor."""

from __future__ import annotations

import jax
import jax.numpy as jnp
from jax.typing import ArrayLike

from dataclasses import dataclass, field
from functools import partial
from typing import Tuple

from . import Environment
from ...state import State
from ...system import System
from ...utils import lidar
from ...utils.linalg import cross, unit
from ...materials import MaterialTable, Material
from ...material_matchmakers import MaterialMatchmaker


@partial(jax.jit, static_argnames=("N",))
@partial(jax.named_call, name="multi_roller._sample_objectives_3d")
def _sample_objectives_3d(
    key: ArrayLike, N: int, box: jax.Array, rad: float
) -> jax.Array:
    """Samples points on the X-Z plane (floor) for 3D positioning."""
    i = jax.lax.iota(int, N)
    # Use X and Z dimensions for the floor grid
    Lx, Lz = box[0], box[2]

    nx = jnp.ceil(jnp.sqrt(N * Lx / Lz)).astype(int)
    nz = jnp.ceil(N / nx).astype(int)

    ix = jnp.mod(i, nx)
    iz = i // nx

    dx = Lx / nx
    dz = Lz / nz

    xs = (ix + 0.5) * dx
    zs = (iz + 0.5) * dz

    # Y is fixed at rad (on the floor)
    ys = jnp.full_like(xs, rad)

    base = jnp.stack([xs, ys, zs], axis=1)

    noise = jax.random.uniform(key, (N, 3), minval=-1.0, maxval=1.0)
    noise_scale = jnp.asarray(
        [
            jnp.maximum(0.0, dx / 2 - rad),
            0.0,  # No noise in Y
            jnp.maximum(0.0, dz / 2 - rad),
        ]
    )

    return base + noise * noise_scale


@partial(jax.named_call, name="multi_roller.frictional_wall_force")
def frictional_wall_force(
    pos: jax.Array, state: State, system: System
) -> Tuple[jax.Array, jax.Array]:
    """Calculates normal and frictional forces for spheres on a y=0 plane."""
    k = 1e5  # Normal stiffness
    mu = 0.4  # Friction coefficient
    n = jnp.array([0.0, 1.0, 0.0])
    p = jnp.array([0.0, 0.0, 0.0])

    # 1. Normal Force Calculation
    dist = jnp.dot(pos - p, n) - state.rad
    penetration = jnp.minimum(0.0, dist)
    force_n = -k * penetration[..., None] * n

    # 2. Velocity at the contact point
    # radius_vector points from center to contact point: -rad * n
    radius_vec = -state.rad[..., None] * n
    v_at_contact = state.vel + cross(state.angVel, radius_vec)

    # Tangential velocity component
    v_n = jnp.sum(v_at_contact * n, axis=-1, keepdims=True) * n
    v_t = v_at_contact - v_n

    # 3. Friction Force (Coulomb approximation)
    f_t_mag = mu * jnp.linalg.norm(force_n, axis=-1, keepdims=True)
    t_dir = unit(v_t)
    force_t = -f_t_mag * t_dir

    # 4. Total Force and Torque
    total_force = force_n + force_t
    total_torque = cross(radius_vec, force_t)

    return total_force, total_torque


[docs] @Environment.register("multiRoller") @jax.tree_util.register_dataclass @dataclass(slots=True) class MultiRoller(Environment): """ Multi-agent 3D rolling environment. Agents are spheres that roll on a floor. They are controlled via 3D torque vectors. Includes collision handling, LiDAR sensing, and distance-based reward shaping. """ n_lidar_rays: int = field(metadata={"static": True})
[docs] @classmethod @partial(jax.named_call, name="MultiRoller.Create") def Create( cls, N: int = 64, min_box_size: float = 1.0, max_box_size: float = 1.0, box_padding: float = 5.0, max_steps: int = 5760, final_reward: float = 1.0, # 1.0 shaping_factor: float = 0.005, prev_shaping_factor: float = 0.0, global_shaping_factor: float = 0.0, collision_penalty: float = -0.005, goal_threshold: float = 2 / 3, lidar_range: float = 0.45, n_lidar_rays: int = 16, ) -> MultiRoller: dim = 3 state = State.create(pos=jnp.zeros((N, dim))) system = System.create(state.shape) env_params = dict( objective=jnp.zeros_like(state.pos), min_box_size=jnp.asarray(min_box_size, dtype=float), max_box_size=jnp.asarray(max_box_size, dtype=float), box_padding=jnp.asarray(box_padding, dtype=float), max_steps=jnp.asarray(max_steps, dtype=int), final_reward=jnp.asarray(final_reward, dtype=float), collision_penalty=jnp.asarray(collision_penalty, dtype=float), shaping_factor=jnp.asarray(shaping_factor, dtype=float), prev_shaping_factor=jnp.asarray(prev_shaping_factor, dtype=float), goal_threshold=jnp.asarray(goal_threshold, dtype=float), prev_rew=jnp.zeros_like(state.rad), lidar_range=jnp.asarray(lidar_range, dtype=float), lidar=jnp.zeros((state.N, int(n_lidar_rays)), dtype=float), goal_scale=jnp.asarray(1.0, dtype=float), ) return cls( state=state, system=system, env_params=env_params, n_lidar_rays=int(n_lidar_rays), )
[docs] @staticmethod @partial(jax.jit, donate_argnames=("env",)) @partial(jax.named_call, name="MultiRoller.reset") def reset(env: Environment, key: ArrayLike) -> Environment: """ Initialize the environment with randomly placed particles. Parameters ---------- env: Environment Current environment instance. key : jax.random.PRNGKey JAX random number generator key. Returns ------- Environment Freshly initialized environment. """ key_box, key_pos, key_objective, key_shuffle, key_vel = jax.random.split(key, 5) N = env.max_num_agents dim = 3 box = jax.random.uniform( key_box, (dim,), minval=env.env_params["min_box_size"], maxval=env.env_params["max_box_size"], dtype=float, ) rad_val = 0.05 padding = env.env_params["box_padding"] * rad_val pos = ( _sample_objectives_3d(key_pos, int(N), box + padding, rad_val) - padding / 2 ) pos = pos.at[:, 1].set(rad_val) objective = _sample_objectives_3d(key_objective, int(N), box, rad_val) objective = objective.at[:, 1].set(rad_val) env.env_params["goal_scale"] = jnp.max(box) perm = jax.random.permutation(key_shuffle, jnp.arange(N)) env.env_params["objective"] = objective[perm] env.env_params["objective_index"] = perm rads = rad_val * jnp.ones(N) env.state = State.create(pos=pos, rad=rads) matcher = MaterialMatchmaker.create("harmonic") mat_table = MaterialTable.from_materials( [Material.create("elastic", density=0.27, young=6e3, poisson=0.3)], matcher=matcher, ) env.system = System.create( env.state.shape, dt=0.004, domain_type="reflect", domain_kw=dict( box_size=box + padding, anchor=jnp.zeros_like(box) - padding / 2, ), force_manager_kw=dict( gravity=[0.0, -10.0, 0.0], force_functions=(frictional_wall_force,), ), mat_table=mat_table, ) delta = env.system.domain.displacement( env.state.pos, env.env_params["objective"], env.system ) env.env_params["prev_rew"] = jnp.linalg.norm(delta, axis=-1) env.env_params["lidar"] = lidar(env) return env
[docs] @staticmethod @partial(jax.jit, donate_argnames=("env",)) @partial(jax.named_call, name="MultiRoller.step") def step(env: Environment, action: jax.Array) -> Environment: torque = action.reshape(env.max_num_agents, 3) force_drag = -0.08 * env.state.vel torque_drag = -0.05 * env.state.angVel env.system = env.system.force_manager.add_force( env.state, env.system, force_drag ) env.system = env.system.force_manager.add_torque( env.state, env.system, torque + torque_drag ) delta = env.system.domain.displacement( env.state.pos, env.env_params["objective"], env.system ) env.env_params["prev_rew"] = jnp.linalg.norm(delta, axis=-1) env.state, env.system = env.system.step(env.state, env.system) env.env_params["lidar"] = lidar(env) return env
[docs] @staticmethod @jax.jit @partial(jax.named_call, name="MultiRoller.observation") def observation(env: Environment) -> jax.Array: disp = ( env.system.domain.displacement( env.env_params["objective"], env.state.pos, env.system ) / env.env_params["goal_scale"] ) return jnp.concatenate( [ disp, env.state.vel, env.state.angVel, env.env_params["lidar"] / env.env_params["lidar_range"], ], axis=-1, )
[docs] @staticmethod @jax.jit @partial(jax.named_call, name="MultiRoller.reward") def reward(env: Environment) -> jax.Array: delta = env.system.domain.displacement( env.state.pos, env.env_params["objective"], env.system ) d = jnp.linalg.norm(delta, axis=-1) rew_shape = ( env.env_params["prev_shaping_factor"] * env.env_params["prev_rew"] - env.env_params["shaping_factor"] * d ) closeness_thresh = jnp.maximum( 0.0, env.env_params["lidar_range"] - 2.0 * env.state.rad[:, None] ) n_hits = (env.env_params["lidar"] > closeness_thresh).sum(axis=-1).astype(float) on_goal = d < env.env_params["goal_threshold"] * env.state.rad reward = ( rew_shape + env.env_params["final_reward"] * on_goal + env.env_params["collision_penalty"] * n_hits ) return reward.reshape(env.max_num_agents)
[docs] @staticmethod @partial(jax.jit, inline=True) @partial(jax.named_call, name="MultiRoller.done") def done(env: Environment) -> jax.Array: return jnp.asarray(env.system.step_count > env.env_params["max_steps"])
@property def action_space_size(self) -> int: return 3 # 3D Torque @property def action_space_shape(self) -> Tuple[int]: return (3,) @property def observation_space_size(self) -> int: # Disp(3) + Vel(3) + AngVel(3) + Lidar(rays) return 9 + self.n_lidar_rays
__all__ = ["MultiRoller"]