Source code for jaxdem.forces.deformable_particle

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project - https://github.com/cdelv/JaxDEM
"""Implementation of the deformable particle container."""

from __future__ import annotations

import jax
import jax.numpy as jnp
from jax.typing import ArrayLike

from dataclasses import dataclass
from typing import TYPE_CHECKING, Tuple, Optional, Dict
from functools import partial

from ..utils.linalg import cross, unit

if TYPE_CHECKING:  # pragma: no cover
    from ..state import State
    from ..system import System
    from .force_manager import ForceFunction
    from .force_manager import EnergyFunction


[docs] @jax.tree_util.register_dataclass @dataclass(slots=True) class DeformableParticleContainer: # type: ignore[misc] r""" Registry holding topology and reference configuration for deformable particles. This container manages the mesh connectivity (`elements`, `edges`, etc.) and reference properties (initial measures, contents, lengths, angles) required to compute forces. It supports both 3D (volumetric bodies bounded by triangles) and 2D (planar bodies bounded by line segments). Indices in elements, edges, etc. correspond to the ID of particles in :class:`~jaxdem.state.State`. The general form of the deformable particle potential energy per particle is: .. math:: &E_K = E_{K,measure} + E_{K,content} + E_{K,bending} + E_{K,edge} + E_{K,strain} **Definitions per Dimension:** * **3D:** Measure ($\mathcal{M}$) is Face Area; Content ($\mathcal{C}$) is Volume; Elements are Triangles. * **2D:** Measure ($\mathcal{M}$) is Segment Length; Content ($\mathcal{C}$) is Enclosed Area; Elements are Segments. Strain Energy (StVK) on Elements (Triangles): .. math:: W = A_0 \cdot \left( \mu \mathrm{tr}(E^2) + \frac{\lambda}{2} (\mathrm{tr} E)^2 \right) All mesh properties are concatenated along axis=0. """ # --- Topology --- elements: Optional[jax.Array] """ Array of vertex indices forming the boundary elements. Shape: (M, 3) for 3D (Triangles) or (M, 2) for 2D (Segments). Indices refer to the particle unique ID corresponding to the `State.pos` array. """ edges: Optional[jax.Array] """ Array of vertex indices forming the unique edges (wireframe). Shape: (E, 2). Each row contains the indices of the two vertices forming the edge. Note: In 2D, the set of edges often overlaps with the set of elements (segments). """ element_adjacency: Optional[jax.Array] """ Array of element adjacency pairs (for bending/dihedral angles). Shape: (A, 2). Each row contains the indices of the two elements sharing a connection. """ element_adjacency_edges: Optional[jax.Array] """ Array of vertex IDs forming the shared edge for each adjacency. Shape: (A, 2). """ # --- ID Mappings --- elements_ID: Optional[jax.Array] """ Array of body IDs for each boundary element. Shape: (M,). `elements_ID[i] == k` means element `i` belongs to body `k`. """ edges_ID: Optional[jax.Array] """ Array of body IDs for each unique edge. Shape: (E,). `edges_ID[e] == k` means edge `e` belongs to body `k`. """ element_adjacency_ID: Optional[jax.Array] """ Array of body IDs for each adjacency (bending hinge). Shape: (A,). `element_adjacency_ID[a] == k` means adjacency `a` belongs to body `k`. """ num_bodies: int """ Total number of distinct deformable bodies in the container. Shape: (K,). """ # --- Reference Configuration --- initial_body_contents: Optional[jax.Array] """ Array of reference (stress-free) bulk content for each body. Shape: (K,). Represents Volume in 3D or Area in 2D. """ initial_element_measures: Optional[jax.Array] """ Array of reference (stress-free) measures for each element. Shape: (M,). Represents Area in 3D or Length in 2D. """ initial_edge_lengths: Optional[jax.Array] """ Array of reference (stress-free) lengths for each unique edge. Shape: (E,). """ initial_bending: Optional[jax.Array] """ Array of reference (stress-free) bending angles for each adjacency. Shape: (A,). Represents Dihedral Angle in 3D or Vertex Angle in 2D. """ # --- Strain Energy Data (SVK) --- inv_ref_shape: Optional[jax.Array] """ Inverse of the reference shape matrix for each element. Shape: (M, 2, 2) for triangles, or (M, 1, 1) for segments. Used to compute the deformation gradient F or Green strain E. """ inv_ref_tet_shape: Optional[jax.Array] """ Inverse of the reference shape matrix for tetrahedra formed by each boundary triangle and the corresponding body center. Shape: (M, 3, 3). """ initial_tet_volumes: Optional[jax.Array] """ Reference volumes for tetrahedra formed by each boundary triangle and the corresponding body center. Shape: (M,). """ # --- Coefficients --- em: Optional[jax.Array] """ Measure elasticity coefficient (Modulus) for each body. Shape: (K,). (Controls Area stiffness in 3D; Length stiffness in 2D). """ ec: Optional[jax.Array] """ Content elasticity coefficient (Modulus) for each body. Shape: (K,). (Controls Volume stiffness in 3D; Area stiffness in 2D). """ eb: Optional[jax.Array] """ Bending elasticity coefficient (Rigidity) for each body. Shape: (K,). """ el: Optional[jax.Array] """ Edge length elasticity coefficient (Modulus) for each body. Shape: (K,). """ gamma: Optional[jax.Array] """ Surface/Line tension coefficient for each body. Shape: (K,). """ lame_lambda: Optional[jax.Array] """ First Lamé parameter for StVK model. Shape: (K,). """ lame_mu: Optional[jax.Array] """ Second Lamé parameter (Shear Modulus) for StVK model. Shape: (K,). """ use_tetrahedral_svk: bool """ If True, compute StVK strain energy on tetrahedra formed by each boundary triangle and the mesh center of its body (3D only). If False, use the existing shell-like element StVK model. """
[docs] @staticmethod @partial(jax.named_call, name="DeformableParticleContainer.create") def create( vertices: ArrayLike, *, elements: Optional[ArrayLike] = None, edges: Optional[ArrayLike] = None, element_adjacency: Optional[ArrayLike] = None, # ID mappings elements_ID: Optional[ArrayLike] = None, edges_ID: Optional[ArrayLike] = None, element_adjacency_ID: Optional[ArrayLike] = None, # Reference states (computed if None) initial_element_measures: Optional[ArrayLike] = None, initial_body_contents: Optional[ArrayLike] = None, initial_bending: Optional[ArrayLike] = None, initial_edge_lengths: Optional[ArrayLike] = None, # Coefficients em: Optional[ArrayLike] = None, ec: Optional[ArrayLike] = None, eb: Optional[ArrayLike] = None, el: Optional[ArrayLike] = None, gamma: Optional[ArrayLike] = None, lame_lambda: Optional[ArrayLike] = None, lame_mu: Optional[ArrayLike] = None, use_tetrahedral_svk: bool = False, ) -> DeformableParticleContainer: r""" Factory method to create a new :class:`DeformableParticleContainer`. Calculates initial geometric properties (areas, volumes, bending angles, and edge lengths) from the provided `vertices` if they are not explicitly provided. """ v_ref = jnp.asarray(vertices, dtype=float) dim = v_ref.shape[-1] # 1. Standardize Topologies elements = jnp.asarray(elements, dtype=int) if elements is not None else None edges = jnp.asarray(edges, dtype=int) if edges is not None else None element_adjacency = ( jnp.asarray(element_adjacency, dtype=int) if element_adjacency is not None else None ) # 2. Densify Body IDs ids_to_check = [ x for x in [elements_ID, edges_ID, element_adjacency_ID] if x is not None ] if ids_to_check: unique_ids, dense_map = jnp.unique( jnp.concatenate([jnp.atleast_1d(x) for x in ids_to_check]), return_inverse=True, ) num_bodies = unique_ids.size # Helper to extract mapped slices cursor = 0 def get_dense(orig: Optional[ArrayLike]) -> Optional[jax.Array]: nonlocal cursor if orig is None: return None size = jnp.atleast_1d(orig).size res = dense_map[cursor : cursor + size] cursor += size return res elements_ID = get_dense(elements_ID) edges_ID = get_dense(edges_ID) element_adjacency_ID = get_dense(element_adjacency_ID) else: num_bodies = 1 elements_ID = ( jnp.zeros(elements.shape[0], dtype=int) if elements is not None else None ) edges_ID = ( jnp.zeros(edges.shape[0], dtype=int) if edges is not None else None ) element_adjacency_ID = ( jnp.zeros(element_adjacency.shape[0], dtype=int) if element_adjacency is not None else None ) # 3. Geometric computations # Trigger calculation if any relevant coeff is present or if purely creating topology calc_geom = True adj_edges = None if calc_geom and elements is not None: compute_fn = ( compute_element_properties_3D if dim == 3 else compute_element_properties_2D ) norms, measures, partial_contents = jax.vmap(compute_fn)(v_ref[elements]) initial_element_measures = ( jnp.asarray(initial_element_measures) if initial_element_measures is not None else measures ) if elements_ID is not None: if initial_body_contents is None: initial_body_contents = jax.ops.segment_sum( partial_contents, elements_ID, num_segments=num_bodies ) else: initial_body_contents = jnp.asarray(initial_body_contents) if element_adjacency is not None: # Bending angles n1, n2 = norms[element_adjacency[:, 0]], norms[element_adjacency[:, 1]] initial_bending = ( jnp.asarray(initial_bending) if initial_bending is not None else jax.vmap(angle_between_normals)(n1, n2) ) # 3D Winding preservation for bending if dim == 3: f1 = elements[element_adjacency[:, 0]] matches = ( f1[:, :, None] == elements[element_adjacency[:, 1]][:, None, :] ) missing_idx = jnp.argmin( jnp.any(matches, axis=2).astype(int), axis=1 ) v_start = jnp.take_along_axis( f1, ((missing_idx + 1) % 3)[:, None], axis=1 ) v_end = jnp.take_along_axis( f1, ((missing_idx + 2) % 3)[:, None], axis=1 ) adj_edges = jnp.concatenate([v_start, v_end], axis=1) if edges is not None: lengths = jnp.linalg.norm(v_ref[edges[:, 1]] - v_ref[edges[:, 0]], axis=-1) initial_edge_lengths = ( jnp.asarray(initial_edge_lengths) if initial_edge_lengths is not None else lengths ) # 4. Precompute SVK Reference Shape Inverses inv_ref_shape = None inv_ref_tet_shape = None initial_tet_volumes = None if (lame_lambda is not None or lame_mu is not None) and elements is not None: inv_ref_shape = jax.vmap(compute_inverse_reference_shape)(v_ref[elements]) if use_tetrahedral_svk and dim == 3 and elements.shape[1] == 3: ref_centers = compute_body_centers_from_elements( v_ref, elements, elements_ID, num_bodies ) ref_tets = jnp.concatenate( [ref_centers[elements_ID][:, None, :], v_ref[elements]], axis=1 ) inv_ref_tet_shape = jax.vmap(compute_inverse_reference_shape_tet)( ref_tets ) initial_tet_volumes = jax.vmap(compute_tetra_volume)(ref_tets) return DeformableParticleContainer( elements=elements, edges=edges, element_adjacency=element_adjacency, element_adjacency_edges=adj_edges, elements_ID=jnp.asarray(elements_ID) if elements_ID is not None else None, edges_ID=jnp.asarray(edges_ID) if edges_ID is not None else None, element_adjacency_ID=( jnp.asarray(element_adjacency_ID) if element_adjacency_ID is not None else None ), num_bodies=num_bodies, initial_element_measures=( jnp.asarray(initial_element_measures) if initial_element_measures is not None else None ), initial_body_contents=( jnp.asarray(initial_body_contents) if initial_body_contents is not None else None ), initial_bending=( jnp.asarray(initial_bending) if initial_bending is not None else None ), initial_edge_lengths=( jnp.asarray(initial_edge_lengths) if initial_edge_lengths is not None else None ), inv_ref_shape=inv_ref_shape, inv_ref_tet_shape=inv_ref_tet_shape, initial_tet_volumes=initial_tet_volumes, em=jnp.asarray(em) if em is not None else None, ec=jnp.asarray(ec) if ec is not None else None, eb=jnp.asarray(eb) if eb is not None else None, el=jnp.asarray(el) if el is not None else None, gamma=jnp.asarray(gamma) if gamma is not None else None, lame_lambda=jnp.asarray(lame_lambda) if lame_lambda is not None else None, lame_mu=jnp.asarray(lame_mu) if lame_mu is not None else None, use_tetrahedral_svk=bool(use_tetrahedral_svk), )
[docs] @staticmethod @partial(jax.named_call, name="DeformableParticleContainer.merge") def merge( c1: DeformableParticleContainer, c2: DeformableParticleContainer, ) -> DeformableParticleContainer: r""" Merges two :class:`DeformableParticleContainer` instances. """ if c2.elements_ID is not None: c2.elements_ID += c1.num_bodies if c2.edges_ID is not None: c2.edges_ID += c1.num_bodies if c2.element_adjacency_ID is not None: c2.element_adjacency_ID += c1.num_bodies def cat(a: jax.Array, b: jax.Array) -> jax.Array: if isinstance(a, jax.Array) and isinstance(b, jax.Array): return jnp.concatenate((a, b), axis=0) elif a is None and b is None: return None else: return a if a is not None else b merged = jax.tree_util.tree_map(cat, c1, c2) merged.num_bodies = c1.num_bodies + c2.num_bodies return merged
[docs] @staticmethod @partial(jax.named_call, name="DeformableParticleContainer.add") def add( container: DeformableParticleContainer, vertices: ArrayLike, *, elements_ID: Optional[ArrayLike] = None, elements: Optional[ArrayLike] = None, initial_element_measures: Optional[ArrayLike] = None, element_adjacency_ID: Optional[ArrayLike] = None, element_adjacency: Optional[ArrayLike] = None, initial_bending: Optional[ArrayLike] = None, edges_ID: Optional[ArrayLike] = None, edges: Optional[ArrayLike] = None, initial_edge_lengths: Optional[ArrayLike] = None, initial_body_contents: Optional[ArrayLike] = None, em: Optional[ArrayLike] = None, ec: Optional[ArrayLike] = None, eb: Optional[ArrayLike] = None, el: Optional[ArrayLike] = None, gamma: Optional[ArrayLike] = None, lame_lambda: Optional[ArrayLike] = None, lame_mu: Optional[ArrayLike] = None, use_tetrahedral_svk: bool = False, ) -> DeformableParticleContainer: r""" Factory method to add bodies to a container. """ new_part = DeformableParticleContainer.create( vertices=vertices, elements=elements, edges=edges, elements_ID=elements_ID, initial_element_measures=initial_element_measures, element_adjacency_ID=element_adjacency_ID, element_adjacency=element_adjacency, initial_bending=initial_bending, edges_ID=edges_ID, initial_edge_lengths=initial_edge_lengths, initial_body_contents=initial_body_contents, em=em, ec=ec, eb=eb, el=el, gamma=gamma, lame_lambda=lame_lambda, lame_mu=lame_mu, use_tetrahedral_svk=use_tetrahedral_svk, ) return DeformableParticleContainer.merge(container, new_part)
[docs] @staticmethod def compute_potential_energy( pos: jax.Array, state: State, _system: System, container: DeformableParticleContainer, ) -> Tuple[jax.Array, Dict[str, jax.Array]]: vertices = pos dim = state.dim if dim == 3: compute_element_properties = compute_element_properties_3D elif dim == 2: compute_element_properties = compute_element_properties_2D else: raise ValueError( f"DeformableParticleContainer only supports 2D or 3D, got dim={dim}." ) idx_map = ( jnp.zeros((state.N,), dtype=int) .at[state.unique_ID] .set(jnp.arange(state.N)) ) E_element = jnp.array(0.0, dtype=float) E_content = jnp.array(0.0, dtype=float) E_gamma = jnp.array(0.0, dtype=float) E_bending = jnp.array(0.0, dtype=float) E_edge = jnp.array(0.0, dtype=float) E_strain = jnp.array(0.0, dtype=float) current_element_indices = idx_map[container.elements] element_normal, element_measure, partial_content = jax.vmap( compute_element_properties )(vertices[current_element_indices]) # Element elastic energy if ( container.em is not None and container.initial_element_measures is not None and container.elements_ID is not None ): # (M - M0)^2 / M0 diff = element_measure - container.initial_element_measures norm_strain_energy = jnp.square(diff) / container.initial_element_measures temp_elements = jax.ops.segment_sum( norm_strain_energy, container.elements_ID, num_segments=container.num_bodies, ) E_element = 0.5 * jnp.sum(container.em * temp_elements) # Content elastic energy if ( container.ec is not None and container.initial_body_contents is not None and container.elements_ID is not None ): content = jax.ops.segment_sum( partial_content, container.elements_ID, num_segments=container.num_bodies, ) # (V - V0)^2 / V0 diff = content - container.initial_body_contents norm_vol_energy = jnp.square(diff) / container.initial_body_contents E_content = 0.5 * jnp.sum(container.ec * norm_vol_energy) # Surface tension if container.gamma is not None and container.elements_ID is not None: element = jax.ops.segment_sum( element_measure, container.elements_ID, num_segments=container.num_bodies, ) E_gamma = -jnp.sum(container.gamma * element) # Bending energy if ( container.eb is not None and container.element_adjacency is not None and container.initial_bending is not None and container.element_adjacency_ID is not None ): n1 = element_normal[container.element_adjacency[:, 0]] n2 = element_normal[container.element_adjacency[:, 1]] cos = jnp.sum(n1 * n2, axis=-1) if dim == 3 and container.element_adjacency_edges is not None: hinge_idx = idx_map[container.element_adjacency_edges] # (A, 2) h_verts = vertices[hinge_idx] # (A, 2, 3) tangent_vec = h_verts[:, 1, :] - h_verts[:, 0, :] tangent = unit(tangent_vec) cross_prod = cross(n1, n2) sin = jnp.sum(cross_prod * tangent, axis=-1) else: sin = cross(n1, n2) sin = jnp.squeeze(sin) bending = jnp.atan2(sin, cos) diff = bending - container.initial_bending temp_angles = jax.ops.segment_sum( jnp.square(diff), container.element_adjacency_ID, num_segments=container.num_bodies, ) E_bending = 0.5 * jnp.sum(container.eb * temp_angles) / 2 # Edge length energy if ( container.el is not None and container.edges is not None and container.initial_edge_lengths is not None and container.edges_ID is not None ): current_edge_indices = idx_map[container.edges] edge_vecs = ( vertices[current_edge_indices[:, 0]] - vertices[current_edge_indices[:, 1]] ) edge_lengths2 = jnp.sum(edge_vecs * edge_vecs, axis=-1) edge_lengths = jnp.sqrt(edge_lengths2) diff = edge_lengths - container.initial_edge_lengths norm_edge_energy = jnp.square(diff) / container.initial_edge_lengths temp_edges = jax.ops.segment_sum( norm_edge_energy, container.edges_ID, num_segments=container.num_bodies, ) E_edge = 0.5 * jnp.sum(container.el * temp_edges) # StVK Strain Energy (Corrected per-element) if ( container.lame_lambda is not None and container.lame_mu is not None and container.elements is not None and container.elements_ID is not None ): if ( container.use_tetrahedral_svk and dim == 3 and container.inv_ref_tet_shape is not None and container.initial_tet_volumes is not None ): curr_verts = vertices[current_element_indices] curr_centers = compute_body_centers_from_elements( vertices, current_element_indices, container.elements_ID, container.num_bodies, ) curr_d_vecs = jnp.swapaxes( curr_verts - curr_centers[container.elements_ID][:, None, :], -1, -2 ) F = curr_d_vecs @ container.inv_ref_tet_shape C = jnp.swapaxes(F, -1, -2) @ F E = 0.5 * (C - jnp.eye(3)) tr_E = jnp.trace(E, axis1=-2, axis2=-1) tr_E2 = jnp.sum(E * E, axis=(-1, -2)) mu = container.lame_mu[container.elements_ID] lam = container.lame_lambda[container.elements_ID] W = mu * tr_E2 + 0.5 * lam * (tr_E**2) E_strain = jnp.sum(W * container.initial_tet_volumes) elif container.inv_ref_shape is not None: # Compute deformation using vectorized batch operations (M, dim, rank) # 1. Gather current vertices: (M, rank+1, dim) curr_verts = vertices[current_element_indices] # 2. Compute current edge vectors d: (M, dim, rank) # d_j = x_{j+1} - x_0 d_vecs = jnp.swapaxes(curr_verts[:, 1:] - curr_verts[:, 0:1], -1, -2) # 3. Compute Deformation Gradient F = d @ D_inv # d_vecs: (M, dim, rank), inv_ref_shape: (M, rank, rank) -> F: (M, dim, rank) F = d_vecs @ container.inv_ref_shape # 4. Compute Green-Lagrange Strain E = 0.5 * (F.T @ F - I) # C = F.T @ F (Right Cauchy-Green, pulled back to local 2D/1D ref manifold) C = jnp.swapaxes(F, -1, -2) @ F rank = container.inv_ref_shape.shape[-1] I = jnp.eye(rank) E = 0.5 * (C - I) # 5. Compute Invariants # tr(E): Trace of (M, rank, rank) -> (M,) tr_E = jnp.trace(E, axis1=-2, axis2=-1) # tr(E^2) = sum(E_ij * E_ji) -> sum(E_ij^2) for symmetric E tr_E2 = jnp.sum(E * E, axis=(-1, -2)) # 6. Map coefficients and compute Energy Density W mu = container.lame_mu[container.elements_ID] lam = container.lame_lambda[container.elements_ID] # Energy Density W (Energy per unit measure) # W = mu * tr(E^2) + 0.5 * lambda * tr(E)^2 W = mu * tr_E2 + 0.5 * lam * (tr_E**2) # 7. Total Strain Energy # E_total = sum(W_i * A0_i) # Note: Thickness is excluded as requested. E_strain = jnp.sum(W * container.initial_element_measures) aux = dict( E_element=E_element, E_content=E_content, E_gamma=E_gamma, E_bending=E_bending, E_edge=E_edge, E_strain=E_strain, ) return E_element + E_content + E_gamma + E_bending + E_edge + E_strain, aux
[docs] @staticmethod def create_force_function( container: DeformableParticleContainer, ) -> ForceFunction: force_fn, _ = DeformableParticleContainer.create_force_energy_functions( container ) return force_fn
[docs] @staticmethod def create_force_energy_functions( container: DeformableParticleContainer, ) -> Tuple[ForceFunction, EnergyFunction]: def force_function( pos: jax.Array, state: State, system: System ) -> Tuple[jax.Array, jax.Array]: energy_grad, _ = jax.grad( DeformableParticleContainer.compute_potential_energy, has_aux=True )(pos, state, system, container) return -energy_grad, jnp.zeros_like(state.torque) def energy_function(pos: jax.Array, state: State, system: System) -> jax.Array: total, _ = DeformableParticleContainer.compute_potential_energy( pos, state, system, container ) idx_map = ( jnp.zeros((state.N,), dtype=int) .at[state.unique_ID] .set(jnp.arange(state.N)) ) mask = jnp.zeros((state.N,), dtype=bool) if container.elements is not None: mask = mask.at[idx_map[container.elements].reshape(-1)].set(True) if container.edges is not None: mask = mask.at[idx_map[container.edges].reshape(-1)].set(True) count = jnp.sum(mask.astype(float)) count = jnp.where(count == 0.0, 1.0, count) return (total / count) * mask.astype(float) return force_function, energy_function
[docs] def angle_between_normals(n1: jax.Array, n2: jax.Array) -> jax.Array: r""" Computes the angle between two normals. Parameters ---------- n1 : jax.Array First normal vector. n2 : jax.Array Second normal vector. Returns ------- jax.Array Angle between the two normals in radians. """ cos = jnp.sum(n1 * n2, axis=-1) sin = cross(n1, n2) sin = jnp.sum(sin * sin, axis=-1) sin = jnp.sqrt(sin) return jnp.atan2(sin, cos)
[docs] def compute_element_properties_3D( simplex: jax.Array, ) -> Tuple[jax.Array, jax.Array, jax.Array]: r""" Computes normal, area, and signed partial volume for a single simplex. Parameters ---------- simplex : jax.Array Shape (3, 3) representing the coordinates of the 3 vertices. Returns ------- Tuple[jax.Array, jax.Array, jax.Array] (unit_normal, area, partial_volume) """ r1 = simplex[0] r2 = simplex[1] - simplex[0] r3 = simplex[2] - simplex[0] face_normal = cross(r2, r3) / 2 partial_vol = jnp.sum(face_normal * r1, axis=-1) / 3 area_face2 = jnp.sum(face_normal * face_normal, axis=-1) area_face = jnp.sqrt(area_face2) return ( face_normal / jnp.where(area_face == 0, 1, area_face), area_face, partial_vol, )
[docs] def compute_element_properties_2D( simplex: jax.Array, ) -> Tuple[jax.Array, jax.Array, jax.Array]: r""" Computes normal, length, and signed partial area for a single simplex. Parameters ---------- simplex : jax.Array Shape (2, 2) representing the coordinates of the 2 vertices. Returns ------- Tuple[jax.Array, jax.Array, jax.Array] (unit_normal, length, partial_area) """ r1 = simplex[0] r2 = simplex[1] edge = r2 - r1 length2 = jnp.sum(edge * edge, axis=-1) length = jnp.sqrt(length2) normal = jnp.array([edge[1], -edge[0]]) normal /= jnp.where(length == 0, 1.0, length) partial_area = 0.5 * (r1[0] * r2[1] - r1[1] * r2[0]) return normal, length, partial_area
[docs] def compute_inverse_reference_shape(simplex: jax.Array) -> jax.Array: """ Computes the inverse of the reference shape matrix (mapping local edge basis to reference coordinates). For a triangle (3 verts), constructs a local 2D basis and inverts the mapping [X2-X1, X3-X1]. For a segment (2 verts), inverts the length. Returns: (2, 2) matrix for triangles, or (1, 1) for segments. """ n_verts = simplex.shape[0] if n_verts == 3: # Triangle (Rank 2) d1 = simplex[1] - simplex[0] d2 = simplex[2] - simplex[0] # Construct local orthonormal basis (u, v) aligned with d1 u = unit(d1) # Project d2 onto plane (assuming 3D embedding) # For a triangle, they define a plane. # Gram-Schmidt for v v_temp = d2 - jnp.dot(d2, u) * u v = unit(v_temp) # Local coordinates of the edge vectors: # edge1_loc = ( |d1|, 0 ) # edge2_loc = ( d2.u, d2.v ) # Shape Matrix B = [edge1_loc, edge2_loc] (Columns are edge vectors) # B = [[|d1|, d2.u], # [0, d2.v]] b11 = jnp.linalg.norm(d1) b12 = jnp.dot(d2, u) b22 = jnp.dot(d2, v) # Analytic Inverse of Upper Triangular 2x2 # invB = 1/(b11*b22) * [[b22, -b12], [0, b11]] det = b11 * b22 # Avoid NaN in degenerate case safe_det = jnp.where(jnp.abs(det) < 1e-12, 1.0, det) inv_shape = (1.0 / safe_det) * jnp.array([[b22, -b12], [0.0, b11]]) return inv_shape elif n_verts == 2: # Segment (Rank 1) d1 = simplex[1] - simplex[0] l0 = jnp.linalg.norm(d1) safe_l0 = jnp.where(l0 < 1e-12, 1.0, l0) return jnp.array([[1.0 / safe_l0]]) else: # Fallback (Should not happen given fixed element shapes) return jnp.eye(n_verts - 1)
[docs] def compute_body_centers_from_elements( vertices: jax.Array, elements: jax.Array, elements_ID: jax.Array, num_bodies: int, ) -> jax.Array: node_mask = jnp.zeros((num_bodies, vertices.shape[0]), dtype=bool) body_rows = jnp.broadcast_to(elements_ID[:, None], elements.shape) node_mask = node_mask.at[body_rows, elements].set(True) counts = jnp.sum(node_mask.astype(vertices.dtype), axis=1) counts = jnp.where(counts == 0.0, 1.0, counts) centers = node_mask.astype(vertices.dtype) @ vertices return centers / counts[:, None]
[docs] def compute_inverse_reference_shape_tet(tet: jax.Array) -> jax.Array: d1 = tet[1] - tet[0] d2 = tet[2] - tet[0] d3 = tet[3] - tet[0] D = jnp.stack([d1, d2, d3], axis=-1) det = jnp.linalg.det(D) safe_D = jnp.where(jnp.abs(det) < 1e-12, jnp.eye(3), D) return jnp.linalg.inv(safe_D)
[docs] def compute_tetra_volume(tet: jax.Array) -> jax.Array: d1 = tet[1] - tet[0] d2 = tet[2] - tet[0] d3 = tet[3] - tet[0] return jnp.abs(jnp.dot(d1, cross(d2, d3))) / 6.0