jaxdem.utils#
Utility functions used to set up simulations and analyze the output.
- class jaxdem.utils.Quaternion(w: Array, xyz: Array)#
Bases:
objectQuaternion representation for 2D and 3D particle orientations.
A quaternion \(q\) is represented as a scalar part \(w\) and a vector part \(\vec{v}_{xyz} = (x, y, z)\):
\[q = w + x\mathbf{i} + y\mathbf{j} + z\mathbf{k}\]In 2D, the rotation axis is restricted to the z-axis: \(\vec{v}_{xyz} = (0, 0, z)\).
- w#
The scalar component of the quaternion. Shape is (…, N, 1).
- Type:
jax.Array
- xyz#
The vector components of the quaternion. Shape is (…, N, 3).
- Type:
jax.Array
- w: Array#
- xyz: Array#
- static create(w: Array | ndarray | bool | number | bool | int | float | complex | None = None, xyz: Array | ndarray | bool | number | bool | int | float | complex | None = None) Quaternion[source]#
Create a Quaternion instance.
- Parameters:
w (ArrayLike, optional) – The scalar component. If None, defaults to ones.
xyz (ArrayLike, optional) – The vector component. If None, defaults to zeros.
- Returns:
The created quaternion.
- Return type:
- static unit(q: Quaternion) Quaternion[source]#
Normalize a quaternion to have unit norm.
\[q_{unit} = \frac{q}{\|q\|} = \frac{q}{\sqrt{w^2 + x^2 + y^2 + z^2}}\]- Parameters:
q (Quaternion) – The quaternion to normalize.
- Returns:
The normalized unit quaternion.
- Return type:
- static conj(q: Quaternion) Quaternion[source]#
Compute the conjugate of a quaternion.
\[q^* = w - x\mathbf{i} - y\mathbf{j} - z\mathbf{k}\]- Parameters:
q (Quaternion) – The quaternion.
- Returns:
The conjugate quaternion.
- Return type:
- static inv(q: Quaternion) Quaternion[source]#
Compute the inverse of a quaternion.
For a unit quaternion, the inverse is equal to its conjugate:
\[q^{-1} = \frac{q^*}{\|q\|^2}\]- Parameters:
q (Quaternion) – The quaternion.
- Returns:
The inverse quaternion.
- Return type:
- static rotate(q: Quaternion, v: Array) Array[source]#
Rotates a vector \(\vec{v}\) from the body reference frame to the lab reference frame.
In 3D, the rotation of a vector \(\vec{v}\) by a unit quaternion \(q = (w, \vec{q}_{xyz})\) is:
\[\vec{v}' = \vec{v} + 2 w (\vec{q}_{xyz} \times \vec{v}) + 2 (\vec{q}_{xyz} \times (\vec{q}_{xyz} \times \vec{v}))\]In 2D, where rotation is restricted to the z-axis, the rotation by angle \(\theta\) (corresponding to quaternion components \(w = \cos(\theta/2)\) and \(q_z = \sin(\theta/2)\)) is:
\[\begin{split}x' &= x \cos(\theta) - y \sin(\theta) \\ y' &= x \sin(\theta) + y \cos(\theta)\end{split}\]- Parameters:
q (Quaternion) – The rotation quaternion.
v (jax.Array) – The vector to rotate. Shape is (…, dim).
- Returns:
The rotated vector in the lab frame. Shape is (…, dim).
- Return type:
jax.Array
- static rotate_back(q: Quaternion, v: Array) Array[source]#
Rotates a vector \(\vec{v}\) from the lab reference frame to the body reference frame.
This performs the inverse rotation using the quaternion conjugate \(q^* = (w, -\vec{q}_{xyz})\):
\[\vec{v}' = \vec{v} - 2 w (\vec{q}_{xyz} \times \vec{v}) + 2 (\vec{q}_{xyz} \times (\vec{q}_{xyz} \times \vec{v}))\]- Parameters:
q (Quaternion) – The rotation quaternion.
v (jax.Array) – The vector to rotate back. Shape is (…, dim).
- Returns:
The rotated vector in the body frame. Shape is (…, dim).
- Return type:
jax.Array
- jaxdem.utils.angle(v1: Array, v2: Array) Array[source]#
Angle from v1 -> v2 in \([0, \pi]\).
Calculates the unsigned angle between two vectors \(\vec{v}_1$ and :math:\)vec{v}_2$ using a numerically stable half-angle formula:
\[\begin{split}\hat{v}_1 &= \text{unit}(\vec{v}_1) \\ \hat{v}_2 &= \text{unit}(\vec{v}_2) \\ y &= \|\hat{v}_1 - \hat{v}_2\| \\ x &= \|\hat{v}_1 + \hat{v}_2\| \\ \theta &= 2 \cdot \text{atan2}(y, x)\end{split}\]- Parameters:
v1 (jax.Array) – First vector. Shape (…, dim).
v2 (jax.Array) – Second vector. Shape (…, dim).
- Returns:
Unsigned angle in radians.
- Return type:
jax.Array
- jaxdem.utils.angle_x(v1: Array) Array[source]#
Angle from v1 -> \(\hat{x}\) in \([0, \pi]\).
Calculates the unsigned angle of a vector :math:`vec{v}_1$ relative to the positive x-axis :math:`(1, 0, dots)$:
\[\begin{split}\hat{v}_1 &= \text{unit}(\vec{v}_1) \\ y &= \sqrt{2(1 - \hat{v}_{1,x})} \\ x &= \sqrt{2(1 + \hat{v}_{1,x})} \\ \theta &= 2 \cdot \text{atan2}(y, x)\end{split}\]- Parameters:
v1 (jax.Array) – The input vector. Shape (…, dim).
- Returns:
Unsigned angle in radians.
- Return type:
jax.Array
- jaxdem.utils.compute_clump_properties(state: State, mat_table: MaterialTable, n_samples: int = 50000) State[source]#
- jaxdem.utils.compute_energy(state: State, system: System) jax.Array[source]#
Compute the total mechanical energy of the system.
\[E_{total} = E_{pot, total} + E_{trans, total} + E_{rot, total}\]
- jaxdem.utils.compute_particle_volume(state: State) jax.Array[source]#
Return the total particle volume.
- jaxdem.utils.compute_potential_energy(state: State, system: System) jax.Array[source]#
Compute the total potential energy of the system. Energy is computed from the force models in the collider, and gravity and force functions that have potential energy associated with them in the force manager.
\[E_{pot, total} = \sum_{i} U(r_i)\]
- jaxdem.utils.compute_rotational_kinetic_energy(state: State) jax.Array[source]#
Compute the total rotational kinetic energy of the system.
\[E_{rot, total} = \sum_{i} \frac{1}{2} \vec{\omega}_i^T I_i \vec{\omega}_i\]- Parameters:
state (State) – The current state of the system.
- Returns:
The scalar sum of rotational kinetic energy across all particles.
- Return type:
jax.Array
- jaxdem.utils.compute_rotational_kinetic_energy_per_particle(state: State) jax.Array[source]#
Compute the rotational kinetic energy per particle.
\[E_{rot} = \frac{1}{2} \vec{\omega}^T I \vec{\omega}\]Notes
The energy of clump members is divided by the number of spheres in the clump.
- Parameters:
state (State) – The current state of the system containing inertia, orientation, and angular velocity.
- Returns:
An array containing the rotational kinetic energy for each particle.
- Return type:
jax.Array
- jaxdem.utils.compute_temperature(state: State, can_rotate: bool, subtract_drift: bool, k_B: float = 1.0) float[source]#
Compute the temperature for a state.
- Parameters:
state (State) – Current simulation state.
can_rotate (bool) – Whether to include rigid body rotations.
subtract_drift (bool) – Whether to remove center-of-mass drift (usually only relevant for small systems).
k_B (float, optional) – Boltzmann constant (default is 1.0).
- jaxdem.utils.compute_translational_kinetic_energy(state: State) jax.Array[source]#
Compute the total translational kinetic energy of the system.
\[E_{trans, total} = \sum_{i} \frac{1}{2} m_i |v_i|^2\]- Parameters:
state (State) – The current state of the system.
- Returns:
The scalar sum of translational kinetic energy across all particles.
- Return type:
jax.Array
- jaxdem.utils.compute_translational_kinetic_energy_per_particle(state: State) jax.Array[source]#
Compute the translational kinetic energy per particle.
\[E_{trans} = \frac{1}{2} m |v|^2\]Notes
The energy of clump members is divided by the number of spheres in the clump.
- Parameters:
state (State) – The current state of the system containing particle masses and velocities.
- Returns:
An array containing the translational kinetic energy for each particle.
- Return type:
jax.Array
- jaxdem.utils.count_clump_contacts(state: State, system: System, cutoff: float | None = None, max_neighbors: int | None = None) tuple[State, System, jax.Array][source]#
Count force-bearing clump-level neighbors per clump.
For every pair of clumps, sums the sphere-sphere contact forces into the clump-clump total and marks the pair as “in contact” iff the resulting total force has nonzero norm. Returns, per clump, the number of such neighbors. This matches the clump-pair convention used by
compute_clump_pair_friction(): two clumps count as in contact when their net contact interaction is nonzero.- Parameters:
- Returns:
state (State) – Potentially updated state.
system (System) – Potentially updated system.
contacts (jax.Array) –
(N_clumps,)integer array of force-bearing clump-level neighbors per clump.
- jaxdem.utils.count_vertex_contacts(state: State, system: System, cutoff: float | None = None, max_neighbors: int | None = None) tuple[State, System, jax.Array][source]#
Count force-bearing vertex-level contacts per clump.
For each clump, returns the number of sphere-sphere contacts with nonzero contact force that involve one of the clump’s vertex spheres. Each unique physical contact between clumps \(I\) and \(J\) increments both clump \(I\)’s and clump \(J\)’s count by one (the neighbor list lists the pair in both directions).
This is the contact-count quantity entering the Maxwell / isostaticity condition: the sum over clumps equals twice the number of distinct force-bearing vertex contacts, so the mean over clumps is the average coordination number \(Z\).
- Parameters:
- Returns:
state (State) – Potentially updated state.
system (System) – Potentially updated system.
contacts (jax.Array) –
(N_clumps,)integer array of force-bearing vertex contacts per clump.
- jaxdem.utils.cross(a: Array, b: Array) Array[source]#
Computes the cross product of two vectors, ‘a’ and ‘b’, along their last axis.
For 3D vectors (\(D=3\)), the result is a vector orthogonal to both ‘a’ and ‘b’:
\[\vec{c} = \vec{a} \times \vec{b} = (a_y b_z - a_z b_y) \mathbf{i} + (a_z b_x - a_x b_z) \mathbf{j} + (a_x b_y - a_y b_x) \mathbf{k}\]For 2D vectors (\(D=2\)), the result is the scalar magnitude of the 3D cross product:
\[c = a_x b_y - a_y b_x\]- Parameters:
a (jax.Array) – First vector. Shape (…, D), where D is the dimension (2 or 3).
b (jax.Array) – Second vector. Shape (…, D).
- Returns:
The cross product. - If D=3: shape is (…, 3). - If D=2: shape is (…, 1).
- Return type:
jax.Array
- Raises:
ValueError – If the last dimension is not 2 or 3, or if the dimensions of ‘a’ and ‘b’ do not match.
- jaxdem.utils.cross_3X3D_1X2D(w: Array, r: Array) Array[source]#
Computes the cross product of angular velocity vector (w) and a position vector (r), often used to find tangential velocity: v = w x r.
For 3D vectors, standard 3D cross product is used:
\[\vec{v} = \vec{w} \times \vec{r}\]For 2D vectors, angular velocity \(w\) is a scalar (z-component) and position \(\vec{r}\) is 2D:
\[\vec{v} = (-w \cdot r_y, \, w \cdot r_x)\]- Parameters:
w (jax.Array) – Angular velocity. Shape (…, 3) in 3D, or (…, 1) or (…) in 2D.
r (jax.Array) – Position vector. Shape (…, 3) or (…, 2).
- Returns:
Tangential velocity. Shape matches r.
- Return type:
jax.Array
- Raises:
ValueError – If dimensions are incompatible.
- jaxdem.utils.cross_lidar_2d(pos_a: jax.Array, pos_b: jax.Array, system: System, lidar_range: float, n_bins: int, max_neighbors: int) tuple[jax.Array, jax.Array, jax.Array][source]#
2-D LIDAR proximity and IDs from
pos_asensing targets inpos_b.Computes all-pairs displacements from
pos_atopos_b, bins by azimuthal angle, and returns per-bin proximity and closest target IDs.- Parameters:
pos_a (jax.Array) – Sensor positions, shape
(N_A, dim).pos_b (jax.Array) – Target positions, shape
(N_B, dim).system (System) – System configuration.
lidar_range (float) – Maximum detection range and reference distance for proximity.
n_bins (int) – Number of angular bins spanning \([-\pi, \pi)\).
max_neighbors (int) – Unused. Kept for backward compatibility.
- Returns:
(proximity, ids, overflow)whereproximityandidshave shape(N_A, n_bins)andoverflowis alwaysFalse. Empty bins getids = -1.- Return type:
Tuple[jax.Array, jax.Array, jax.Array]
Notes
Uses an all-pairs approach and does not invoke the collider. Returned
idsare indices intopos_bregardless of howpos_amay have been reordered by a cell-list collider.Examples
>>> prox, ids, overflow = cross_lidar_2d(agents, obstacles, system, ... lidar_range=5.0, n_bins=36, ... max_neighbors=64)
- jaxdem.utils.cross_lidar_3d(pos_a: jax.Array, pos_b: jax.Array, system: System, lidar_range: float, n_azimuth: int, n_elevation: int, max_neighbors: int) tuple[jax.Array, jax.Array, jax.Array][source]#
3-D LIDAR proximity and IDs from
pos_asensing targets inpos_b.Computes all-pairs displacements from
pos_atopos_b, bins on a spherical grid, and returns per-bin proximity and closest target IDs.- Parameters:
pos_a (jax.Array) – Sensor positions, shape
(N_A, 3).pos_b (jax.Array) – Target positions, shape
(N_B, 3).system (System) – System configuration.
lidar_range (float) – Maximum detection range and reference distance for proximity.
n_azimuth (int) – Number of azimuthal bins.
n_elevation (int) – Number of elevation bins.
max_neighbors (int) – Unused. Kept for backward compatibility.
- Returns:
(proximity, ids, overflow)whereproximityandidshave shape(N_A, n_azimuth * n_elevation)andoverflowis alwaysFalse. Empty bins getids = -1.- Return type:
Tuple[jax.Array, jax.Array, jax.Array]
Notes
Uses an all-pairs approach and does not invoke the collider. Returned
idsare indices intopos_bregardless of howpos_amay have been reordered by a cell-list collider.Examples
>>> prox, ids, overflow = cross_lidar_3d(agents, obstacles, system, ... lidar_range=5.0, n_azimuth=36, ... n_elevation=18, max_neighbors=64)
- jaxdem.utils.decode_callable(path: str) Callable[[...], Any][source]#
Import a callable from a dotted path string.
- jaxdem.utils.dot(a: Array, b: Array) Array[source]#
Dot product of vectors along the last axis.
\[c = \vec{a} \cdot \vec{b} = \sum_{i} a_i b_i\]- Parameters:
a (jax.Array) – First vector. Shape (…, D).
b (jax.Array) – Second vector. Shape (…, D).
- Returns:
The dot product. Shape (…).
- Return type:
jax.Array
- jaxdem.utils.encode_callable(fn: Callable[[...], Any]) str[source]#
Return a dotted path like ‘jax._src.nn.functions.gelu’.
- jaxdem.utils.env_step(env: Environment, model: Callable[..., Any], key: jax.Array, *, n: int = 1, skip_frames: int = 1, **kw: Any) tuple[Environment, jax.Array][source]#
Advance the environment n steps using actions from model.
- Parameters:
env (Environment) – Initial environment pytree (batchable).
model (Callable) – Callable with signature model(obs, key, **kw) -> action.
key (jax.Array) – JAX random key. The returned key is the advanced version that should be used for subsequent calls.
n (int) – Number of steps to perform.
skip_frames (int) – Number of additional frames to repeat the action.
**kw (Any) – Extra keyword arguments forwarded to model.
- Returns:
Updated environment and the advanced random key.
- Return type:
Tuple[Environment, jax.Array]
Examples
>>> env, key = env_step(env, model, key, n=10, objective=goal)
- jaxdem.utils.env_trajectory_rollout(env: Environment, model: Callable[..., Any], key: jax.Array, *, n: int, stride: int = 1, skip_frames: int = 1, **kw: Any) tuple[Environment, jax.Array, Environment][source]#
Roll out a trajectory by applying model in chunks of stride steps and collecting the environment after each chunk.
- Parameters:
env (Environment) – Initial environment pytree.
model (Callable) – Callable with signature model(obs, key, **kw) -> action.
key (jax.Array) – JAX random key. The returned key is the advanced version that should be used for subsequent calls.
n (int) – Number of chunks to roll out. Total internal steps = n * stride.
stride (int) – Steps per chunk between recorded snapshots.
skip_frames (int) – Number of additional frames to repeat the action.
**kw (Any) – Extra keyword arguments passed to model on every step.
- Returns:
Final environment, advanced random key, and a stacked pytree of environments with length n, each snapshot taken after a chunk of stride steps.
- Return type:
Tuple[Environment, jax.Array, Environment]
Examples
>>> env, key, traj = env_trajectory_rollout(env, model, key, n=100, stride=5, objective=goal)
- jaxdem.utils.generate_arclength_mesh(nv: int, N: int, dim: int = 2, aspect_ratio: Any = None, n_fine: int | None = None) Array[source]#
2D mesh with equal arc-length spacing along the ellipse / circle perimeter.
This is the closed-form analogue of the converged (
n_steps → ∞) Thomson ground state in 2D. On a circle the two coincide exactly — both give the regularnv-gon. On an ellipse, “uniform neighbor distance” (= equal arc-length spacing) is exactly theα → ∞packing-problem limit of Riesz-energy minimization and a very close approximation to theα = 1Thomson ground state. Use this when you’d otherwise run Thomson with a hugen_stepsbudget on a 2D particle — it arrives at essentially the same answer in one numerical integration, deterministically.The algorithm inverts the arc-length parameterization of the ellipse
(a cos t, b sin t)numerically: trapezoidal-rule the arclength integrand on a fine grid, then linearly interpolate the angles corresponding tonvequally spaced target arc lengths. Accuracy is controlled byn_fine(auto-sized tomax(10_000, 200 * nv)by default).- Parameters:
nv (int) – Number of surface vertices. Must be
>= 3.N (int) – Number of bodies (all identical copies; per-body orientation is handled downstream).
dim (int) – Must be 2. In 3D “uniform neighbor distance” is generically the multi-point Thomson problem — use
generate_thomson_mesh().aspect_ratio (None, scalar, or (2,) array-like) – Ellipse semi-axes in the usual “normalize to max=1” convention.
Nonegives a circle.n_fine (int, optional) – Grid resolution for the arc-length integral.
Noneauto-sizes.
- Returns:
Shape
(nv, 2)ifN == 1else(N, nv, 2).- Return type:
jax.Array
- jaxdem.utils.generate_faceted_mesh(nv: int, N: int, dim: int, n_facets: int = 6, aspect_ratio: Any = None) Array[source]#
Regular n-gon (2D) or icosahedron (3D) with vertex + surface-filler asperities.
Unlike the smooth-sphere family (thomson / icosphere / fibonacci), this mesh keeps the particle genuinely faceted: the shape is a polygon or polyhedron with sharp vertices and flat faces, and the asperities sit on those flat features rather than being projected to a circumscribed sphere.
Asperity layout
Vertex asperities are always placed at the corners of the shape (
n_facetsin 2D, 12 for the icosahedron in 3D).The remaining
nv - n_verticesasperities are distributed uniformly across the surface primitives — edges in 2D, triangular face interiors in 3D — to achieve an (approximately) uniform surface density. Ifnv - n_verticesis not divisible by the number of primitives, the first few get one extra asperity so the total exactly matchesnv.
In 2D, edge-interior asperities are evenly spaced along each edge (excluding the endpoints, which are already vertex asperities).
In 3D, face-interior asperities are quasi-uniformly sampled inside each triangular face via a barycentric-coordinate remap. They stay on the flat face — not projected to the circumscribing sphere — so the particle’s faceted character is preserved (asperities on faces sit closer to the center than vertex asperities).
- Parameters:
nv (int) – Total number of asperities. Must be
>= n_facetsin 2D and>= 12in 3D. Largernv→ higher surface density.N (int) – Number of bodies (all identical copies; per-body orientation is handled downstream).
dim (int) – Spatial dimension (2 or 3).
n_facets (int) – 2D only: number of polygon sides / vertices. Must be
>= 3. Ignored in 3D (the icosahedron has 12 vertices / 20 faces fixed).aspect_ratio (None, scalar, or (dim,) array-like) – Axis stretch in the usual “normalize to max=1” convention.
- Returns:
Shape
(nv, dim)ifN == 1else(N, nv, dim). Vertex asperities appear first in the array, followed by edge/face fillers grouped by primitive.- Return type:
jax.Array
- jaxdem.utils.generate_fibonacci_sphere_mesh(nv: int, N: int, dim: int, aspect_ratio: Any = None) Array[source]#
Generate
NFibonacci-sphere meshes (3D) or circles (2D).In 3D the points are laid out by a golden-angle spiral (the sunflower / Fibonacci lattice):
zis stratified in(-1, 1)and the azimuth advances by the golden angle on each step. The result is a near-optimal, deterministic, low-discrepancy covering of the sphere for anynv >= 1. It’s the “I want uniform points fast” default and a drop-in alternative to Thomson that skips the minimization.In 2D the output is
nvevenly-spaced points on the unit circle (there’s no non-trivial 1D analogue of the spiral).The mesh is deterministic, so all
Nbodies are identical copies; per-body random orientation is typically applied downstream bydistribute_bodies().
- jaxdem.utils.generate_helix_mesh(nv: int, N: int, dim: int, n_turns: float = 3.0, helix_radius: float = 0.3, aspect_ratio: Any = None) Array[source]#
Generate
Nhelical meshes (3D) or Archimedean spirals (2D).In 3D the points trace a right-handed helix along the
zaxis:nvpoints evenly spaced in the arc parameter, makingn_turnsfull turns fromz = -1toz = 1on a circle of radiushelix_radius. This gives chiral, rod-like bodies with a controllable pitch; good for studies of enantiomeric packing or helical-fiber clumps.In 2D the helix degenerates to an Archimedean spiral centered at the origin, with
nvpoints going from near the origin out to the unit circle overn_turnsturns.
- jaxdem.utils.generate_icosphere_mesh(nv: int, N: int, dim: int, aspect_ratio: Any = None) Array[source]#
Generate
Nicosphere meshes (3D) or regular polygons (2D).In 3D,
nvmust be10 * frequency**2 + 2for some integerfrequency >= 1(i.e. one of{12, 42, 92, 162, 252, ...}). Powers of two use recursive midpoint subdivision; other frequencies use direct triangular geodesic subdivision. Usegenerate_fibonacci_sphere_mesh()if you need an arbitrary vertex count on a sphere.In 2D,
nvcan be any integer and the output is a regularnv-gon on the unit circle.The mesh is deterministic, so all
Nbodies are identical copies; per-body random orientation is typically applied downstream bydistribute_bodies().
- jaxdem.utils.generate_thomson_mesh(nv: int, N: int, dim: int, alpha: float = 1.0, lr: float = 0.01, steps: int = 1000, aspect_ratio: Any = None, use_uniform_sampling: bool = True, batch_size: int | None = None, seed: int | None = None) tuple[Array, Array][source]#
Generate and minimize charges constrained to a hyper-ellipsoid surface.
- jaxdem.utils.generate_torus_mesh(nv: int, N: int, dim: int = 3, tube_ratio: float = 0.3, aspect_ratio: Any = None) Array[source]#
Generate
Ntorus surface meshes withnvquasi-uniform points each.The torus is parameterized by two radii: the major radius
R(center of the tube → center of the torus) and the minor radiusr(tube half-thickness).tube_ratiosetsrdirectly under the “longest axis has extent 1” convention:r = tube_ratioandR = 1 - tube_ratio, so the torus fits inx, y ∈ [-1, 1]andz ∈ [-r, r].Points are placed by stratified angular sampling around the major axis (
theta, evenly spaced) paired with a golden-ratio quasi-random phase around the tube (phi). This gives good 2D coverage of the torus surface for anynv, with a slight over-representation of the inner rim relative to the outer rim (exact-uniform sampling would require a(R + r cos(phi))area weighting, which we skip for simplicity).Useful for non-convex, genus-1 particles — e.g. studies of interlocking or linking in packings where non-convexity matters.
- jaxdem.utils.get_clump_rattler_ids(state: State, system: System, cutoff: float | None = None, max_neighbors: int | None = None, zc: int | None = None, check_contact_rank: bool = False, contact_rank_tol: float | None = None) tuple[State, System, jax.Array, jax.Array][source]#
Identify rattler clumps by iteratively removing under-coordinated clumps.
A clump is a rattler if its total vertex-contact count is below the coordination threshold zc. Optionally, clumps whose contacts do not span their rigid-body generalized force space are also treated as rattlers.
- Parameters:
state (State) – Current simulation state.
system (System) – System definition.
cutoff (float, optional) – Neighbor search cutoff distance.
max_neighbors (int, optional) – Maximum number of neighbors per particle.
zc (int, optional) – Minimum contact count. Defaults to
dim + angular_dof + 1— the mechanical stability threshold for a rigid body. A clump withdim + angular_dofor fewer force-bearing vertex contacts is unstable (the tangential softening of each contact at finite overlap can give the sub-hessian a negative eigenvalue), sodim + angular_dof + 1non-degenerate contacts are needed for a positive-definite local hessian.check_contact_rank (bool, optional) – If
True, also remove clumps whose active force-bearing contacts have generalized force rank belowdim + angular_dof.contact_rank_tol (float, optional) – Absolute tolerance passed to
jax.numpy.linalg.matrix_rankfor the optional generalized force rank check.
- Returns:
state (State) – Potentially updated state.
system (System) – Potentially updated system.
rattler_ids (jax.Array) – 1-D array of rattler clump IDs.
non_rattler_ids (jax.Array) – 1-D array of non-rattler clump IDs.
- jaxdem.utils.get_pair_forces_and_ids(state: State, system: System, cutoff: float | None = None, max_neighbors: int | None = None) tuple[State, System, jax.Array, jax.Array][source]#
Compute pairwise contact forces and their associated particle IDs.
- Parameters:
- Returns:
state (State) – Potentially updated state (after neighbor-list rebuild).
system (System) – Potentially updated system.
pair_ids (jax.Array) –
(M, 2)array of(i, j)sphere index pairs.forces (jax.Array) –
(M, dim)array of pairwise force vectors, one per pair.
- jaxdem.utils.get_sphere_rattler_ids(state: State, system: System, cutoff: float | None = None, max_neighbors: int | None = None, zc: int | None = None, check_contact_rank: bool = False, contact_rank_tol: float | None = None) tuple[State, System, jax.Array, jax.Array][source]#
Identify rattler spheres by iteratively removing under-coordinated particles.
- Parameters:
state (State) – Current simulation state.
system (System) – System definition.
cutoff (float, optional) – Neighbor search cutoff distance.
max_neighbors (int, optional) – Maximum number of neighbors per particle.
zc (int, optional) – Minimum contact count. Defaults to
dim + 1— the mechanical stability threshold for a point particle. A sphere withdimor fewer force-bearing contacts is unstable: the tangential softening of each contact at finite overlap can give the sub-hessian a negative eigenvalue, sodim + 1non-degenerate contacts are needed for a positive-definite local hessian.check_contact_rank (bool, optional) – If
True, also remove particles whose active force-bearing contacts have force-direction rank belowdim.contact_rank_tol (float, optional) – Absolute tolerance passed to
jax.numpy.linalg.matrix_rankfor the optional force-direction rank check.
- Returns:
state (State) – Potentially updated state.
system (System) – Potentially updated system.
rattler_ids (jax.Array) – 1-D array of rattler sphere indices.
non_rattler_ids (jax.Array) – 1-D array of non-rattler sphere indices.
- jaxdem.utils.grid_state(*, n_per_axis: Sequence[int], spacing: ArrayLike | float, radius: float = 1.0, mass: float = 1.0, jitter: float = 0.0, vel_range: ArrayLike | None = None, radius_range: ArrayLike | None = None, mass_range: ArrayLike | None = None, seed: int = 0, key: jax.Array | None = None) State[source]#
Create a state where particles sit on a rectangular lattice.
Random values can be sampled for particle radii, masses and velocities by specifying
*_rangearguments, which are interpreted as(min, max)bounds for a uniform distribution. When a range is not provided the correspondingradiusormassargument is used for all particles and the velocity components are sampled in[-1, 1].- Parameters:
n_per_axis (tuple[int]) – Number of spheres along each axis.
spacing (tuple[float] | float) – Centre-to-centre distance; scalar is broadcast to every axis.
radius (float) – Shared radius / mass for all particles when the corresponding range is not provided.
mass (float) – Shared radius / mass for all particles when the corresponding range is not provided.
jitter (float) – Add a uniform random offset in the range [-jitter, +jitter] for non-perfect grids (useful to break symmetry).
vel_range (ArrayLike | None) –
(min, max)values for the velocity components, radii and masses.radius_range (ArrayLike | None) –
(min, max)values for the velocity components, radii and masses.mass_range (ArrayLike | None) –
(min, max)values for the velocity components, radii and masses.seed (int) – Integer seed used when
keyis not supplied.key (PRNG key, optional) – Controls randomness. If
Nonea key will be created fromseed.
- Return type:
- jaxdem.utils.lidar_2d(state: State, system: System, lidar_range: float, n_bins: int, max_neighbors: int, sense_edges: bool = False) tuple[State, System, jax.Array, jax.Array, jax.Array][source]#
2-D LIDAR proximity readings and neighbor IDs.
For every particle in
statethe displacement vectors to all other particles are projected onto the \(xy\)-plane and binned by azimuthal angle inton_binsuniform sectors spanning \([-\pi, \pi)\). Each bin stores the proximity value and the index of the closest neighbor in that sector:\[p_k = \max(0,\; r_{\max} - d_{\min,k})\]This works identically for 2-D and 3-D position data; in the 3-D case the \(z\)-component of the displacement is simply ignored during binning while the full Euclidean distance is used for proximity.
- Parameters:
state (State) – Simulation state (positions, radii, etc.).
system (System) – System configuration including domain.
lidar_range (float) – Maximum detection range and reference distance for proximity.
n_bins (int) – Number of angular bins (rays) spanning \([-\pi, \pi)\).
max_neighbors (int) – Unused. Kept for backward compatibility.
sense_edges (bool, optional) – If
True, domain boundaries are included as proximity sources. Wall detections receive an ID of-1. Only meaningful for bounded domains. Default isFalse.
- Returns:
(state, system, proximity, ids, overflow)wherestateandsystemare unchanged,proximityandidshave shape(N, n_bins), andoverflowis alwaysFalse. Bins with no detection haveidsset to the particle’s own index.- Return type:
Notes
This function computes all-pairs displacements directly from
state.posand does not invoke the collider. The returnedidsare indices intostate.posin whatever order it has at call time, so results are correct regardless of whether a cell-list collider has reordered the state.Examples
>>> state, system, prox, ids, overflow = lidar_2d(state, system, ... lidar_range=5.0, n_bins=36, max_neighbors=64)
- jaxdem.utils.lidar_3d(state: State, system: System, lidar_range: float, n_azimuth: int, n_elevation: int, max_neighbors: int, sense_edges: bool = False) tuple[State, System, jax.Array, jax.Array, jax.Array][source]#
3-D LIDAR proximity readings and neighbor IDs.
Similar to
lidar_2d()but bins neighbors on a spherical grid defined byn_azimuthazimuthal sectors in \([-\pi, \pi)\) andn_elevationelevation bands in \([-\pi/2, \pi/2]\). The returned proximity and ID arrays have shape(N, n_azimuth * n_elevation)with flat indexingaz * n_elevation + el.- Parameters:
state (State) – Simulation state.
system (System) – System configuration including domain.
lidar_range (float) – Maximum detection range and reference distance for proximity.
n_azimuth (int) – Number of azimuthal bins.
n_elevation (int) – Number of elevation bins.
max_neighbors (int) – Unused. Kept for backward compatibility.
sense_edges (bool, optional) – If
True, domain boundaries are included as proximity sources. Wall detections receive an ID of-1. Default isFalse.
- Returns:
(state, system, proximity, ids, overflow)wherestateandsystemare unchanged,proximityandidshave shape(N, n_azimuth * n_elevation), andoverflowis alwaysFalse.- Return type:
Notes
Uses an all-pairs approach and does not invoke the collider. Returned
idsindex intostate.posin its current order, so results are correct regardless of collider-induced reordering.Examples
>>> state, system, prox, ids, overflow = lidar_3d(state, system, ... lidar_range=5.0, n_azimuth=36, n_elevation=18, max_neighbors=64)
- jaxdem.utils.load_legacy_dp(path: str, ref_pos: jax.Array | None = None, dim: int = 3) DeformableParticleModel[source]#
Load an old
DeformableParticleContainerh5 file and return a newDeformableParticleModel.- Parameters:
path (str) – Path to the
.h5file containing the saved DP container.ref_pos (jax.Array, optional) – Reference vertex positions, shape
(N, dim). Required for 3D to compute the neww_bbending normalization. You can obtain this from the legacy state:ref_pos = state.pos.dim (int) – Spatial dimension (2 or 3). Needed to choose the correct
w_bcomputation.
- Returns:
A new model instance with fields mapped from the old container.
- Return type:
- jaxdem.utils.load_legacy_simulation(state_path: str, system_path: str, dp_path: str | None = None) tuple[State, System][source]#
Load state, system, and (optionally) a deformable-particle container from old-format h5 files and wire them into a ready-to-use
(State, System)pair.When dp_path is given, the DP model is attached to the system via
system.bonded_force_modeland its force/energy functions are registered in the force manager.- Parameters:
state_path (str) – Path to the legacy State h5 file.
system_path (str) – Path to the legacy System h5 file.
dp_path (str, optional) – Path to the legacy DeformableParticleContainer h5 file.
- Returns:
state (State) – The loaded state with current field names.
system (System) – The loaded system, with bonded forces wired up if dp_path was given.
Example
from jaxdem.utils.load_legacy import load_legacy_simulation state, system = load_legacy_simulation( "old_data/state.h5", "old_data/system.h5", dp_path="old_data/dp.h5", )
- jaxdem.utils.load_legacy_state(path: str) State[source]#
Load a
Statesaved with the old field naming convention (angVel,clump_ID,deformable_ID,unique_ID).- Parameters:
path (str) – Path to the
.h5file containing the saved State.- Returns:
A new State constructed with the current field names.
- Return type:
- jaxdem.utils.load_legacy_system(path: str, state_shape: tuple[int, ...] | None = None) System[source]#
Load a
Systemsaved with the old schema (nobonded_force_modelorinteract_same_bond_idfields).The current
System.createfactory is used to produce a valid skeleton; scalar fields (dt,time,step_count,key) and nested component dataclasses that still exist (collider,domain,force_model,mat_table,force_manager, integrators) are overwritten from the file where the schemas still match.- Parameters:
path (str) – Path to the
.h5file containing the saved System.state_shape (tuple of int, optional) – Shape hint
(N, dim)passed toSystem.createto build default components. If None, inferred from the storedforce_manager/external_forceorforce_manager/external_force_comdataset.
- Returns:
A new System instance populated with as much data from the file as possible.
bonded_force_modeldefaults to None andinteract_same_bond_iddefaults to False.- Return type:
- jaxdem.utils.make_save_steps_linear(*, num_steps: int, save_freq: int, include_step0: bool = True) ndarray[source]#
- jaxdem.utils.make_save_steps_pseudolog(*, num_steps: int, reset_save_decade: int, min_save_decade: int, decade: int = 10, include_step0: bool = True, cap: int | None = None) ndarray[source]#
Pseudo-log schedule compatible with the BaseLogGroup logic.
Parameters are interpreted on the integer timestep grid 0..num_steps (inclusive).
- jaxdem.utils.norm(v: Array) Array[source]#
Norm (magnitude) of vectors along the last axis.
\[\|v\| = \sqrt{\sum_{i} v_i^2}\]- Parameters:
v (jax.Array) – Input vector. Shape (…, D).
- Returns:
The norm. Shape (…).
- Return type:
jax.Array
- jaxdem.utils.norm2(v: Array) Array[source]#
Squared norm of vectors along the last axis.
\[\|v\|^2 = \vec{v} \cdot \vec{v} = \sum_{i} v_i^2\]- Parameters:
v (jax.Array) – Input vector. Shape (…, D).
- Returns:
The squared norm. Shape (…).
- Return type:
jax.Array
- jaxdem.utils.quasistatic_compress_to_packing_fraction(state: State, system: System, target_phi: float, *, step: float = 0.001, phi_tolerance: float = 1e-10, pe_tol: float = 1e-16, pe_diff_tol: float = 1e-16, max_n_min_steps_per_outer: int = 1000000, max_n_outer_steps: int = 1000000, progress: bool = False) tuple[State, System, jax.Array, jax.Array][source]#
Quasi-statically compress (or decompress) toward
target_phi.Alternates
scale_to_packing_fraction()withminimize()in steps no larger thanstepin packing fraction. The final step is truncated so the target is hit exactly (withinphi_tolerance). Works in both directions: iftarget_phi > current_phithe box shrinks and particles are pushed closer; iftarget_phi < current_phithe box grows and the system relaxes.The state is minimized once up front, so a non-equilibrium input is safe. Above the jamming point the minimizer may exit with residual PE — the final PE is returned so the caller can detect this.
- Parameters:
state – Current state/system; any domain type that
scale_to_packing_fraction()supports is allowed.system – Current state/system; any domain type that
scale_to_packing_fraction()supports is allowed.target_phi – Target packing fraction.
step – Maximum magnitude of the per-outer-step increment in phi. Smaller values are more quasistatic (costlier); 1e-3 is a reasonable default for dense compressions.
phi_tolerance – Absolute tolerance on the terminal packing fraction.
pe_tol – Minimizer convergence tolerances.
pe_diff_tol – Minimizer convergence tolerances.
max_n_min_steps_per_outer – FIRE iterations allowed per minimization (per outer step).
max_n_outer_steps – Hard cap on outer iterations (safety net).
progress – If
Trueandtqdmis importable, wraps the outer loop in a progress bar. Otherwise silent.
- Returns:
final_phiiscompute_packing_fraction(state, system)at exit;final_peis the PE after the last minimization.- Return type:
- jaxdem.utils.random_state(*, N: int, dim: int, box_size: ArrayLike | None = None, box_anchor: ArrayLike | None = None, radius_range: ArrayLike | None = None, mass_range: ArrayLike | None = None, vel_range: ArrayLike | None = None, seed: int = 0) State[source]#
Generate N non-overlap-checked particles uniformly in an axis-aligned box.
- Parameters:
N – Number of particles.
dim – Spatial dimension (2 or 3).
box_size – Edge lengths of the domain.
box_anchor – Coordinate of the lower box corner.
radius_range – min and max values that the radius can take.
mass_range – min and max values that the radius can take.
vel_range – min and max values that the velocity components can take.
seed – Integer for reproducibility.
- Returns:
A fully-initialised State instance.
- Return type:
- jaxdem.utils.randomize_orientations(state: State, key: jax.Array) State[source]#
Randomize orientations for clumps (particles with repeated
state.clump_id), leaving spheres unchanged.
- jaxdem.utils.remove_rattlers(state: State, system: System, rattler_clump_ids: jax.Array) tuple[State, System][source]#
Remove all spheres belonging to rattler clumps and rebuild a matching system.
The state’s rattler spheres are dropped and its
clump_id/bond_id/unique_idarrays are re-indexed. The returned system is adataclasses.replace()copy of the input — so every field (domain,mat_table, integrators, user hooks,dt,time, and any future additions toSystem) is preserved by default — with only the state-size-dependent fields refreshed:collideris rebuilt via itsCreate()method for stateful colliders (NeighborList, cell lists, sweep-and-prune) and passed through unchanged for stateless ones (naive). Create’s config kwargs are recovered from the current collider via introspection (see_refresh_collider()).force_manageris rebuilt so that its per-particle buffers (external_force,external_force_com,external_torque) are sized for the reduced state.gravity,force_functions,energy_functions, andis_com_forceare preserved.
- Parameters:
- Returns:
state (State) – New state with rattler spheres removed and IDs re-indexed.
system (System) – New system with matching state shape.
Notes
DP / bonded force models. When
system.bonded_force_modelis aDeformableParticleModel, its topology arrays (elements,edges,element_adjacency, …) reference vertices byunique_id.remove_rattlers()re-indexesunique_idin the state but does not remap the bonded-model topology, because the correct behavior is ambiguous (an element that partially straddles removed vertices could be dropped, re-triangulated, or flagged). A warning is emitted when a bonded model is present; users with DP systems should handle the topology remap manually.Custom collider settings. Any collider Create-kwarg whose name is not a field on the current collider (e.g.
number_densityandsafety_factoronNeighborList) gets Create’s default value, not the value originally used. If you need to preserve such settings, rebuild the system yourself.
- jaxdem.utils.run_packing_fraction_protocol(state: State, system: System, *, strides: jax.Array, phi_at_frames: jax.Array, unroll: int = 2) tuple[State, System, tuple[State, System]][source]#
Integrate + scheduled box-rescale protocol, saving one frame per event.
For each frame
iinrange(K):Advance
strides[i]integration steps viaSystem.step().Rescale the periodic box to
phi_at_frames[i]viascale_to_packing_fraction().Record
(state, system)as the frame.
All dynamics — pairwise forces, bonded forces, thermostat integrators, neighbor-list rebuilds, etc. — are delegated to
system.step. That means temperature control, if desired, is set up atSystem.createtime by pickinglinear_integrator_type="verlet_rescaling"(deterministic velocity rescaling) or"langevin"(stochastic). This function then runs whatever integrator is on the System and adds the box-rescale schedule on top.- Parameters:
state – Initial state / system; the system’s integrators + collider + force model determine what happens between rescales.
system – Initial state / system; the system’s integrators + collider + force model determine what happens between rescales.
strides – 1D integer array of per-frame integration step counts. Length
Ksets the number of frames.phi_at_frames – 1D float array of target packing fractions, one per frame, applied after the frame’s integration strides. Must have the same length as
strides.unroll – Unroll factor for the outer
jax.lax.scan()(same semantics asSystem.trajectory_rollout()).
- Returns:
Final state/system and the per-frame trajectory, stacked along leading axis
K— same layout astrajectory_rollout’s defaultsave_fn.- Return type:
Notes
For a non-rescaling rollout, call
System.trajectory_rollout()directly.To drive from a
save_stepsarray produced bymake_save_steps_pseudolog()ormake_save_steps_linear(), passstrides=np.diff(save_steps)and a matchingphi_at_framesarray.
- jaxdem.utils.scale_to_packing_fraction(state: State, system: System, new_packing_fraction: float) tuple[State, System][source]#
- jaxdem.utils.scale_to_temperature(state: State, target_temperature: float, can_rotate: bool, subtract_drift: bool, k_B: float = 1.0) State[source]#
Scale the velocities of a state to a desired temperature state: State target_temperature: float - desired target temperature can_rotate: bool - whether to include the rigid body rotations subtract_drift: bool - whether to remove center of mass drift (usually only relevant for small systems) k_B: Optional[float] - boltzmanns constant, default is 1.0.
- jaxdem.utils.set_temperature(state: State, target_temperature: float, can_rotate: bool, subtract_drift: bool, seed: int | None = 0, k_B: float = 1.0) State[source]#
Randomize the velocities of a state according to a desired temperature.
- Parameters:
state (State) – Current simulation state.
target_temperature (float) – Desired target temperature.
can_rotate (bool) – Whether to include rigid body rotations.
subtract_drift (bool) – Whether to remove center-of-mass drift (usually only relevant for small systems).
seed (int, optional) – RNG seed.
k_B (float, optional) – Boltzmann constant (default is 1.0).
- jaxdem.utils.signed_angle(v1: Array, v2: Array) Array[source]#
Directional angle from v1 -> v2 around normal \(\hat{z}\) (right-hand rule), in \([-\pi, \pi)\).
Calculates the signed angle between two 2D vectors \(\vec{v}_1$ and :math:\)vec{v}_2$ using the dot product and the 2D cross product:
\[\begin{split}\hat{v}_1 &= \text{unit}(\vec{v}_1) \\ \hat{v}_2 &= \text{unit}(\vec{v}_2) \\ d &= \hat{v}_1 \cdot \hat{v}_2 \\ s &= \hat{v}_{1,x} \hat{v}_{2,y} - \hat{v}_{1,y} \hat{v}_{2,x} \\ \theta &= \text{atan2}(s, d)\end{split}\]- Parameters:
v1 (jnp.ndarray) – First vector. Shape (…, 2).
v2 (jnp.ndarray) – Second vector. Shape (…, 2).
- Returns:
Signed angle in radians.
- Return type:
jnp.ndarray
- jaxdem.utils.signed_angle_x(v1: Array) Array[source]#
Directional angle from v1 -> \(\hat{x}\) around normal \(\hat{z}\), in \((-\pi, \pi]\).
Calculates the signed angle of a 2D vector :math:`vec{v}_1$ relative to the positive x-axis :math:`(1, 0)$:
\[\theta = \text{atan2}(-v_{1,y}, v_{1,x})\]- Parameters:
v1 (jnp.ndarray) – The input vector. Shape (…, 2).
- Returns:
Signed angle in radians.
- Return type:
jnp.ndarray
- jaxdem.utils.unit(v: Array) Array[source]#
Normalize vectors to unit vectors along the last axis.
If the vector is zero, the result is zero.
\[\begin{split}\hat{v} = \begin{cases} \frac{\vec{v}}{\|v\|} & \text{if } \|v\| > 0 \\ \vec{0} & \text{otherwise} \end{cases}\end{split}\]- Parameters:
v (jax.Array) – Input vector. Shape (…, D).
- Returns:
Unit vector. Shape (…, D).
- Return type:
jax.Array
- jaxdem.utils.unit_and_norm(v: Array) tuple[Array, Array][source]#
Normalize vectors along the last axis and return both unit vectors and norms.
- Parameters:
v (jax.Array) – Input vector. Shape (…, D).
- Returns:
A tuple of (unit vectors, norms).
- Return type:
Tuple[jax.Array, jax.Array]
Modules
Utility functions to compute angles between vectors. |
|
Utility functions for analyzing particle contacts and identifying rattlers. |
|
Utility functions to assign radius dispersity. |
|
Functions for calculating the dynamical matrix (hessian of the potential energy w.r.t. |
|
Protocols that interleave integration with periodic state/system rescaling. |
|
Utility functions to handle environments and LIDAR sensor. |
|
Utility functions for creating Geometric Asperity particle states in 2D and 3D. |
|
Utility functions to initialize states with particles arranged in a grid. |
|
HDF5 save/load utilities (v2). |
|
Jamming routines. |
|
Utility functions to help with linear algebra. |
|
Adapter for loading HDF5 data saved with the pre-merge-2-27-26 branch. |
|
Asperity mesh generators for geometric-asperity (GA) particles. |
|
Utility functions for calculating and changing the packing fraction. |
|
Utility functions for creating states of GA particles (rigid/DP). |
|
Quaternion math utilities. |
|
Generates a random, energy-minimized configurations of spheres in 2D or 3D. |
|
Utility functions to randomly initialize states. |
|
Utility functions to randomize particle orientations. |
|
Utilities to generate step indices for trajectory logging. |
|
Monte-Carlo-style sampling of the surface of a clump particle with a tracer clump. |
|
Utility functions to compute thermodynamic quantities. |