Source code for jaxdem.factory

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
The factory defines and instantiates specific simulation components.
"""
from __future__ import annotations

import jax

from abc import ABC
from dataclasses import dataclass
from functools import partial
from typing import (
    Any,
    Callable,
    ClassVar,
    Dict,
    Type,
    TypeVar,
    Optional,
    cast,
    TYPE_CHECKING,
)
from inspect import signature

# TypeVars for type-preserving decorators & methods
RootT = TypeVar("RootT", bound="Factory")
SubT = TypeVar("SubT", bound="Factory")


[docs] @partial( jax.tree_util.register_dataclass, drop_fields=["_registry", "__registry_name__"] ) @dataclass class Factory(ABC): """ Base factory class for pluggable components. This abstract base class provides a mechanism for registering and creating subclasses based on a string key. Notes ----- Each concrete subclass gets its own private registry. Keys are strings and not case sensitive. Example ------- Use Factory as a base class for a specific component type (e.g., `Foo`): >>> class Foo(Factory["Foo"], ABC): >>> ... Register a concrete subclass of `Foo`: >>> @Foo.register("bar") >>> class bar: >>> ... To instantiate the subclass instance: >>> Foo.create("bar", **bar_kw) """ if not TYPE_CHECKING: __slots__ = () __registry_name__: ClassVar[Optional[str]] _registry: ClassVar[Dict[str, Type["Factory"]]] = {} """ Dictionary to store the registered subclases.""" def __init_subclass__(cls, **kw: Any) -> None: super().__init_subclass__(**kw) # subclass hook – each concrete root gets its own private registry cls._registry = {} cls.__registry_name__ = None if "create" in cls.__dict__: raise TypeError( f"{cls.__name__} is not allowed to override the `create` method. " "Use `Create` instead for custom instantiation logic." )
[docs] @classmethod def registry_name(cls) -> str: """ Returns the key under which this class is registered. """ name = getattr(cls, "__registry_name__", None) if name is None: raise KeyError(f"{cls.__name__} is not registered in the Factory.") return str(name)
@property def type_name(self) -> str: """ Returns the key under which this instance's class is registered. """ return type(self).registry_name()
[docs] @classmethod @partial(jax.named_call, name="Factory.register") def register( cls: Type[RootT], key: str | None = None ) -> Callable[[Type[SubT]], Type[SubT]]: """ Registers a subclass with the factory's registry. This method returns a decorator that can be applied to a class to register it under a specific `key`. Parameters ---------- key : str or None, optional The string key under which to register the subclass. If `None`, the lowercase name of the subclass itself will be used as the key. Returns ------- Callable[[Type[T]], Type[T]] A decorator function that takes a class and registers it, returning the class unchanged. Raises ------ ValueError If the provided `key` (or the default class name) is already registered in the factory's registry. Example ------- Register a class named "MyComponent" under the key "mycomp": >>> @MyFactory.register("mycomp") >>> class MyComponent: >>> ... Register a class named "DefaultComponent" using its own name as the key: >>> @MyFactory.register() >>> class DefaultComponent: >>> ... """ def decorator(sub_cls: Type[SubT]) -> Type[SubT]: # Preserve an explicitly provided empty-string key instead of # defaulting to the subclass name. Only fall back to the class name # when the caller passes `None`. k = sub_cls.__name__.lower() if key is None else key.lower() if k in cls._registry: raise ValueError( f"{cls.__name__}: key '{k}' already registered for {cls._registry[k].__name__}" ) cls._registry[k] = sub_cls # Stamp the registered name on the class. # Only check for an explicit override on the subclass itself. Base # classes may already be registered (e.g., under the empty string # key) and should not block subclasses from choosing their own # registration name. existing = sub_cls.__dict__.get("__registry_name__", None) if existing is not None and existing != k: raise ValueError( f"{sub_cls.__name__} has __registry_name__={existing!r}, " f"but is being registered as {k!r}." ) setattr(sub_cls, "__registry_name__", k) return sub_cls return decorator
[docs] @classmethod @partial(jax.named_call, name="Factory.create") def create(cls: Type[RootT], key: str, /, **kw: Any) -> RootT: """ Creates and returns an instance of a registered subclass. This method looks up the subclass associated with the given `key` in the factory's registry and then calls its constructor with the provided arguments. If the subclass defines a `Create` method (capitalized), that method will be called instead of the constructor. This allows subclasses to validate or preprocess arguments before instantiation. Parameters ---------- key : str The registration key of the subclass to be created. **kw : Any Arbitrary keyword arguments to be passed directly to the constructor of the registered subclass. Returns ------- T An instance of the registered subclass. Raises ------ KeyError If the provided `key` is not found in the factory's registry. TypeError If the provided `**kw` arguments do not match the signature of the registered subclass's constructor. Example ------- Given `Foo` factory and `Bar` registered: >>> bar_instance = Foo.create("bar", value=42) >>> print(bar_instance) Bar(value=42) """ try: sub_cls = cls._registry[key.lower()] except KeyError as err: raise KeyError( f"Unknown {cls.__name__} '{key}'. " f"Available: {list(cls._registry)}" ) from err # Prefer 'Create' if present, else call the class constructor create_or_ctor = getattr(sub_cls, "Create", None) or sub_cls # Tell the type checker this callable returns RootT factory_callable = cast(Callable[..., RootT], create_or_ctor) # Optional: friendly arg check sig = signature(create_or_ctor) try: sig.bind_partial(**kw) except TypeError as err: raise TypeError( f"Invalid keyword(s) for {sub_cls.__name__}: {err}. " f"Expected signature: {sub_cls.__name__}.{create_or_ctor.__name__}{sig}" ) from None return factory_callable(**kw)