Source code for jaxdem.factory
# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
The factory defines and instantiates specific simulation components.
"""
from __future__ import annotations
import jax
from abc import ABC
from dataclasses import dataclass
from functools import partial
from typing import (
Any,
Callable,
ClassVar,
Dict,
Type,
TypeVar,
Optional,
cast,
TYPE_CHECKING,
)
from inspect import signature
# TypeVars for type-preserving decorators & methods
RootT = TypeVar("RootT", bound="Factory")
SubT = TypeVar("SubT", bound="Factory")
[docs]
@partial(
jax.tree_util.register_dataclass, drop_fields=["_registry", "__registry_name__"]
)
@dataclass
class Factory(ABC):
"""
Base factory class for pluggable components. This abstract base class provides a mechanism for registering and creating
subclasses based on a string key.
Notes
-----
Each concrete subclass gets its own private registry. Keys are strings and not case sensitive.
Example
-------
Use Factory as a base class for a specific component type (e.g., `Foo`):
>>> class Foo(Factory["Foo"], ABC):
>>> ...
Register a concrete subclass of `Foo`:
>>> @Foo.register("bar")
>>> class bar:
>>> ...
To instantiate the subclass instance:
>>> Foo.create("bar", **bar_kw)
"""
if not TYPE_CHECKING:
__slots__ = ()
__registry_name__: ClassVar[Optional[str]]
_registry: ClassVar[Dict[str, Type["Factory"]]] = {}
""" Dictionary to store the registered subclases."""
def __init_subclass__(cls, **kw: Any) -> None:
super().__init_subclass__(**kw)
# subclass hook – each concrete root gets its own private registry
cls._registry = {}
cls.__registry_name__ = None
if "create" in cls.__dict__:
raise TypeError(
f"{cls.__name__} is not allowed to override the `create` method. "
"Use `Create` instead for custom instantiation logic."
)
[docs]
@classmethod
def registry_name(cls) -> str:
"""
Returns the key under which this class is registered.
"""
name = getattr(cls, "__registry_name__", None)
if name is None:
raise KeyError(f"{cls.__name__} is not registered in the Factory.")
return str(name)
@property
def type_name(self) -> str:
"""
Returns the key under which this instance's class is registered.
"""
return type(self).registry_name()
[docs]
@classmethod
@partial(jax.named_call, name="Factory.register")
def register(
cls: Type[RootT], key: str | None = None
) -> Callable[[Type[SubT]], Type[SubT]]:
"""
Registers a subclass with the factory's registry.
This method returns a decorator that can be applied to a class
to register it under a specific `key`.
Parameters
----------
key : str or None, optional
The string key under which to register the subclass. If `None`,
the lowercase name of the subclass itself will be used as the key.
Returns
-------
Callable[[Type[T]], Type[T]]
A decorator function that takes a class and registers it, returning
the class unchanged.
Raises
------
ValueError
If the provided `key` (or the default class name) is already
registered in the factory's registry.
Example
-------
Register a class named "MyComponent" under the key "mycomp":
>>> @MyFactory.register("mycomp")
>>> class MyComponent:
>>> ...
Register a class named "DefaultComponent" using its own name as the key:
>>> @MyFactory.register()
>>> class DefaultComponent:
>>> ...
"""
def decorator(sub_cls: Type[SubT]) -> Type[SubT]:
# Preserve an explicitly provided empty-string key instead of
# defaulting to the subclass name. Only fall back to the class name
# when the caller passes `None`.
k = sub_cls.__name__.lower() if key is None else key.lower()
if k in cls._registry:
raise ValueError(
f"{cls.__name__}: key '{k}' already registered for {cls._registry[k].__name__}"
)
cls._registry[k] = sub_cls
# Stamp the registered name on the class.
# Only check for an explicit override on the subclass itself. Base
# classes may already be registered (e.g., under the empty string
# key) and should not block subclasses from choosing their own
# registration name.
existing = sub_cls.__dict__.get("__registry_name__", None)
if existing is not None and existing != k:
raise ValueError(
f"{sub_cls.__name__} has __registry_name__={existing!r}, "
f"but is being registered as {k!r}."
)
setattr(sub_cls, "__registry_name__", k)
return sub_cls
return decorator
[docs]
@classmethod
@partial(jax.named_call, name="Factory.create")
def create(cls: Type[RootT], key: str, /, **kw: Any) -> RootT:
"""
Creates and returns an instance of a registered subclass.
This method looks up the subclass associated with the given `key`
in the factory's registry and then calls its constructor with the
provided arguments. If the subclass defines a `Create` method (capitalized),
that method will be called instead of the constructor. This allows
subclasses to validate or preprocess arguments before instantiation.
Parameters
----------
key : str
The registration key of the subclass to be created.
**kw : Any
Arbitrary keyword arguments to be passed directly to the constructor of the registered subclass.
Returns
-------
T
An instance of the registered subclass.
Raises
------
KeyError
If the provided `key` is not found in the factory's registry.
TypeError
If the provided `**kw` arguments do not match the signature
of the registered subclass's constructor.
Example
-------
Given `Foo` factory and `Bar` registered:
>>> bar_instance = Foo.create("bar", value=42)
>>> print(bar_instance)
Bar(value=42)
"""
try:
sub_cls = cls._registry[key.lower()]
except KeyError as err:
raise KeyError(
f"Unknown {cls.__name__} '{key}'. " f"Available: {list(cls._registry)}"
) from err
# Prefer 'Create' if present, else call the class constructor
create_or_ctor = getattr(sub_cls, "Create", None) or sub_cls
# Tell the type checker this callable returns RootT
factory_callable = cast(Callable[..., RootT], create_or_ctor)
# Optional: friendly arg check
sig = signature(create_or_ctor)
try:
sig.bind_partial(**kw)
except TypeError as err:
raise TypeError(
f"Invalid keyword(s) for {sub_cls.__name__}: {err}. "
f"Expected signature: {sub_cls.__name__}.{create_or_ctor.__name__}{sig}"
) from None
return factory_callable(**kw)