Source code for jaxdem.rl.environments

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""Reinforcement-learning environment interface."""

from __future__ import annotations

import jax
from jax.typing import ArrayLike

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Tuple, Type

from ...factory import Factory

if TYPE_CHECKING:  # pragma: no cover
    from ...state import State
    from ...system import System


[docs] @jax.tree_util.register_dataclass @dataclass(slots=True, frozen=True) class Environment(Factory, ABC): """ Defines the interface for reinforcement-learning environments. - Let **A** be the number of agents (A ≥ 1). Single-agent environments still use A=1. - Observations and actions are flattened per agent to fixed sizes. Use ``action_space_shape`` to reshape inside the environment if needed. **Required shapes** - Observation: ``(A, observation_space_size)`` - Action (input to :meth:`step`): ``(A, action_space_size)`` - Reward: ``(A,)`` - Done: scalar boolean for the whole environment TODO: - Truncated data field: per-agent termination flag - Render method Example ------- To define a custom environment, inherit from :class:`Environment` and implement the abstract methods: >>> @Environment.register("MyCustomEnv") >>> @jax.tree_util.register_dataclass >>> @dataclass(slots=True, frozen=True) >>> class MyCustomEnv(Environment): ... """ state: "State" """ Simulation state. """ system: "System" """ Simulation system configuration. """ env_params: Dict[str, Any] """ Environment-specific parameters. """ max_num_agents: int = field(default=0, metadata={"static": True}) """ Maximum number of active agents in the environment. """ action_space_size: int = field(default=0, metadata={"static": True}) """ Flattened action size per agent. Actions passed to :meth:`step` have shape ``(A, action_space_size)``. """ action_space_shape: Tuple[int, ...] = field(default=(), metadata={"static": True}) """ Original per-agent action shape (useful for reshaping inside the environment). """ observation_space_size: int = field(default=0, metadata={"static": True}) """ Flattened observation size per agent. :meth:`observation` returns shape ``(A, observation_space_size)``. """ _base_env_cls: ClassVar[Type["Environment"]]
[docs] @staticmethod @abstractmethod @jax.jit def reset(env: "Environment", key: ArrayLike) -> "Environment": """ Initialize the environment to a valid start state. Parameters ---------- env: Environment Instance of the environment. key : jax.random.PRNGKey JAX random number generator key. Returns ------- Environment Freshly initialized environment. """ raise NotImplementedError
[docs] @staticmethod @jax.jit def reset_if_done( env: "Environment", done: jax.Array, key: ArrayLike ) -> "Environment": """ Conditionally resets the environment if the environment has reached a terminal state. This method checks the `done` flag and, if `True`, calls the environment's `reset` method to reinitialize the state. Otherwise, it returns the current environment unchanged. Parameters ---------- env : Environment The current environment instance. done : jax.Array A boolean flag indicating whether the environment has reached a terminal state. key : jax.random.PRNGKey JAX random number generator key used for reinitialization. Returns ------- Environment Either the freshly reset environment (if `done` is True) or the unchanged environment (if `done` is False). """ base_cls = getattr(env.__class__, "_base_env_cls", env.__class__) return jax.lax.cond( done, lambda _: base_cls.reset(env, key), lambda _: env, operand=None, )
[docs] @staticmethod @abstractmethod @jax.jit def step(env: "Environment", action: jax.Array) -> "Environment": """ Advance the simulation by one step using **per-agent** actions. Parameters ---------- env : Environment The current environment. action : jax.Array The vector of actions each agent in the environment should take. Returns ------- Environment The updated environment state. """ raise NotImplementedError
[docs] @staticmethod @abstractmethod @jax.jit def observation(env: "Environment") -> jax.Array: """ Returns the per-agent observation vector. Parameters ---------- env : Environment The current environment. Returns ------- jax.Array Vector corresponding to the environment observation. """ raise NotImplementedError
[docs] @staticmethod @abstractmethod @jax.jit def reward(env: "Environment") -> jax.Array: """ Returns the per-agent immediate rewards. Parameters ---------- env : Environment The current environment. Returns ------- jax.Array Vector corresponding to all the agent's rewards based on the current environment state. """ raise NotImplementedError
[docs] @staticmethod @abstractmethod @jax.jit def done(env: "Environment") -> jax.Array: """ Returns a boolean indicating whether the environment has ended. Parameters ---------- env : Environment The current environment. Returns ------- jax.Array A bool indicating when the environment ended """ raise NotImplementedError
[docs] @staticmethod @jax.jit def info(env: "Environment") -> Dict[str, Any]: """ Return auxiliary diagnostic information. By default, returns an empty dict. Subclasses may override to provide environment specific information. Parameters ---------- env : Environment The current state of the environment. Returns ------- Dict A dictionary with additional information about the environment. """ return dict()
from .multi_navigator import MultiNavigator from .single_navigator import SingleNavigator __all__ = ["Environment", "MultiNavigator", "SingleNavigator"]