Source code for jaxdem.rl.envWrapper
# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
Contains wrappers for modifying rl environments.
"""
import jax
import jax.numpy as jnp
from dataclasses import dataclass, fields
from typing import Callable
from .environment import Environment
def _wrap_env(env: "Environment", method_transform: Callable) -> "Environment":
"""
Internal helper to create a new environment subclass with transformed
static methods.
Parameters
----------
env : Environment
The environment instance to wrap.
method_transform : Callable
A function (name: str, func: callable) -> callable
that returns the transformed function for each staticmethod.
Returns
-------
Environment
A new environment instance with transformed static methods.
"""
cls = env.__class__
name_space: dict[str, object] = {}
for name, attr in cls.__dict__.items():
if isinstance(attr, staticmethod):
new_func = method_transform(name, attr.__func__)
name_space[name] = staticmethod(new_func)
NewCls = type(f"Wrapped{cls.__name__}", (cls,), name_space)
NewCls = dataclass(slots=True)(NewCls)
NewCls = jax.tree_util.register_dataclass(NewCls)
field_vals = {f.name: getattr(env, f.name) for f in fields(env)}
return NewCls(**field_vals)
[docs]
def vectorise_env(env: "Environment") -> "Environment":
"""
Promote an environment instance to a parallel version by applying
`jax.vmap(...)` to its static methods.
"""
return _wrap_env(env, lambda name, fn: jax.vmap(fn))
[docs]
def clip_action_env(
env: "Environment", min_val: float = -1.0, max_val: float = 1.0
) -> "Environment":
"""
Wrap an environment so that its `step` method clips the action
before calling the original step.
"""
def transform(name, fn):
if name == "step":
@jax.jit
def clipped_step(env_obj, action):
clipped_action = jnp.clip(action, min_val, max_val)
return fn(env_obj, clipped_action)
return clipped_step
return fn
return _wrap_env(env, transform)