Source code for jaxdem.utils.h5

import dataclasses
import importlib
import json

import h5py
import jax
import jax.numpy as jnp
import numpy as np

from .quaternion import Quaternion

_STR = h5py.string_dtype(encoding="utf-8")

def _qualname(cls) -> str:
    return f"{cls.__module__}:{cls.__qualname__}"

def _import_qualname(s: str):
    mod_name, _, qual = s.partition(":")
    mod = importlib.import_module(mod_name)
    obj = mod
    for part in qual.split("."):
        obj = getattr(obj, part)
    return obj

def _is_array(x) -> bool:
    return isinstance(x, (jax.Array, np.ndarray))

def _to_numpy(x):
    # ensure host numpy for h5py
    return np.asarray(jax.device_get(x))

def _write_any(g: h5py.Group, name: str, obj, *, allow_callables: bool = False):
    # None
    if obj is None:
        sg = g.create_group(name)
        sg.attrs["__kind__"] = "none"
        return

    # Quaternion
    if isinstance(obj, Quaternion):
        sg = g.create_group(name)
        sg.attrs["__kind__"] = "quaternion"
        sg.create_dataset("w", data=_to_numpy(obj.w))
        sg.create_dataset("xyz", data=_to_numpy(obj.xyz))
        return

    # Arrays
    if _is_array(obj):
        g.create_dataset(name, data=_to_numpy(obj))
        g[name].attrs["__kind__"] = "array"
        return

    # Scalars / strings (JSON-safe)
    if isinstance(obj, (bool, int, float, np.number, str)):
        ds = g.create_dataset(name, data=obj, dtype=_STR if isinstance(obj, str) else None)
        ds.attrs["__kind__"] = "scalar"
        return

    # Dict[str, ...]
    if isinstance(obj, dict):
        if not all(isinstance(k, str) for k in obj.keys()):
            raise TypeError(f"Only dict[str, ...] supported. Got keys: {list(obj.keys())[:5]}")
        sg = g.create_group(name)
        sg.attrs["__kind__"] = "dict"
        sg.attrs["__keys__"] = json.dumps(list(obj.keys()))
        for k, v in obj.items():
            _write_any(sg, k, v, allow_callables=allow_callables)
        return

    # list/tuple
    if isinstance(obj, (list, tuple)):
        sg = g.create_group(name)
        sg.attrs["__kind__"] = "list" if isinstance(obj, list) else "tuple"
        sg.attrs["__len__"] = len(obj)
        for i, v in enumerate(obj):
            _write_any(sg, str(i), v, allow_callables=allow_callables)
        return

    # Callable (optional, only if importable)
    if callable(obj):
        if not allow_callables:
            raise TypeError(f"Refusing to serialize callable {obj!r}. Set allow_callables=True or drop it.")
        mod = getattr(obj, "__module__", None)
        qual = getattr(obj, "__qualname__", None)
        if not mod or not qual:
            raise TypeError(f"Callable not importable by qualname: {obj!r}")
        sg = g.create_group(name)
        sg.attrs["__kind__"] = "callable"
        sg.attrs["__import__"] = f"{mod}:{qual}"
        return

    # Dataclass (covers State/System/MaterialTable/ForceModels/etc.)
    if dataclasses.is_dataclass(obj):
        sg = g.create_group(name)
        sg.attrs["__kind__"] = "dataclass"
        sg.attrs["__class__"] = _qualname(type(obj))
        for f in dataclasses.fields(obj):
            _write_any(sg, f.name, getattr(obj, f.name), allow_callables=allow_callables)
        return

    raise TypeError(f"Unsupported type at {name}: {type(obj)}")

def _read_any(node):
    # dataset
    if isinstance(node, h5py.Dataset):
        kind = node.attrs.get("__kind__", None)
        if kind in (None, "array", "scalar"):
            x = node[()]
            # bytes -> str
            if isinstance(x, (bytes, np.bytes_)):
                return x.decode("utf-8")
            # numeric -> jax array (keeps System/State happy)
            if np.isscalar(x):
                return jnp.asarray(x)
            return jnp.asarray(x)
        raise ValueError(f"Unknown dataset kind {kind!r}")

    # group
    g = node
    kind = g.attrs.get("__kind__", None)

    if kind == "none":
        return None

    if kind == "quaternion":
        w = jnp.asarray(g["w"][...])
        xyz = jnp.asarray(g["xyz"][...])
        return Quaternion.create(w=w, xyz=xyz)

    if kind == "dict":
        keys = json.loads(g.attrs["__keys__"])
        return {k: _read_any(g[k]) for k in keys}

    if kind in ("list", "tuple"):
        n = int(g.attrs["__len__"])
        items = [_read_any(g[str(i)]) for i in range(n)]
        return items if kind == "list" else tuple(items)

    if kind == "callable":
        fn = _import_qualname(g.attrs["__import__"])
        return fn

    if kind == "dataclass":
        cls = _import_qualname(g.attrs["__class__"])
        kw = {k: _read_any(g[k]) for k in g.keys()}
        return cls(**kw)

    raise ValueError(f"Unknown group kind {kind!r}")

[docs] def save(obj, path: str, *, overwrite: bool = True, allow_callables: bool = False): import os if os.path.exists(path): if overwrite: os.remove(path) else: raise FileExistsError(path) with h5py.File(path, "w") as f: _write_any(f, "root", obj, allow_callables=allow_callables)
[docs] def load(path: str): with h5py.File(path, "r") as f: return _read_any(f["root"])