Source code for jaxdem.utils.environment

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
Utility functions to handle environments.
"""

from __future__ import annotations

import jax
import jax.numpy as jnp

from typing import TYPE_CHECKING, Callable, Tuple, Any
from functools import partial

if TYPE_CHECKING:
    from ..rl.environments import Environment


[docs] @partial(jax.jit, static_argnames=("model", "n", "stride")) @partial(jax.named_call, name="utils.env_trajectory_rollout") def env_trajectory_rollout( env: "Environment", model: Callable, key: jax.Array, *, n: int, stride: int = 1, **kw: Any, ) -> Tuple["Environment", "Environment"]: """ Roll out a trajectory by applying `model` in chunks of `stride` steps and collecting the environment after each chunk. Parameters ---------- env : Environment Initial environment pytree. model : Callable Callable with signature `model(obs, key, **kw) -> action`. n : int Number of chunks to roll out. Total internal steps = `n * stride`. stride : int Steps per chunk between recorded snapshots. **kw : Any Extra keyword arguments passed to `model` on every step. Returns ------- Environment Environment after `n * stride` steps. Environment Stacked pytree of environments with length `n`, each snapshot taken after a chunk of `stride` steps. Examples -------- >>> env, traj = env_trajectory_rollout(env, model, n=100, stride=5, objective=goal) """ def body(carry, _): env, key = carry key, subkey = jax.random.split(key) env = env_step(env, model, subkey, n=stride, **kw) return (env, key), env (env, key), env_traj = jax.lax.scan(body, (env, key), length=n, xs=None) return env, env_traj
[docs] @partial(jax.jit, static_argnames=("model", "n")) @partial(jax.named_call, name="utils.env_step") def env_step( env: "Environment", model: Callable, key: jax.Array, *, n: int = 1, **kw: Any ) -> "Environment": """ Advance the environment `n` steps using actions from `model`. Parameters ---------- env : Environment Initial environment pytree (batchable). model : Callable Callable with signature `model(obs, key, **kw) -> action`. n : int Number of steps to perform. **kw : Any Extra keyword arguments forwarded to `model`. Returns ------- Environment Environment after `n` steps. Examples -------- >>> env = env_step(env, model, n=10, objective=goal) """ def body(carry, _): env, key = carry key, subkey = jax.random.split(key) env = _env_step(env, model, subkey, **kw) return (env, key), None (env, key), _ = jax.lax.scan(body, (env, key), length=n, xs=None) return env
@partial(jax.jit, static_argnames=("model",)) @partial(jax.named_call, name="utils._env_step") def _env_step( env: "Environment", model: Callable, key: jax.Array, **kw: Any ) -> "Environment": """ Single environment step driven by `model`. Parameters ---------- env : Environment Current environment pytree. model : Callable Callable with signature `model(obs, key, **kw) -> action`. **kw : Any Extra keyword arguments passed to `model`. Returns ------- Environment Updated environment after applying `env.step(env, action)`. """ obs = env.observation(env) action = model(obs, key, **kw) env = env.step(env, action) return env
[docs] @jax.jit @partial(jax.named_call, name="utils.lidar") def lidar(env: "Environment") -> jax.Array: nbins = env.n_lidar_rays indices = jax.lax.iota(int, env.max_num_agents) def lidar_for_i(i: jax.Array) -> jax.Array: rij = jax.vmap( lambda j: env.system.domain.displacement( env.state.pos[i], env.state.pos[j], env.system ) )(indices) r = jnp.vecdot(rij, rij) r = jnp.sqrt(r) r = r.at[i].set(jnp.inf) theta = jnp.arctan2(rij[..., 1], rij[..., 0]) bins = jnp.floor((theta + jnp.pi) * (nbins / (2.0 * jnp.pi))).astype(int) d_in = jnp.where(r < env.env_params["lidar_range"], r, jnp.inf) d_bins = ( jnp.full((nbins,), jnp.inf, dtype=env.state.pos.dtype).at[bins].min(d_in) ) proximity = jnp.where( jnp.isfinite(d_bins), jnp.maximum(0.0, env.env_params["lidar_range"] - d_bins), 0.0, ) return proximity return jax.vmap(lidar_for_i)(indices)
__all__ = ["env_trajectory_rollout", "env_step", "lidar"]