Source code for jaxdem.utils.serialization

import importlib
from typing import Callable


[docs] def encode_callable(fn: Callable) -> str: """Return a dotted path like 'jax._src.nn.functions.gelu'.""" mod = getattr(fn, "__module__", None) name = getattr(fn, "__name__", None) if not (mod and name): raise TypeError(f"Activation must be a plain function, got: {fn!r}") return f"{mod}.{name}"
[docs] def decode_callable(path: str) -> Callable: """Import a callable from a dotted path string.""" module_path, _, attr = path.rpartition(".") if not module_path or not attr: raise ValueError(f"Invalid callable path: {path!r}") mod = importlib.import_module(module_path) fn = getattr(mod, attr) if not callable(fn): raise TypeError(f"Imported object is not callable: {path!r}") return fn