Source code for jaxdem.factory
# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
The factory defines and instantiates specific simulation components.
"""
from abc import ABC
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, ClassVar, Dict, Generic, Type, TypeVar
from inspect import signature
import jax
T = TypeVar("T", bound="Factory")
[docs]
@partial(jax.tree_util.register_dataclass, drop_fields=["_registry"])
@dataclass
class Factory(ABC, Generic[T]):
"""
Base factory class for pluggable components. This abstract base class provides a mechanism for registering and creating
subclasses based on a string key.
Notes
-----
Each concrete subclass gets its own private registry. Keys are strings and not case sensitive.
Example
-------
Use Factory as a base class for a specific component type (e.g., `Foo`):
>>> class Foo(Factory["Foo"], ABC):
>>> ...
Register a concrete subclass of `Foo`:
>>> @Foo.register("bar")
>>> class bar:
>>> ...
To instantiate the subclass instance:
>>> Foo.create("bar", **bar_kw)
"""
__slots__ = ()
_registry: ClassVar[Dict[str, Type["Factory"]]] = {}
""" Dictionary to store the registered subclases."""
def __init_subclass__(cls, **kw):
super().__init_subclass__(**kw)
cls._registry = (
{}
) # subclass hook – each concrete root gets its own private registry
[docs]
@classmethod
def register(cls, key: str | None = None) -> Callable[[Type[T]], Type[T]]:
"""
Registers a subclass with the factory's registry.
This method returns a decorator that can be applied to a class
to register it under a specific `key`.
Parameters
----------
key : str or None, optional
The string key under which to register the subclass. If `None`,
the lowercase name of the subclass itself will be used as the key.
Returns
-------
Callable[[Type[T]], Type[T]]
A decorator function that takes a class and registers it, returning
the class unchanged.
Raises
------
ValueError
If the provided `key` (or the default class name) is already
registered in the factory's registry.
Example
-------
Register a class named "MyComponent" under the key "mycomp":
>>> @MyFactory.register("mycomp")
>>> class MyComponent:
>>> ...
Register a class named "DefaultComponent" using its own name as the key:
>>> @MyFactory.register()
>>> class DefaultComponent:
>>> ...
"""
def decorator(sub_cls: Type[T]) -> Type[T]:
k = (key or sub_cls.__name__).lower()
if k in cls._registry:
raise ValueError(
f"{cls.__name__}: key '{k}' already registered for {cls._registry[k].__name__}"
)
cls._registry[k] = sub_cls
return sub_cls
return decorator
[docs]
@classmethod
def create(cls: Type[T], key: str, /, **kw: Any) -> T:
"""
Creates and returns an instance of a registered subclass.
This method looks up the subclass associated with the given `key`
in the factory's registry and then calls its constructor with the
provided arguments.
Parameters
----------
key : str
The registration key of the subclass to be created.
**kw : Any
Arbitrary keyword arguments to be passed directly to the constructor of the registered subclass.
Returns
-------
T
An instance of the registered subclass.
Raises
------
KeyError
If the provided `key` is not found in the factory's registry.
TypeError
If the provided `**kw` arguments do not match the signature
of the registered subclass's constructor.
Example
-------
Given `Foo` factory and `Bar` registered:
>>> bar_instance = Foo.create("bar", value=42)
>>> print(bar_instance)
Bar(value=42)
"""
try:
sub_cls = cls._registry[key.lower()]
except KeyError as err:
raise KeyError(
f"Unknown {cls.__name__} '{key}'. " f"Available: {list(cls._registry)}"
) from err
try:
signature(sub_cls).bind_partial(**kw)
except TypeError as err:
raise TypeError(
f"Invalid keyword(s) for {sub_cls.__name__}: {err}"
) from None
return sub_cls(**kw) # type: ignore[arg-type]