Source code for jaxdem.material

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
Interface for defining materials and the MaterialTable. The MaterialTable creates a SoA container for the materials. Different material types can be used if the force laws supports them.
"""

from dataclasses import dataclass, fields  # Add field import if not present
from typing import Dict, Sequence, Tuple, ClassVar

import jax
import jax.numpy as jnp

from .factory import Factory
from .materialMatchmaker import MaterialMatchmaker


[docs] @jax.tree_util.register_dataclass @dataclass(slots=True) class Material(Factory["Material"]): """ Abstract base class for defining materials. Concrete subclasses of `Material` should define scalar fields (e.g., `young`, `poisson`, `mu`) that represent specific physical properties of a material. These fields are then collected and managed by the :class:`MaterialTable`. Notes ----- - Each field defined in a concrete `Material` subclass will become a named property in the :attr:`MaterialTable.props` dictionary. """ ...
[docs] @Material.register("elastic") @jax.tree_util.register_dataclass @dataclass(slots=True) class Elastic(Material): """ A concrete `Material` implementation for elastic properties. This material type defines properties relevant for elastic interactions, such as Young's modulus and Poisson's ratio. Example ------- >>> import jaxdem as jdem >>> elastic_steel = jdem.Material.create("elastic", young=2.0e11, poisson=0.3) """ young: float poisson: float
[docs] @Material.register("elasticfrict") @jax.tree_util.register_dataclass @dataclass(slots=True) class ElasticFriction(Material): """ A concrete `Material` implementation for elastic properties with friction. This material type extends :class:`Elastic` by adding a coefficient of friction, making it suitable for models that include frictional contact. Example ------- >>> import jaxdem as jdem >>> frictional_rubber = jdem.Material.create("elasticfrict", young=1.0e7, poisson=0.49, mu=0.5) """ young: float poisson: float mu: float
[docs] @jax.tree_util.register_dataclass @dataclass(frozen=True, slots=True) class MaterialTable: """ A container for material properties, organized as Structures of Arrays (SoA) and pre-computed effective pair properties. This class centralizes material data, allowing efficient access to scalar properties for individual materials and pre-calculated effective properties for material-pair interactions. Notes ----- - Scalar properties can be accessed directly using dot notation (e.g., `material_table.young`). - Effective pair properties can also be accessed directly using dot notation (e.g., `material_table.young_eff`). Example ------- Creating a `MaterialTable` from multiple material types: >>> import jax.numpy as jnp >>> import jaxdem as jdem >>> >>> # Define different material instances >>> mat1 = jdem.Material.create("elastic", young=1.0e4, poisson=0.3) >>> mat2 = jdem.Material.create("elasticfrict", young=2.0e4, poisson=0.4, mu=0.5) >>> >>> # Create a MaterialTable using a linear matcher >>> matcher_instance = jdem.MaterialMatchmaker.create("linear") >>> mat_table = matcher_instance.from_materials( >>> [mat1, mat2], >>> matcher=matcher_instance >>> ) """ props: Dict[str, jax.Array] """ A dictionary mapping scalar material property names (e.g., "young", "poisson", "mu") to JAX arrays. Each array has shape `(M,)`, where `M` is the total number of distinct material types present in the table. """ pair: Dict[str, jax.Array] # key → (M, M) """ A dictionary mapping effective pair property names (e.g., "young_eff", "mu_eff") to JAX arrays. Each array has shape `(M, M)`, representing the effective property for interactions between any two material types (M_i, M_j). """ matcher: MaterialMatchmaker """ The :class:`MaterialMatchmaker` instance that was used to compute the effective pair properties stored in the :attr:`pair` dictionary. """
[docs] @staticmethod def from_materials( mats: Sequence[Material], *, matcher: MaterialMatchmaker, fill: float = 0.0, ) -> "MaterialTable": """ Constructs a :class:`MaterialTable` from a sequence of :class:`Material` instances. Parameters ---------- mats : Sequence[Material] A sequence of concrete :class:`Material` instances. Each instance represents a distinct material type in the simulation. The order in this sequence defines their material IDs (0 to `len(mats)-1`). matcher : MaterialMatchmaker The :class:`MaterialMatchmaker` instance to be used for computing effective pair properties (e.g., harmonic mean, arithmetic mean). fill : float, optional A fill value used for material properties that are not defined in a specific `Material` subclass (e.g., if an :class:`Elastic` material is provided when an :class:`ElasticFriction` is expected, `mu` would be filled with this value). Defaults to 0.0. Returns ------- MaterialTable A new `MaterialTable` instance containing the scalar properties and pre-computed effective pair properties for all provided materials. Raises ------ TypeError If `mats` is not a sequence of `Material` instances. """ all_keys = {f.name for m in mats for f in fields(m)} scalars: Dict[str, list[float]] = {k: [] for k in all_keys} for m in mats: for k in all_keys: scalars[k].append(getattr(m, k, fill)) props = {k: jnp.asarray(v, dtype=float) for k, v in scalars.items()} pair = { f"{k}_eff": matcher.get_effective_property(a[:, None], a[None, :]) for k, a in props.items() } return MaterialTable(props=props, pair=pair, matcher=matcher)
def __getattr__(self, item: str) -> jax.Array: """ Allows direct attribute access to scalar and effective pair properties. Parameters ---------- item : str The name of the attribute being accessed (e.g., "young", "young_eff"). Returns ------- jax.Array The JAX array corresponding to the requested scalar or effective pair property. Raises ------ AttributeError If `item` is not found as a scalar property in :attr:`props` or an effective pair property in :attr:`pair`. """ if item in self.props: return self.props[item] if item in self.pair: return self.pair[item] raise AttributeError(item) def __len__(self) -> int: """ Returns the number of distinct material types stored in the table. Returns ------- int The number of materials, `M`. This corresponds to the length of any scalar property array. """ return next(iter(self.props.values())).shape[0]
# TODO: add and merge methods similar to State, returning the corresponding material ID when adding or merging. # Will need to handle the underlying Dict[str, jax.Array] structures and recompute pair properties. # This might require some JAX array manipulations within the `props` and `pair` dictionaries. # The `MaterialTable` is frozen, so methods would return new instances. # Example placeholders for future methods: # @staticmethod # def merge(table1: "MaterialTable", table2: "MaterialTable") -> "MaterialTable": # """Merges two MaterialTable instances.""" # # Logic would involve combining props, then recomputing pair based on the combined set # # and ensuring material IDs are consistent if coming from different tables. # pass # def add_materials(self, mats: Sequence[Material], fill: float = 0.0) -> "MaterialTable": # """Adds new materials to the table, returning a new MaterialTable instance.""" # # Logic would involve converting mats to a partial table, then merging with self. # pass