Source code for jaxdem.rl.envWrappers
# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
Contains wrappers for modifying rl environments.
"""
from __future__ import annotations
import jax
import jax.numpy as jnp
from dataclasses import dataclass, fields
from typing import Callable, Type
from ..environments import Environment
def _wrap_env(
env: "Environment", method_transform: Callable, prefix: str = "Wrapped"
) -> "Environment":
"""
Internal helper to create a new environment subclass with transformed
static methods.
Parameters
----------
env : Environment
The environment instance to wrap.
method_transform : Callable
A function (name: str, func: callable) -> callable
that returns the transformed function for each staticmethod.
Returns
-------
Environment
A new environment instance with transformed static methods.
"""
cls = env.__class__
name_space: dict[str, object] = {}
for name, attr in cls.__dict__.items():
if isinstance(attr, staticmethod) and name not in name_space:
new_func = method_transform(name, attr.__func__)
name_space[name] = staticmethod(new_func)
# Customizable name
NewCls = type(f"{prefix}{cls.__name__}", (cls,), name_space)
NewCls = dataclass(slots=True, frozen=True)(NewCls)
NewCls = jax.tree_util.register_dataclass(NewCls)
# Remember the scalar base class
# Preserve the original base class if already wrapped
base_cls = getattr(cls, "_base_env_cls", cls)
NewCls._base_env_cls = base_cls
field_vals = {f.name: getattr(env, f.name) for f in fields(env)}
return NewCls(**field_vals)
[docs]
def vectorise_env(env: "Environment") -> "Environment":
"""
Promote an environment instance to a parallel version by applying
`jax.vmap(...)` to its static methods.
"""
return _wrap_env(env, lambda name, fn: jax.vmap(fn), prefix="Vec")
[docs]
def clip_action_env(
env: "Environment", min_val: float = -1.0, max_val: float = 1.0
) -> "Environment":
"""
Wrap an environment so that its `step` method clips the action
before calling the original step.
"""
def transform(name, fn):
if name == "step":
@jax.jit
def clipped_step(env_obj, action):
clipped_action = jnp.clip(action, min_val, max_val)
return fn(env_obj, clipped_action)
return clipped_step
return fn
return _wrap_env(env, transform, prefix="Clipped")
[docs]
def is_wrapped(env: "Environment") -> bool:
"""
Check whether an environment instance is a wrapped environment.
Parameters
----------
env : Environment
The environment instance to check.
Returns
-------
bool
True if the environment is wrapped (i.e., has a `_base_env_cls` attribute),
False otherwise.
"""
cls = env.__class__
# Note: _base_env_cls is a ClassVar on the base class, so it may not
# exist on unwrapped classes (annotation alone doesn’t create the attr).
base_cls: Type["Environment"] = getattr(cls, "_base_env_cls", cls)
return base_cls is not cls
[docs]
def unwrap(env: "Environment") -> "Environment":
"""
Unwrap an environment to its original base class while preserving all
current field values.
Parameters
----------
env : Environment
The wrapped environment instance.
Returns
-------
Environment
A new instance of the original base environment class with the same
field values as the wrapped instance.
"""
if not is_wrapped(env):
return env # already the base class
cls = env.__class__
base_cls: Type["Environment"] = getattr(cls, "_base_env_cls", cls)
# dataclasses.fields() ignores ClassVar entries, so this won’t include
# _base_env_cls and friends. :contentReference[oaicite:1]{index=1}
field_vals = {f.name: getattr(env, f.name) for f in fields(env)}
return base_cls(**field_vals)
__all__ = ["vectorise_env", "clip_action_env", "is_wrapped", "unwrap"]