Source code for jaxdem.rl.environments.single_roller

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project - https://github.com/cdelv/JaxDEM
"""Environment where a single agent rolls towards a target on the floor."""

from __future__ import annotations

import jax
import jax.numpy as jnp
from jax.typing import ArrayLike

from dataclasses import dataclass
from functools import partial
from typing import Tuple

from . import Environment
from ...state import State
from ...system import System
from ...utils.linalg import cross, unit


@partial(jax.named_call, name="single_roller.frictional_wall_force")
def frictional_wall_force(
    pos: jax.Array, state: State, system: System
) -> Tuple[jax.Array, jax.Array]:
    """Calculates normal and frictional forces for a sphere on a y=0 plane."""
    k = 1e5  # Normal stiffness
    mu = 0.4  # Friction coefficient
    n = jnp.array([0.0, 1.0, 0.0])
    p = jnp.array([0.0, 0.0, 0.0])

    # 1. Normal Force Calculation
    dist = jnp.dot(pos - p, n) - state.rad
    penetration = jnp.minimum(0.0, dist)
    force_n = -k * penetration[..., None] * n

    # 2. Velocity at the contact point
    # radius_vector points from center to contact point: -rad * n
    radius_vec = -state.rad[..., None] * n
    v_at_contact = state.vel + cross(state.angVel, radius_vec)

    # Tangential velocity component
    v_n = jnp.sum(v_at_contact * n, axis=-1, keepdims=True) * n
    v_t = v_at_contact - v_n

    # 3. Friction Force (Coulomb approximation)
    f_t_mag = mu * jnp.linalg.norm(force_n, axis=-1, keepdims=True)
    t_dir = unit(v_t)
    force_t = -f_t_mag * t_dir

    # 4. Total Force and Torque
    total_force = force_n + force_t
    total_torque = cross(radius_vec, force_t)

    return total_force, total_torque


[docs] @Environment.register("singleRoller3D") @jax.tree_util.register_dataclass @dataclass(slots=True) class SingleRoller3D(Environment): """Single-agent navigation where the agent rolls on a plane using torque control."""
[docs] @classmethod @partial(jax.named_call, name="SingleRoller3D.Create") def Create( cls, min_box_size: float = 1.0, max_box_size: float = 1.0, max_steps: int = 6000, final_reward: float = 2.0, shaping_factor: float = 1.0, prev_shaping_factor: float = 1.0, goal_threshold: float = 2 / 3, ) -> SingleRoller3D: dim = 3 N = 1 state = State.create(pos=jnp.zeros((N, dim))) system = System.create(state.shape) env_params = dict( objective=jnp.zeros_like(state.pos), min_box_size=jnp.asarray(min_box_size, dtype=float), max_box_size=jnp.asarray(max_box_size, dtype=float), max_steps=jnp.asarray(max_steps, dtype=int), final_reward=jnp.asarray(final_reward, dtype=float), shaping_factor=jnp.asarray(shaping_factor, dtype=float), prev_shaping_factor=jnp.asarray(prev_shaping_factor, dtype=float), goal_threshold=jnp.asarray(goal_threshold, dtype=float), prev_rew=jnp.zeros_like(state.rad), ) return cls( state=state, system=system, env_params=env_params, )
[docs] @staticmethod @partial(jax.jit, donate_argnames=("env",)) @partial(jax.named_call, name="SingleRoller3D.reset") def reset(env: Environment, key: ArrayLike) -> Environment: key_box, key_pos, key_objective, key_vel = jax.random.split(key, 4) N = env.max_num_agents dim = env.state.dim rad_val = 0.05 box = jax.random.uniform( key_box, (dim,), minval=env.env_params["min_box_size"], maxval=env.env_params["max_box_size"], dtype=float, ) # Ensure agent and objective are placed on the floor (y = rad) min_pos = rad_val * jnp.ones_like(box) pos = jax.random.uniform( key_pos, (N, dim), minval=min_pos, maxval=box - min_pos, dtype=float, ) pos = pos.at[:, 1].set(rad_val) objective = jax.random.uniform( key_objective, (N, dim), minval=min_pos, maxval=box - min_pos, dtype=float, ) objective = objective.at[:, 1].set(rad_val) env.env_params["objective"] = objective vel = jax.random.uniform( key_vel, (N, dim), minval=-0.05, maxval=0.05, dtype=float ) rad = rad_val * jnp.ones(N) env.state = State.create(pos=pos, vel=vel, rad=rad) # Initialize system with gravity and the frictional wall force function env.system = System.create( env.state.shape, domain_type="reflect", domain_kw=dict(box_size=box, anchor=[0, -4 * rad_val, 0]), force_manager_kw=dict( gravity=[0.0, -10.0, 0.0], force_functions=(frictional_wall_force,), ), ) env.env_params["prev_rew"] = jnp.zeros_like(env.state.rad) return env
[docs] @staticmethod @partial(jax.jit, donate_argnames=("env",)) @partial(jax.named_call, name="SingleRoller3D.step") def step(env: Environment, action: jax.Array) -> Environment: torque = action.reshape(env.max_num_agents, 3) - 0.05 * env.state.angVel force = -0.08 * env.state.vel env.system = env.system.force_manager.add_force(env.state, env.system, force) env.system = env.system.force_manager.add_torque(env.state, env.system, torque) # Update reward tracking before physics step delta = env.system.domain.displacement( env.state.pos, env.env_params["objective"], env.system ) env.env_params["prev_rew"] = jnp.linalg.norm(delta, axis=-1) # Physics integration env.state, env.system = env.system.step(env.state, env.system) return env
[docs] @staticmethod @jax.jit @partial(jax.named_call, name="SingleRoller3D.observation") def observation(env: Environment) -> jax.Array: # Include angular velocity in observations for better control of rolling return jnp.concatenate( [ env.system.domain.displacement( env.state.pos, env.env_params["objective"], env.system ), env.state.vel, env.state.angVel, ], axis=-1, ) / jnp.max(env.system.domain.box_size)
[docs] @staticmethod @jax.jit @partial(jax.named_call, name="SingleRoller3D.reward") def reward(env: Environment) -> jax.Array: delta = env.system.domain.displacement( env.state.pos, env.env_params["objective"], env.system ) d = jnp.linalg.norm(delta, axis=-1) on_goal = d < env.env_params["goal_threshold"] * env.state.rad rew = ( env.env_params["prev_shaping_factor"] * env.env_params["prev_rew"] - env.env_params["shaping_factor"] * d ) reward = rew + env.env_params["final_reward"] * on_goal return reward.reshape(env.max_num_agents)
[docs] @staticmethod @partial(jax.jit, inline=True) @partial(jax.named_call, name="SingleRoller3D.done") def done(env: Environment) -> jax.Array: return jnp.asarray(env.system.step_count > env.env_params["max_steps"])
@property def action_space_size(self) -> int: return 3 # 3D Torque vector @property def action_space_shape(self) -> Tuple[int]: return (3,) @property def observation_space_size(self) -> int: return 3 * self.state.dim # Disp(3) + Vel(3) + AngVel(3)
__all__ = ["SingleRoller3D"]