jaxdem.forces#

Force-law interfaces and concrete implementations.

Classes

ForceModel([laws])

Abstract base class for defining inter-particle force laws and their corresponding potential energies.

class jaxdem.forces.ForceModel(laws: Tuple[ForceModel, ...] = ())[source]#

Bases: Factory, ABC

Abstract base class for defining inter-particle force laws and their corresponding potential energies.

Concrete subclasses implement specific force and energy models, such as linear springs, Hertzian contacts, etc.

Notes

  • The force() and energy() methods should correctly handle the case where i and j refer to the same particle (i.e., i == j). There is no guarantee that self-interaction calls will not occur.

Example

To define a custom force model, inherit from ForceModel and implement its abstract methods:

>>> @ForceModel.register("myCustomForce")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True)
>>> class MyCustomForce(ForceModel):
        ...
laws: Tuple[ForceModel, ...]#

A static tuple of other ForceModel instances that compose this force model.

This allows for creating composite force models (e.g., a total force being the sum of a spring force and a damping force).

abstractmethod static force(i: int, j: int, pos: jax.Array, state: State, system: System) Tuple[jax.Array, jax.Array][source][source]#

Compute the force and torque vector acting on particle \(i\) due to particle \(j\).

Parameters:
  • i (int) – Index of the first particle (on which the interaction acts).

  • j (int) – Index of the second particle (which is exerting the interaction).

  • state (State) – Current state of the simulation.

  • system (System) – Simulation system configuration.

Returns:

A tuple (force, torque) where force has shape (dim,) and torque has shape (1,) in 2D or (3,) in 3D.

Return type:

Tuple[jax.Array, jax.Array]

abstractmethod static energy(i: int, j: int, pos: jax.Array, state: State, system: System) jax.Array[source][source]#

Compute the potential energy of the interaction between particle \(i\) and particle \(j\).

Parameters:
  • i (int) – Index of the first particle.

  • j (int) – Index of the second particle.

  • state (State) – Current state of the simulation.

  • system (System) – Simulation system configuration.

Returns:

Scalar JAX array representing the potential energy of the interaction between particles \(i\) and \(j\).

Return type:

jax.Array

property required_material_properties: Tuple[str, ...][source]#

A static tuple of strings specifying the material properties required by this force model.

These properties (e.g., ‘young_eff’, ‘restitution’, …) must be present in the System.mat_table for the model to function correctly. This is used for validation.

class jaxdem.forces.LawCombiner(laws: Tuple[ForceModel, ...] = ())[source]#

Bases: ForceModel

Sum a tuple of elementary force laws.

static force(i: int, j: int, pos: jax.Array, state: State, system: System) Tuple[jax.Array, jax.Array][source][source]#
static energy(i: int, j: int, pos: jax.Array, state: State, system: System) jax.Array[source][source]#
class jaxdem.forces.ForceRouter(laws: Tuple[ForceModel, ...] = (), table: Tuple[Tuple[ForceModel, ...], ...] = ())[source]#

Bases: ForceModel

Static species-to-force lookup table.

table: Tuple[Tuple[ForceModel, ...], ...]#
static from_dict(S: int, mapping: dict[Tuple[int, int], ForceModel]) ForceRouter[source][source]#
static force(i: int, j: int, pos: jax.Array, state: State, system: System) jax.Array[source][source]#
static energy(i: int, j: int, pos: jax.Array, state: State, system: System) jax.Array[source][source]#
class jaxdem.forces.SpringForce(laws: Tuple[ForceModel, ...] = ())[source]#

Bases: ForceModel

A ForceModel implementation for a linear spring-like interaction between particles.

Notes

  • The ‘effective Young’s modulus’ (\(k_{eff,\; ij}\)) is retrieved from the jaxdem.System.mat_table based on the material IDs of the interacting particles.

  • The force is zero if \(i == j\).

  • A small epsilon is added to the squared distance (\(r^2\)) before taking the square root to prevent division by zero or NaN issues when particles are perfectly co-located.

The penetration \(\delta\) (overlap) between two particles \(i\) and \(j\) is:

\[\delta = (R_i + R_j) - r\]

where \(R_i\) and \(R_j\) are the radii of particles \(i\) and \(j\) respectively, and \(r = ||r_{ij}||\) is the distance between their centers.

The scalar overlap \(s\) is defined as:

\[s = \max \left(0, \frac{R_i + R_j}{r} - 1 \right)\]

The force \(F_{ij}\) acting on particle \(i\) due to particle \(j\) is:

\[F_{ij} = k_{eff,\; ij} s r_{ij}\]

The potential energy \(E_{ij}\) of the interaction is:

\[E_{ij} = \frac{1}{2} k_{eff,\; ij} s^2\]

where \(k_{eff,\; ij}\) is the effective Young’s modulus for the particle pair.

static force(i: int, j: int, pos: jax.Array, state: State, system: System) Tuple[jax.Array, jax.Array][source][source]#

Compute linear spring-like interaction force acting on particle \(i\) due to particle \(j\).

Returns zero when \(i = j\).

Parameters:
  • i (int) – Index of the first particle.

  • j (int) – Index of the second particle.

  • state (State) – Current state of the simulation.

  • system (System) – Simulation system configuration.

Returns:

Force vector acting on particle \(i\) due to particle \(j\).

Return type:

jax.Array

static energy(i: int, j: int, pos: jax.Array, state: State, system: System) jax.Array[source][source]#

Compute linear spring-like interaction potential energy between particle \(i\) and particle \(j\).

Returns zero when \(i = j\).

Parameters:
  • i (int) – Index of the first particle.

  • j (int) – Index of the second particle.

  • state (State) – Current state of the simulation.

  • system (System) – Simulation system configuration.

Returns:

Scalar JAX array representing the potential energy of the interaction between particles \(i\) and \(j\).

Return type:

jax.Array

property required_material_properties: Tuple[str, ...][source]#

A static tuple of strings specifying the material properties required by this force model.

These properties (e.g., ‘young_eff’, ‘restitution’, …) must be present in the System.mat_table for the model to function correctly. This is used for validation.

class jaxdem.forces.WCA(laws: Tuple[ForceModel, ...] = ())[source]#

Bases: ForceModel

Weeks-Chandler-Andersen (WCA) purely repulsive Lennard-Jones interaction.

Uses material-pair parameter:
  • epsilon_eff[mi, mj]

The length scale \(\sigma_{ij}\) is derived from particle radii (like spring.py):

\[\sigma_{ij} = R_i + R_j\]
Potential (for r < r_c = 2^(1/6) sigma):

U(r) = 4 eps [(sigma/r)^12 - (sigma/r)^6] + eps

else:

U(r) = 0

Force:

F_vec = 24 eps (2 (sigma/r)^12 - (sigma/r)^6) * (1/r^2) * r_ij

static force(i: int, j: int, pos: jax.Array, state: State, system: System) Tuple[jax.Array, jax.Array][source][source]#
static energy(i: int, j: int, pos: jax.Array, state: State, system: System) jax.Array[source][source]#
property required_material_properties: Tuple[str, ...][source]#

A static tuple of strings specifying the material properties required by this force model.

These properties (e.g., ‘young_eff’, ‘restitution’, …) must be present in the System.mat_table for the model to function correctly. This is used for validation.

class jaxdem.forces.LennardJones(laws: Tuple[ForceModel, ...] = ())[source]#

Bases: ForceModel

Lennard-Jones (LJ) 12-6 interaction with a per-pair cutoff and energy shift.

Uses material-pair parameter:
  • epsilon_eff[mi, mj]

The length scale \(\sigma_{ij}\) is derived from particle radii (like spring.py):

\[\sigma_{ij} = R_i + R_j\]

Potential (for \(r < r_c = 2.5 \sigma_{ij}\)):

\[U(r) = 4 \epsilon \left[\left(\frac{\sigma}{r}\right)^{12} - \left(\frac{\sigma}{r}\right)^6 \right] - U(r_c)\]

else:

\[U(r) = 0\]

Force (for \(r < r_c\)):

\[\mathbf{F} = 24 \epsilon \left(2 \left(\frac{\sigma}{r}\right)^{12} - \left(\frac{\sigma}{r}\right)^6\right) \frac{1}{r^2}\, \mathbf{r}_{ij}\]
RC_FACTOR: ClassVar[float] = 2.5#
static force(i: int, j: int, pos: jax.Array, state: State, system: System) Tuple[jax.Array, jax.Array][source][source]#
static energy(i: int, j: int, pos: jax.Array, state: State, system: System) jax.Array[source][source]#
property required_material_properties: Tuple[str, ...][source]#

A static tuple of strings specifying the material properties required by this force model.

These properties (e.g., ‘young_eff’, ‘restitution’, …) must be present in the System.mat_table for the model to function correctly. This is used for validation.

class jaxdem.forces.WCAShifted(laws: Tuple[ForceModel, ...] = ())[source]#

Bases: ForceModel

Contact-start, force-shifted WCA/LJ repulsion.

This model enforces that the interaction “begins” at contact:

  • cutoff at \(r_c = \sigma_{ij}\) where \(\sigma_{ij} = R_i + R_j\)

  • \(U(r_c) = 0\)

  • \(F(r_c) = 0\) (force-shifted; smooth turn-on at contact)

Uses material-pair parameter:
  • epsilon_eff[mi, mj]

static force(i: int, j: int, pos: jax.Array, state: State, system: System) Tuple[jax.Array, jax.Array][source][source]#
static energy(i: int, j: int, pos: jax.Array, state: State, system: System) jax.Array[source][source]#
property required_material_properties: Tuple[str, ...][source]#

A static tuple of strings specifying the material properties required by this force model.

These properties (e.g., ‘young_eff’, ‘restitution’, …) must be present in the System.mat_table for the model to function correctly. This is used for validation.

class jaxdem.forces.ForceManager(gravity: jax.Array, external_force: jax.Array, external_force_com: jax.Array, external_torque: jax.Array, is_com_force: Tuple[bool, ...] = (), force_functions: Tuple[ForceFunction, ...] = (), energy_functions: Tuple[EnergyFunction | None, ...] = ())[source]#

Bases: object

Manage custom force contributions external to the collider. It also accumulates forces in the state after collider application, accounting for rigid bodies.

gravity: jax.Array#

Constant acceleration applied to all particles. Shape (dim,).

external_force: jax.Array#

Accumulated external force applied to all particles (at particle position). This buffer is cleared when apply() is invoked.

external_force_com: jax.Array#

Accumulated external force applied to Center of Mass (does not induce torque). This buffer is cleared when apply() is invoked.

external_torque: jax.Array#

Accumulated external torque applied to all particles. This buffer is cleared when apply() is invoked.

is_com_force: Tuple[bool, ...]#

Boolean array corresponding to force_functions with shape (n_forces,). If True, the force is applied to the Center of Mass (no induced torque). If False, the force is applied to the constituent particle (induces torque via lever arm).

force_functions: Tuple[ForceFunction, ...]#

Tuple of callables with signature (pos, state, system) returning per-particle force and torque arrays.

energy_functions: Tuple[EnergyFunction | None, ...]#

Tuple of callables (or None) with signature (pos, state, system) returning per-particle potential energy arrays. Corresponds to force_functions.

static create(state_shape: Tuple[int, ...], *, gravity: jax.Array | None = None, force_functions: Sequence[ForceFunction | Tuple[ForceFunction, bool] | Tuple[ForceFunction, EnergyFunction] | Tuple[ForceFunction, EnergyFunction, bool]] = ()) ForceManager[source][source]#

Create a ForceManager for a state with the given shape.

Parameters:
  • state_shape – Shape of the state position array, typically (..., dim).

  • gravity – Optional initial gravitational acceleration. Defaults to zeros of shape (dim,).

  • force_functions

    Sequence of callables or tuples. Supported formats:

    • ForceFunc: Applied at particle, no potential energy.

    • (ForceFunc, bool): Boolean specifies if it is a COM force.

    • (ForceFunc, EnergyFunc): Includes potential energy function.

    • (ForceFunc, EnergyFunc, bool): Includes energy and COM specifier.

    Signature of ForceFunc: (pos, state, system) -> (Force, Torque) Signature of EnergyFunc: (pos, state, system) -> Energy

    Supported formats for force_functions items: - func -> (func, None, False) - (func,) -> (func, None, False) - (func, bool) -> (func, None, bool) - (func, energy) -> (func, energy, False) - (func, energy, bool) -> (func, energy, bool) - (func, None, bool) -> (func, None, bool)

static add_force(state: State, system: System, force: jax.Array, *, is_com: bool = False) System[source][source]#

Accumulate an external force to be applied on the next apply call for all particles.

Parameters:
  • state (State) – Current state of the simulation.

  • system (System) – Simulation system configuration.

  • force (jax.Array) – External force to be added to all particles in the current order.

  • is_com (bool, optional) – If True, force is applied to Center of Mass (no induced torque). If False (default), force is applied to Particle Position (induces torque).

static add_force_at(state: State, system: System, force: jax.Array, idx: jax.Array, *, is_com: bool = False) System[source][source]#

Add an external force to particles with ID=idx.

Parameters:
  • state (State) – Current state of the simulation.

  • system (System) – Simulation system configuration.

  • force (jax.Array) – External force to be added to particles with ID=idx.

  • idx (jax.Array) – ID of the particles affected by the external force.

  • is_com (bool, optional) – If True, force is applied to Center of Mass (no induced torque). If False (default), force is applied to Particle Position (induces torque).

static add_torque(state: State, system: System, torque: jax.Array) System[source][source]#

Accumulate an external torque to be applied on the next apply call for all particles.

Parameters:
  • state (State) – Current state of the simulation.

  • system (System) – Simulation system configuration.

  • torque (jax.Array) – External torque to be added to all particles in the current order..

static add_torque_at(state: State, system: System, torque: jax.Array, idx: jax.Array) System[source][source]#

Add an external torque to particles with ID=idx.

Parameters:
  • state (State) – Current state of the simulation.

  • system (System) – Simulation system configuration.

  • torque (jax.Array) – External torque to be added to particles with ID=idx.

  • idx (jax.Array) – ID of the particles affected by the external force.

static apply(state: State, system: System) Tuple[State, System][source][source]#

Accumulate managed per-particle contributions on top of collider/contact forces, then perform final clump aggregation + broadcast.

Parameters:
  • state (State) – Current state of the simulation.

  • system (System) – Simulation system configuration.

Returns:

The updated state and system after one time step.

Return type:

Tuple[State, System]

static compute_potential_energy(state: State, system: System) jax.Array[source][source]#

Compute the total potential energy of the system.

Notes

  • The energy of clump members is divided by the number of spheres in the clump.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A scalar JAX array representing the total potential energy of each particle.

Return type:

jax.Array

class jaxdem.forces.DeformableParticleContainer(elements: Array | None, edges: Array | None, element_adjacency: Array | None, element_adjacency_edges: Array | None, elements_ID: Array | None, edges_ID: Array | None, element_adjacency_ID: Array | None, num_bodies: int, initial_body_contents: Array | None, initial_element_measures: Array | None, initial_edge_lengths: Array | None, initial_bending: Array | None, inv_ref_shape: Array | None, inv_ref_tet_shape: Array | None, initial_tet_volumes: Array | None, em: Array | None, ec: Array | None, eb: Array | None, el: Array | None, gamma: Array | None, lame_lambda: Array | None, lame_mu: Array | None, use_tetrahedral_svk: bool)[source]#

Bases: object

Registry holding topology and reference configuration for deformable particles.

This container manages the mesh connectivity (elements, edges, etc.) and reference properties (initial measures, contents, lengths, angles) required to compute forces. It supports both 3D (volumetric bodies bounded by triangles) and 2D (planar bodies bounded by line segments).

Indices in elements, edges, etc. correspond to the ID of particles in State.

The general form of the deformable particle potential energy per particle is:

\[&E_K = E_{K,measure} + E_{K,content} + E_{K,bending} + E_{K,edge} + E_{K,strain}\]

Definitions per Dimension:

  • 3D: Measure ($mathcal{M}$) is Face Area; Content ($mathcal{C}$) is Volume; Elements are Triangles.

  • 2D: Measure ($mathcal{M}$) is Segment Length; Content ($mathcal{C}$) is Enclosed Area; Elements are Segments.

Strain Energy (StVK) on Elements (Triangles): .. math:

W = A_0 \cdot \left( \mu \mathrm{tr}(E^2) + \frac{\lambda}{2} (\mathrm{tr} E)^2 \right)

All mesh properties are concatenated along axis=0.

elements: Array | None#

Array of vertex indices forming the boundary elements. Shape: (M, 3) for 3D (Triangles) or (M, 2) for 2D (Segments). Indices refer to the particle unique ID corresponding to the State.pos array.

edges: Array | None#

(E, 2). Each row contains the indices of the two vertices forming the edge. Note: In 2D, the set of edges often overlaps with the set of elements (segments).

Type:

Array of vertex indices forming the unique edges (wireframe). Shape

element_adjacency: Array | None#

(A, 2). Each row contains the indices of the two elements sharing a connection.

Type:

Array of element adjacency pairs (for bending/dihedral angles). Shape

element_adjacency_edges: Array | None#

Array of vertex IDs forming the shared edge for each adjacency. Shape: (A, 2).

elements_ID: Array | None#

(M,). elements_ID[i] == k means element i belongs to body k.

Type:

Array of body IDs for each boundary element. Shape

edges_ID: Array | None#

(E,). edges_ID[e] == k means edge e belongs to body k.

Type:

Array of body IDs for each unique edge. Shape

element_adjacency_ID: Array | None#

(A,). element_adjacency_ID[a] == k means adjacency a belongs to body k.

Type:

Array of body IDs for each adjacency (bending hinge). Shape

num_bodies: int#

(K,).

Type:

Total number of distinct deformable bodies in the container. Shape

initial_body_contents: Array | None#

(K,). Represents Volume in 3D or Area in 2D.

Type:

Array of reference (stress-free) bulk content for each body. Shape

initial_element_measures: Array | None#

(M,). Represents Area in 3D or Length in 2D.

Type:

Array of reference (stress-free) measures for each element. Shape

initial_edge_lengths: Array | None#

(E,).

Type:

Array of reference (stress-free) lengths for each unique edge. Shape

initial_bending: Array | None#

(A,). Represents Dihedral Angle in 3D or Vertex Angle in 2D.

Type:

Array of reference (stress-free) bending angles for each adjacency. Shape

inv_ref_shape: Array | None#

Inverse of the reference shape matrix for each element. Shape: (M, 2, 2) for triangles, or (M, 1, 1) for segments. Used to compute the deformation gradient F or Green strain E.

inv_ref_tet_shape: Array | None#

Inverse of the reference shape matrix for tetrahedra formed by each boundary triangle and the corresponding body center. Shape: (M, 3, 3).

initial_tet_volumes: Array | None#

Reference volumes for tetrahedra formed by each boundary triangle and the corresponding body center. Shape: (M,).

em: Array | None#

(K,). (Controls Area stiffness in 3D; Length stiffness in 2D).

Type:

Measure elasticity coefficient (Modulus) for each body. Shape

ec: Array | None#

(K,). (Controls Volume stiffness in 3D; Area stiffness in 2D).

Type:

Content elasticity coefficient (Modulus) for each body. Shape

eb: Array | None#

(K,).

Type:

Bending elasticity coefficient (Rigidity) for each body. Shape

el: Array | None#

(K,).

Type:

Edge length elasticity coefficient (Modulus) for each body. Shape

gamma: Array | None#

(K,).

Type:

Surface/Line tension coefficient for each body. Shape

lame_lambda: Array | None#

(K,).

Type:

First Lamé parameter for StVK model. Shape

lame_mu: Array | None#

(K,).

Type:

Second Lamé parameter (Shear Modulus) for StVK model. Shape

use_tetrahedral_svk: bool#

If True, compute StVK strain energy on tetrahedra formed by each boundary triangle and the mesh center of its body (3D only). If False, use the existing shell-like element StVK model.

static create(vertices: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, elements: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, edges: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, element_adjacency: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, elements_ID: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, edges_ID: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, element_adjacency_ID: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, initial_element_measures: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, initial_body_contents: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, initial_bending: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, initial_edge_lengths: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, em: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, ec: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, eb: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, el: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, gamma: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, lame_lambda: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, lame_mu: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, use_tetrahedral_svk: bool = False) DeformableParticleContainer[source][source]#

Factory method to create a new DeformableParticleContainer.

Calculates initial geometric properties (areas, volumes, bending angles, and edge lengths) from the provided vertices if they are not explicitly provided.

static merge(c1: DeformableParticleContainer, c2: DeformableParticleContainer) DeformableParticleContainer[source][source]#

Merges two DeformableParticleContainer instances.

static add(container: DeformableParticleContainer, vertices: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, elements_ID: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, elements: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, initial_element_measures: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, element_adjacency_ID: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, element_adjacency: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, initial_bending: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, edges_ID: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, edges: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, initial_edge_lengths: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, initial_body_contents: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, em: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, ec: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, eb: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, el: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, gamma: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, lame_lambda: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, lame_mu: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, use_tetrahedral_svk: bool = False) DeformableParticleContainer[source][source]#

Factory method to add bodies to a container.

static compute_potential_energy(pos: jax.Array, state: State, _system: System, container: DeformableParticleContainer) Tuple[jax.Array, Dict[str, jax.Array]][source][source]#
static create_force_function(container: DeformableParticleContainer) ForceFunction[source][source]#
static create_force_energy_functions(container: DeformableParticleContainer) Tuple[ForceFunction, EnergyFunction][source][source]#

Modules

deformable_particle

Implementation of the deformable particle container.

force_manager

Utilities for managing external and custom force contributions that do not depend on the collider.

law_combiner

Composite force model that sums multiple force laws.

lennardjones

router

Force model router selecting laws based on species pairs.

spring

Linear spring force model.

wca

wca_shifted