jaxdem.utils#

Utility functions used to set up simulations and analyze the output.

class jaxdem.utils.Quaternion(w: Array, xyz: Array)#

Bases: object

Quaternion representation for 2D and 3D particle orientations.

A quaternion \(q\) is represented as a scalar part \(w\) and a vector part \(\vec{v}_{xyz} = (x, y, z)\):

\[q = w + x\mathbf{i} + y\mathbf{j} + z\mathbf{k}\]

In 2D, the rotation axis is restricted to the z-axis: \(\vec{v}_{xyz} = (0, 0, z)\).

w#

The scalar component of the quaternion. Shape is (…, N, 1).

Type:

jax.Array

xyz#

The vector components of the quaternion. Shape is (…, N, 3).

Type:

jax.Array

w: Array#
xyz: Array#
static create(w: Array | ndarray | bool | number | bool | int | float | complex | None = None, xyz: Array | ndarray | bool | number | bool | int | float | complex | None = None) Quaternion[source]#

Create a Quaternion instance.

Parameters:
  • w (ArrayLike, optional) – The scalar component. If None, defaults to ones.

  • xyz (ArrayLike, optional) – The vector component. If None, defaults to zeros.

Returns:

The created quaternion.

Return type:

Quaternion

static unit(q: Quaternion) Quaternion[source]#

Normalize a quaternion to have unit norm.

\[q_{unit} = \frac{q}{\|q\|} = \frac{q}{\sqrt{w^2 + x^2 + y^2 + z^2}}\]
Parameters:

q (Quaternion) – The quaternion to normalize.

Returns:

The normalized unit quaternion.

Return type:

Quaternion

static conj(q: Quaternion) Quaternion[source]#

Compute the conjugate of a quaternion.

\[q^* = w - x\mathbf{i} - y\mathbf{j} - z\mathbf{k}\]
Parameters:

q (Quaternion) – The quaternion.

Returns:

The conjugate quaternion.

Return type:

Quaternion

static inv(q: Quaternion) Quaternion[source]#

Compute the inverse of a quaternion.

For a unit quaternion, the inverse is equal to its conjugate:

\[q^{-1} = \frac{q^*}{\|q\|^2}\]
Parameters:

q (Quaternion) – The quaternion.

Returns:

The inverse quaternion.

Return type:

Quaternion

static rotate(q: Quaternion, v: Array) Array[source]#

Rotates a vector \(\vec{v}\) from the body reference frame to the lab reference frame.

In 3D, the rotation of a vector \(\vec{v}\) by a unit quaternion \(q = (w, \vec{q}_{xyz})\) is:

\[\vec{v}' = \vec{v} + 2 w (\vec{q}_{xyz} \times \vec{v}) + 2 (\vec{q}_{xyz} \times (\vec{q}_{xyz} \times \vec{v}))\]

In 2D, where rotation is restricted to the z-axis, the rotation by angle \(\theta\) (corresponding to quaternion components \(w = \cos(\theta/2)\) and \(q_z = \sin(\theta/2)\)) is:

\[\begin{split}x' &= x \cos(\theta) - y \sin(\theta) \\ y' &= x \sin(\theta) + y \cos(\theta)\end{split}\]
Parameters:
  • q (Quaternion) – The rotation quaternion.

  • v (jax.Array) – The vector to rotate. Shape is (…, dim).

Returns:

The rotated vector in the lab frame. Shape is (…, dim).

Return type:

jax.Array

static rotate_back(q: Quaternion, v: Array) Array[source]#

Rotates a vector \(\vec{v}\) from the lab reference frame to the body reference frame.

This performs the inverse rotation using the quaternion conjugate \(q^* = (w, -\vec{q}_{xyz})\):

\[\vec{v}' = \vec{v} - 2 w (\vec{q}_{xyz} \times \vec{v}) + 2 (\vec{q}_{xyz} \times (\vec{q}_{xyz} \times \vec{v}))\]
Parameters:
  • q (Quaternion) – The rotation quaternion.

  • v (jax.Array) – The vector to rotate back. Shape is (…, dim).

Returns:

The rotated vector in the body frame. Shape is (…, dim).

Return type:

jax.Array

jaxdem.utils.angle(v1: Array, v2: Array) Array[source]#

Angle from v1 -> v2 in \([0, \pi]\).

Calculates the unsigned angle between two vectors \(\vec{v}_1$ and :math:\)vec{v}_2$ using a numerically stable half-angle formula:

\[\begin{split}\hat{v}_1 &= \text{unit}(\vec{v}_1) \\ \hat{v}_2 &= \text{unit}(\vec{v}_2) \\ y &= \|\hat{v}_1 - \hat{v}_2\| \\ x &= \|\hat{v}_1 + \hat{v}_2\| \\ \theta &= 2 \cdot \text{atan2}(y, x)\end{split}\]
Parameters:
  • v1 (jax.Array) – First vector. Shape (…, dim).

  • v2 (jax.Array) – Second vector. Shape (…, dim).

Returns:

Unsigned angle in radians.

Return type:

jax.Array

jaxdem.utils.angle_x(v1: Array) Array[source]#

Angle from v1 -> \(\hat{x}\) in \([0, \pi]\).

Calculates the unsigned angle of a vector :math:`vec{v}_1$ relative to the positive x-axis :math:`(1, 0, dots)$:

\[\begin{split}\hat{v}_1 &= \text{unit}(\vec{v}_1) \\ y &= \sqrt{2(1 - \hat{v}_{1,x})} \\ x &= \sqrt{2(1 + \hat{v}_{1,x})} \\ \theta &= 2 \cdot \text{atan2}(y, x)\end{split}\]
Parameters:

v1 (jax.Array) – The input vector. Shape (…, dim).

Returns:

Unsigned angle in radians.

Return type:

jax.Array

jaxdem.utils.compute_clump_properties(state: State, mat_table: MaterialTable, n_samples: int = 50000) State[source]#
jaxdem.utils.compute_energy(state: State, system: System) jax.Array[source]#

Compute the total mechanical energy of the system.

\[E_{total} = E_{pot, total} + E_{trans, total} + E_{rot, total}\]
Parameters:
  • state (State) – The current state of the system.

  • system (System) – The system definition containing physics parameters and colliders.

Returns:

The total energy (scalar) of the system.

Return type:

jax.Array

jaxdem.utils.compute_packing_fraction(state: State, system: System) jax.Array[source]#
jaxdem.utils.compute_particle_volume(state: State) jax.Array[source]#

Return the total particle volume.

jaxdem.utils.compute_potential_energy(state: State, system: System) jax.Array[source]#

Compute the total potential energy of the system. Energy is computed from the force models in the collider, and gravity and force functions that have potential energy associated with them in the force manager.

\[E_{pot, total} = \sum_{i} U(r_i)\]
Parameters:
  • state (State) – The current state of the system.

  • system (System) – The system definition containing the collider.

Returns:

The scalar sum of potential energy across all particles.

Return type:

jax.Array

jaxdem.utils.compute_rotational_kinetic_energy(state: State) jax.Array[source]#

Compute the total rotational kinetic energy of the system.

\[E_{rot, total} = \sum_{i} \frac{1}{2} \vec{\omega}_i^T I_i \vec{\omega}_i\]
Parameters:

state (State) – The current state of the system.

Returns:

The scalar sum of rotational kinetic energy across all particles.

Return type:

jax.Array

jaxdem.utils.compute_rotational_kinetic_energy_per_particle(state: State) jax.Array[source]#

Compute the rotational kinetic energy per particle.

\[E_{rot} = \frac{1}{2} \vec{\omega}^T I \vec{\omega}\]

Notes

  • The energy of clump members is divided by the number of spheres in the clump.

Parameters:

state (State) – The current state of the system containing inertia, orientation, and angular velocity.

Returns:

An array containing the rotational kinetic energy for each particle.

Return type:

jax.Array

jaxdem.utils.compute_temperature(state: State, can_rotate: bool, subtract_drift: bool, k_B: float = 1.0) float[source]#

Compute the temperature for a state.

Parameters:
  • state (State) – Current simulation state.

  • can_rotate (bool) – Whether to include rigid body rotations.

  • subtract_drift (bool) – Whether to remove center-of-mass drift (usually only relevant for small systems).

  • k_B (float, optional) – Boltzmann constant (default is 1.0).

jaxdem.utils.compute_translational_kinetic_energy(state: State) jax.Array[source]#

Compute the total translational kinetic energy of the system.

\[E_{trans, total} = \sum_{i} \frac{1}{2} m_i |v_i|^2\]
Parameters:

state (State) – The current state of the system.

Returns:

The scalar sum of translational kinetic energy across all particles.

Return type:

jax.Array

jaxdem.utils.compute_translational_kinetic_energy_per_particle(state: State) jax.Array[source]#

Compute the translational kinetic energy per particle.

\[E_{trans} = \frac{1}{2} m |v|^2\]

Notes

  • The energy of clump members is divided by the number of spheres in the clump.

Parameters:

state (State) – The current state of the system containing particle masses and velocities.

Returns:

An array containing the translational kinetic energy for each particle.

Return type:

jax.Array

jaxdem.utils.count_clump_contacts(state: State, system: System, cutoff: float | None = None, max_neighbors: int | None = None) tuple[State, System, jax.Array][source]#

Count force-bearing clump-level neighbors per clump.

For every pair of clumps, sums the sphere-sphere contact forces into the clump-clump total and marks the pair as “in contact” iff the resulting total force has nonzero norm. Returns, per clump, the number of such neighbors. This matches the clump-pair convention used by compute_clump_pair_friction(): two clumps count as in contact when their net contact interaction is nonzero.

Parameters:
  • state (State) – Current simulation state.

  • system (System) – System definition.

  • cutoff (float, optional) – Neighbor search cutoff distance.

  • max_neighbors (int, optional) – Maximum number of neighbors per particle.

Returns:

  • state (State) – Potentially updated state.

  • system (System) – Potentially updated system.

  • contacts (jax.Array) – (N_clumps,) integer array of force-bearing clump-level neighbors per clump.

jaxdem.utils.count_vertex_contacts(state: State, system: System, cutoff: float | None = None, max_neighbors: int | None = None) tuple[State, System, jax.Array][source]#

Count force-bearing vertex-level contacts per clump.

For each clump, returns the number of sphere-sphere contacts with nonzero contact force that involve one of the clump’s vertex spheres. Each unique physical contact between clumps \(I\) and \(J\) increments both clump \(I\)’s and clump \(J\)’s count by one (the neighbor list lists the pair in both directions).

This is the contact-count quantity entering the Maxwell / isostaticity condition: the sum over clumps equals twice the number of distinct force-bearing vertex contacts, so the mean over clumps is the average coordination number \(Z\).

Parameters:
  • state (State) – Current simulation state.

  • system (System) – System definition.

  • cutoff (float, optional) – Neighbor search cutoff distance.

  • max_neighbors (int, optional) – Maximum number of neighbors per particle.

Returns:

  • state (State) – Potentially updated state.

  • system (System) – Potentially updated system.

  • contacts (jax.Array) – (N_clumps,) integer array of force-bearing vertex contacts per clump.

jaxdem.utils.cross(a: Array, b: Array) Array[source]#

Computes the cross product of two vectors, ‘a’ and ‘b’, along their last axis.

For 3D vectors (\(D=3\)), the result is a vector orthogonal to both ‘a’ and ‘b’:

\[\vec{c} = \vec{a} \times \vec{b} = (a_y b_z - a_z b_y) \mathbf{i} + (a_z b_x - a_x b_z) \mathbf{j} + (a_x b_y - a_y b_x) \mathbf{k}\]

For 2D vectors (\(D=2\)), the result is the scalar magnitude of the 3D cross product:

\[c = a_x b_y - a_y b_x\]
Parameters:
  • a (jax.Array) – First vector. Shape (…, D), where D is the dimension (2 or 3).

  • b (jax.Array) – Second vector. Shape (…, D).

Returns:

The cross product. - If D=3: shape is (…, 3). - If D=2: shape is (…, 1).

Return type:

jax.Array

Raises:

ValueError – If the last dimension is not 2 or 3, or if the dimensions of ‘a’ and ‘b’ do not match.

jaxdem.utils.cross_3X3D_1X2D(w: Array, r: Array) Array[source]#

Computes the cross product of angular velocity vector (w) and a position vector (r), often used to find tangential velocity: v = w x r.

For 3D vectors, standard 3D cross product is used:

\[\vec{v} = \vec{w} \times \vec{r}\]

For 2D vectors, angular velocity \(w\) is a scalar (z-component) and position \(\vec{r}\) is 2D:

\[\vec{v} = (-w \cdot r_y, \, w \cdot r_x)\]
Parameters:
  • w (jax.Array) – Angular velocity. Shape (…, 3) in 3D, or (…, 1) or (…) in 2D.

  • r (jax.Array) – Position vector. Shape (…, 3) or (…, 2).

Returns:

Tangential velocity. Shape matches r.

Return type:

jax.Array

Raises:

ValueError – If dimensions are incompatible.

jaxdem.utils.cross_lidar_2d(pos_a: jax.Array, pos_b: jax.Array, system: System, lidar_range: float, n_bins: int, max_neighbors: int) tuple[jax.Array, jax.Array, jax.Array][source]#

2-D LIDAR proximity and IDs from pos_a sensing targets in pos_b.

Computes all-pairs displacements from pos_a to pos_b, bins by azimuthal angle, and returns per-bin proximity and closest target IDs.

Parameters:
  • pos_a (jax.Array) – Sensor positions, shape (N_A, dim).

  • pos_b (jax.Array) – Target positions, shape (N_B, dim).

  • system (System) – System configuration.

  • lidar_range (float) – Maximum detection range and reference distance for proximity.

  • n_bins (int) – Number of angular bins spanning \([-\pi, \pi)\).

  • max_neighbors (int) – Unused. Kept for backward compatibility.

Returns:

(proximity, ids, overflow) where proximity and ids have shape (N_A, n_bins) and overflow is always False. Empty bins get ids = -1.

Return type:

Tuple[jax.Array, jax.Array, jax.Array]

Notes

Uses an all-pairs approach and does not invoke the collider. Returned ids are indices into pos_b regardless of how pos_a may have been reordered by a cell-list collider.

Examples

>>> prox, ids, overflow = cross_lidar_2d(agents, obstacles, system,
...                                      lidar_range=5.0, n_bins=36,
...                                      max_neighbors=64)
jaxdem.utils.cross_lidar_3d(pos_a: jax.Array, pos_b: jax.Array, system: System, lidar_range: float, n_azimuth: int, n_elevation: int, max_neighbors: int) tuple[jax.Array, jax.Array, jax.Array][source]#

3-D LIDAR proximity and IDs from pos_a sensing targets in pos_b.

Computes all-pairs displacements from pos_a to pos_b, bins on a spherical grid, and returns per-bin proximity and closest target IDs.

Parameters:
  • pos_a (jax.Array) – Sensor positions, shape (N_A, 3).

  • pos_b (jax.Array) – Target positions, shape (N_B, 3).

  • system (System) – System configuration.

  • lidar_range (float) – Maximum detection range and reference distance for proximity.

  • n_azimuth (int) – Number of azimuthal bins.

  • n_elevation (int) – Number of elevation bins.

  • max_neighbors (int) – Unused. Kept for backward compatibility.

Returns:

(proximity, ids, overflow) where proximity and ids have shape (N_A, n_azimuth * n_elevation) and overflow is always False. Empty bins get ids = -1.

Return type:

Tuple[jax.Array, jax.Array, jax.Array]

Notes

Uses an all-pairs approach and does not invoke the collider. Returned ids are indices into pos_b regardless of how pos_a may have been reordered by a cell-list collider.

Examples

>>> prox, ids, overflow = cross_lidar_3d(agents, obstacles, system,
...                                      lidar_range=5.0, n_azimuth=36,
...                                      n_elevation=18, max_neighbors=64)
jaxdem.utils.decode_callable(path: str) Callable[[...], Any][source]#

Import a callable from a dotted path string.

jaxdem.utils.dot(a: Array, b: Array) Array[source]#

Dot product of vectors along the last axis.

\[c = \vec{a} \cdot \vec{b} = \sum_{i} a_i b_i\]
Parameters:
  • a (jax.Array) – First vector. Shape (…, D).

  • b (jax.Array) – Second vector. Shape (…, D).

Returns:

The dot product. Shape (…).

Return type:

jax.Array

jaxdem.utils.encode_callable(fn: Callable[[...], Any]) str[source]#

Return a dotted path like ‘jax._src.nn.functions.gelu’.

jaxdem.utils.env_step(env: Environment, model: Callable[..., Any], key: jax.Array, *, n: int = 1, skip_frames: int = 1, **kw: Any) tuple[Environment, jax.Array][source]#

Advance the environment n steps using actions from model.

Parameters:
  • env (Environment) – Initial environment pytree (batchable).

  • model (Callable) – Callable with signature model(obs, key, **kw) -> action.

  • key (jax.Array) – JAX random key. The returned key is the advanced version that should be used for subsequent calls.

  • n (int) – Number of steps to perform.

  • skip_frames (int) – Number of additional frames to repeat the action.

  • **kw (Any) – Extra keyword arguments forwarded to model.

Returns:

Updated environment and the advanced random key.

Return type:

Tuple[Environment, jax.Array]

Examples

>>> env, key = env_step(env, model, key, n=10, objective=goal)
jaxdem.utils.env_trajectory_rollout(env: Environment, model: Callable[..., Any], key: jax.Array, *, n: int, stride: int = 1, skip_frames: int = 1, **kw: Any) tuple[Environment, jax.Array, Environment][source]#

Roll out a trajectory by applying model in chunks of stride steps and collecting the environment after each chunk.

Parameters:
  • env (Environment) – Initial environment pytree.

  • model (Callable) – Callable with signature model(obs, key, **kw) -> action.

  • key (jax.Array) – JAX random key. The returned key is the advanced version that should be used for subsequent calls.

  • n (int) – Number of chunks to roll out. Total internal steps = n * stride.

  • stride (int) – Steps per chunk between recorded snapshots.

  • skip_frames (int) – Number of additional frames to repeat the action.

  • **kw (Any) – Extra keyword arguments passed to model on every step.

Returns:

Final environment, advanced random key, and a stacked pytree of environments with length n, each snapshot taken after a chunk of stride steps.

Return type:

Tuple[Environment, jax.Array, Environment]

Examples

>>> env, key, traj = env_trajectory_rollout(env, model, key, n=100, stride=5, objective=goal)
jaxdem.utils.generate_arclength_mesh(nv: int, N: int, dim: int = 2, aspect_ratio: Any = None, n_fine: int | None = None) Array[source]#

2D mesh with equal arc-length spacing along the ellipse / circle perimeter.

This is the closed-form analogue of the converged (n_steps ) Thomson ground state in 2D. On a circle the two coincide exactly — both give the regular nv-gon. On an ellipse, “uniform neighbor distance” (= equal arc-length spacing) is exactly the α packing-problem limit of Riesz-energy minimization and a very close approximation to the α = 1 Thomson ground state. Use this when you’d otherwise run Thomson with a huge n_steps budget on a 2D particle — it arrives at essentially the same answer in one numerical integration, deterministically.

The algorithm inverts the arc-length parameterization of the ellipse (a cos t, b sin t) numerically: trapezoidal-rule the arclength integrand on a fine grid, then linearly interpolate the angles corresponding to nv equally spaced target arc lengths. Accuracy is controlled by n_fine (auto-sized to max(10_000, 200 * nv) by default).

Parameters:
  • nv (int) – Number of surface vertices. Must be >= 3.

  • N (int) – Number of bodies (all identical copies; per-body orientation is handled downstream).

  • dim (int) – Must be 2. In 3D “uniform neighbor distance” is generically the multi-point Thomson problem — use generate_thomson_mesh().

  • aspect_ratio (None, scalar, or (2,) array-like) – Ellipse semi-axes in the usual “normalize to max=1” convention. None gives a circle.

  • n_fine (int, optional) – Grid resolution for the arc-length integral. None auto-sizes.

Returns:

Shape (nv, 2) if N == 1 else (N, nv, 2).

Return type:

jax.Array

jaxdem.utils.generate_faceted_mesh(nv: int, N: int, dim: int, n_facets: int = 6, aspect_ratio: Any = None) Array[source]#

Regular n-gon (2D) or icosahedron (3D) with vertex + surface-filler asperities.

Unlike the smooth-sphere family (thomson / icosphere / fibonacci), this mesh keeps the particle genuinely faceted: the shape is a polygon or polyhedron with sharp vertices and flat faces, and the asperities sit on those flat features rather than being projected to a circumscribed sphere.

Asperity layout

  • Vertex asperities are always placed at the corners of the shape (n_facets in 2D, 12 for the icosahedron in 3D).

  • The remaining nv - n_vertices asperities are distributed uniformly across the surface primitives — edges in 2D, triangular face interiors in 3D — to achieve an (approximately) uniform surface density. If nv - n_vertices is not divisible by the number of primitives, the first few get one extra asperity so the total exactly matches nv.

In 2D, edge-interior asperities are evenly spaced along each edge (excluding the endpoints, which are already vertex asperities).

In 3D, face-interior asperities are quasi-uniformly sampled inside each triangular face via a barycentric-coordinate remap. They stay on the flat face — not projected to the circumscribing sphere — so the particle’s faceted character is preserved (asperities on faces sit closer to the center than vertex asperities).

Parameters:
  • nv (int) – Total number of asperities. Must be >= n_facets in 2D and >= 12 in 3D. Larger nv → higher surface density.

  • N (int) – Number of bodies (all identical copies; per-body orientation is handled downstream).

  • dim (int) – Spatial dimension (2 or 3).

  • n_facets (int) – 2D only: number of polygon sides / vertices. Must be >= 3. Ignored in 3D (the icosahedron has 12 vertices / 20 faces fixed).

  • aspect_ratio (None, scalar, or (dim,) array-like) – Axis stretch in the usual “normalize to max=1” convention.

Returns:

Shape (nv, dim) if N == 1 else (N, nv, dim). Vertex asperities appear first in the array, followed by edge/face fillers grouped by primitive.

Return type:

jax.Array

jaxdem.utils.generate_fibonacci_sphere_mesh(nv: int, N: int, dim: int, aspect_ratio: Any = None) Array[source]#

Generate N Fibonacci-sphere meshes (3D) or circles (2D).

In 3D the points are laid out by a golden-angle spiral (the sunflower / Fibonacci lattice): z is stratified in (-1, 1) and the azimuth advances by the golden angle on each step. The result is a near-optimal, deterministic, low-discrepancy covering of the sphere for any nv >= 1. It’s the “I want uniform points fast” default and a drop-in alternative to Thomson that skips the minimization.

In 2D the output is nv evenly-spaced points on the unit circle (there’s no non-trivial 1D analogue of the spiral).

The mesh is deterministic, so all N bodies are identical copies; per-body random orientation is typically applied downstream by distribute_bodies().

jaxdem.utils.generate_helix_mesh(nv: int, N: int, dim: int, n_turns: float = 3.0, helix_radius: float = 0.3, aspect_ratio: Any = None) Array[source]#

Generate N helical meshes (3D) or Archimedean spirals (2D).

In 3D the points trace a right-handed helix along the z axis: nv points evenly spaced in the arc parameter, making n_turns full turns from z = -1 to z = 1 on a circle of radius helix_radius. This gives chiral, rod-like bodies with a controllable pitch; good for studies of enantiomeric packing or helical-fiber clumps.

In 2D the helix degenerates to an Archimedean spiral centered at the origin, with nv points going from near the origin out to the unit circle over n_turns turns.

jaxdem.utils.generate_icosphere_mesh(nv: int, N: int, dim: int, aspect_ratio: Any = None) Array[source]#

Generate N icosphere meshes (3D) or regular polygons (2D).

In 3D, nv must be 10 * frequency**2 + 2 for some integer frequency >= 1 (i.e. one of {12, 42, 92, 162, 252, ...}). Powers of two use recursive midpoint subdivision; other frequencies use direct triangular geodesic subdivision. Use generate_fibonacci_sphere_mesh() if you need an arbitrary vertex count on a sphere.

In 2D, nv can be any integer and the output is a regular nv-gon on the unit circle.

The mesh is deterministic, so all N bodies are identical copies; per-body random orientation is typically applied downstream by distribute_bodies().

jaxdem.utils.generate_thomson_mesh(nv: int, N: int, dim: int, alpha: float = 1.0, lr: float = 0.01, steps: int = 1000, aspect_ratio: Any = None, use_uniform_sampling: bool = True, batch_size: int | None = None, seed: int | None = None) tuple[Array, Array][source]#

Generate and minimize charges constrained to a hyper-ellipsoid surface.

jaxdem.utils.generate_torus_mesh(nv: int, N: int, dim: int = 3, tube_ratio: float = 0.3, aspect_ratio: Any = None) Array[source]#

Generate N torus surface meshes with nv quasi-uniform points each.

The torus is parameterized by two radii: the major radius R (center of the tube → center of the torus) and the minor radius r (tube half-thickness). tube_ratio sets r directly under the “longest axis has extent 1” convention: r = tube_ratio and R = 1 - tube_ratio, so the torus fits in x, y [-1, 1] and z [-r, r].

Points are placed by stratified angular sampling around the major axis (theta, evenly spaced) paired with a golden-ratio quasi-random phase around the tube (phi). This gives good 2D coverage of the torus surface for any nv, with a slight over-representation of the inner rim relative to the outer rim (exact-uniform sampling would require a (R + r cos(phi)) area weighting, which we skip for simplicity).

Useful for non-convex, genus-1 particles — e.g. studies of interlocking or linking in packings where non-convexity matters.

jaxdem.utils.get_clump_rattler_ids(state: State, system: System, cutoff: float | None = None, max_neighbors: int | None = None, zc: int | None = None, check_contact_rank: bool = False, contact_rank_tol: float | None = None) tuple[State, System, jax.Array, jax.Array][source]#

Identify rattler clumps by iteratively removing under-coordinated clumps.

A clump is a rattler if its total vertex-contact count is below the coordination threshold zc. Optionally, clumps whose contacts do not span their rigid-body generalized force space are also treated as rattlers.

Parameters:
  • state (State) – Current simulation state.

  • system (System) – System definition.

  • cutoff (float, optional) – Neighbor search cutoff distance.

  • max_neighbors (int, optional) – Maximum number of neighbors per particle.

  • zc (int, optional) – Minimum contact count. Defaults to dim + angular_dof + 1 — the mechanical stability threshold for a rigid body. A clump with dim + angular_dof or fewer force-bearing vertex contacts is unstable (the tangential softening of each contact at finite overlap can give the sub-hessian a negative eigenvalue), so dim + angular_dof + 1 non-degenerate contacts are needed for a positive-definite local hessian.

  • check_contact_rank (bool, optional) – If True, also remove clumps whose active force-bearing contacts have generalized force rank below dim + angular_dof.

  • contact_rank_tol (float, optional) – Absolute tolerance passed to jax.numpy.linalg.matrix_rank for the optional generalized force rank check.

Returns:

  • state (State) – Potentially updated state.

  • system (System) – Potentially updated system.

  • rattler_ids (jax.Array) – 1-D array of rattler clump IDs.

  • non_rattler_ids (jax.Array) – 1-D array of non-rattler clump IDs.

jaxdem.utils.get_pair_forces_and_ids(state: State, system: System, cutoff: float | None = None, max_neighbors: int | None = None) tuple[State, System, jax.Array, jax.Array][source]#

Compute pairwise contact forces and their associated particle IDs.

Parameters:
  • state (State) – Current simulation state.

  • system (System) – System definition containing the collider and force model.

  • cutoff (float, optional) – Neighbor search cutoff distance. Defaults to 3 * max(rad).

  • max_neighbors (int, optional) – Maximum number of neighbors per particle (default 100).

Returns:

  • state (State) – Potentially updated state (after neighbor-list rebuild).

  • system (System) – Potentially updated system.

  • pair_ids (jax.Array) – (M, 2) array of (i, j) sphere index pairs.

  • forces (jax.Array) – (M, dim) array of pairwise force vectors, one per pair.

jaxdem.utils.get_sphere_rattler_ids(state: State, system: System, cutoff: float | None = None, max_neighbors: int | None = None, zc: int | None = None, check_contact_rank: bool = False, contact_rank_tol: float | None = None) tuple[State, System, jax.Array, jax.Array][source]#

Identify rattler spheres by iteratively removing under-coordinated particles.

Parameters:
  • state (State) – Current simulation state.

  • system (System) – System definition.

  • cutoff (float, optional) – Neighbor search cutoff distance.

  • max_neighbors (int, optional) – Maximum number of neighbors per particle.

  • zc (int, optional) – Minimum contact count. Defaults to dim + 1 — the mechanical stability threshold for a point particle. A sphere with dim or fewer force-bearing contacts is unstable: the tangential softening of each contact at finite overlap can give the sub-hessian a negative eigenvalue, so dim + 1 non-degenerate contacts are needed for a positive-definite local hessian.

  • check_contact_rank (bool, optional) – If True, also remove particles whose active force-bearing contacts have force-direction rank below dim.

  • contact_rank_tol (float, optional) – Absolute tolerance passed to jax.numpy.linalg.matrix_rank for the optional force-direction rank check.

Returns:

  • state (State) – Potentially updated state.

  • system (System) – Potentially updated system.

  • rattler_ids (jax.Array) – 1-D array of rattler sphere indices.

  • non_rattler_ids (jax.Array) – 1-D array of non-rattler sphere indices.

jaxdem.utils.grid_state(*, n_per_axis: Sequence[int], spacing: ArrayLike | float, radius: float = 1.0, mass: float = 1.0, jitter: float = 0.0, vel_range: ArrayLike | None = None, radius_range: ArrayLike | None = None, mass_range: ArrayLike | None = None, seed: int = 0, key: jax.Array | None = None) State[source]#

Create a state where particles sit on a rectangular lattice.

Random values can be sampled for particle radii, masses and velocities by specifying *_range arguments, which are interpreted as (min, max) bounds for a uniform distribution. When a range is not provided the corresponding radius or mass argument is used for all particles and the velocity components are sampled in [-1, 1].

Parameters:
  • n_per_axis (tuple[int]) – Number of spheres along each axis.

  • spacing (tuple[float] | float) – Centre-to-centre distance; scalar is broadcast to every axis.

  • radius (float) – Shared radius / mass for all particles when the corresponding range is not provided.

  • mass (float) – Shared radius / mass for all particles when the corresponding range is not provided.

  • jitter (float) – Add a uniform random offset in the range [-jitter, +jitter] for non-perfect grids (useful to break symmetry).

  • vel_range (ArrayLike | None) – (min, max) values for the velocity components, radii and masses.

  • radius_range (ArrayLike | None) – (min, max) values for the velocity components, radii and masses.

  • mass_range (ArrayLike | None) – (min, max) values for the velocity components, radii and masses.

  • seed (int) – Integer seed used when key is not supplied.

  • key (PRNG key, optional) – Controls randomness. If None a key will be created from seed.

Return type:

State

jaxdem.utils.lidar_2d(state: State, system: System, lidar_range: float, n_bins: int, max_neighbors: int, sense_edges: bool = False) tuple[State, System, jax.Array, jax.Array, jax.Array][source]#

2-D LIDAR proximity readings and neighbor IDs.

For every particle in state the displacement vectors to all other particles are projected onto the \(xy\)-plane and binned by azimuthal angle into n_bins uniform sectors spanning \([-\pi, \pi)\). Each bin stores the proximity value and the index of the closest neighbor in that sector:

\[p_k = \max(0,\; r_{\max} - d_{\min,k})\]

This works identically for 2-D and 3-D position data; in the 3-D case the \(z\)-component of the displacement is simply ignored during binning while the full Euclidean distance is used for proximity.

Parameters:
  • state (State) – Simulation state (positions, radii, etc.).

  • system (System) – System configuration including domain.

  • lidar_range (float) – Maximum detection range and reference distance for proximity.

  • n_bins (int) – Number of angular bins (rays) spanning \([-\pi, \pi)\).

  • max_neighbors (int) – Unused. Kept for backward compatibility.

  • sense_edges (bool, optional) – If True, domain boundaries are included as proximity sources. Wall detections receive an ID of -1. Only meaningful for bounded domains. Default is False.

Returns:

(state, system, proximity, ids, overflow) where state and system are unchanged, proximity and ids have shape (N, n_bins), and overflow is always False. Bins with no detection have ids set to the particle’s own index.

Return type:

Tuple[State, System, jax.Array, jax.Array, jax.Array]

Notes

This function computes all-pairs displacements directly from state.pos and does not invoke the collider. The returned ids are indices into state.pos in whatever order it has at call time, so results are correct regardless of whether a cell-list collider has reordered the state.

Examples

>>> state, system, prox, ids, overflow = lidar_2d(state, system,
...     lidar_range=5.0, n_bins=36, max_neighbors=64)
jaxdem.utils.lidar_3d(state: State, system: System, lidar_range: float, n_azimuth: int, n_elevation: int, max_neighbors: int, sense_edges: bool = False) tuple[State, System, jax.Array, jax.Array, jax.Array][source]#

3-D LIDAR proximity readings and neighbor IDs.

Similar to lidar_2d() but bins neighbors on a spherical grid defined by n_azimuth azimuthal sectors in \([-\pi, \pi)\) and n_elevation elevation bands in \([-\pi/2, \pi/2]\). The returned proximity and ID arrays have shape (N, n_azimuth * n_elevation) with flat indexing az * n_elevation + el.

Parameters:
  • state (State) – Simulation state.

  • system (System) – System configuration including domain.

  • lidar_range (float) – Maximum detection range and reference distance for proximity.

  • n_azimuth (int) – Number of azimuthal bins.

  • n_elevation (int) – Number of elevation bins.

  • max_neighbors (int) – Unused. Kept for backward compatibility.

  • sense_edges (bool, optional) – If True, domain boundaries are included as proximity sources. Wall detections receive an ID of -1. Default is False.

Returns:

(state, system, proximity, ids, overflow) where state and system are unchanged, proximity and ids have shape (N, n_azimuth * n_elevation), and overflow is always False.

Return type:

Tuple[State, System, jax.Array, jax.Array, jax.Array]

Notes

Uses an all-pairs approach and does not invoke the collider. Returned ids index into state.pos in its current order, so results are correct regardless of collider-induced reordering.

Examples

>>> state, system, prox, ids, overflow = lidar_3d(state, system,
...     lidar_range=5.0, n_azimuth=36, n_elevation=18, max_neighbors=64)
jaxdem.utils.load_legacy_dp(path: str, ref_pos: jax.Array | None = None, dim: int = 3) DeformableParticleModel[source]#

Load an old DeformableParticleContainer h5 file and return a new DeformableParticleModel.

Parameters:
  • path (str) – Path to the .h5 file containing the saved DP container.

  • ref_pos (jax.Array, optional) – Reference vertex positions, shape (N, dim). Required for 3D to compute the new w_b bending normalization. You can obtain this from the legacy state: ref_pos = state.pos.

  • dim (int) – Spatial dimension (2 or 3). Needed to choose the correct w_b computation.

Returns:

A new model instance with fields mapped from the old container.

Return type:

DeformableParticleModel

jaxdem.utils.load_legacy_simulation(state_path: str, system_path: str, dp_path: str | None = None) tuple[State, System][source]#

Load state, system, and (optionally) a deformable-particle container from old-format h5 files and wire them into a ready-to-use (State, System) pair.

When dp_path is given, the DP model is attached to the system via system.bonded_force_model and its force/energy functions are registered in the force manager.

Parameters:
  • state_path (str) – Path to the legacy State h5 file.

  • system_path (str) – Path to the legacy System h5 file.

  • dp_path (str, optional) – Path to the legacy DeformableParticleContainer h5 file.

Returns:

  • state (State) – The loaded state with current field names.

  • system (System) – The loaded system, with bonded forces wired up if dp_path was given.

Example

from jaxdem.utils.load_legacy import load_legacy_simulation

state, system = load_legacy_simulation(
    "old_data/state.h5",
    "old_data/system.h5",
    dp_path="old_data/dp.h5",
)
jaxdem.utils.load_legacy_state(path: str) State[source]#

Load a State saved with the old field naming convention (angVel, clump_ID, deformable_ID, unique_ID).

Parameters:

path (str) – Path to the .h5 file containing the saved State.

Returns:

A new State constructed with the current field names.

Return type:

State

jaxdem.utils.load_legacy_system(path: str, state_shape: tuple[int, ...] | None = None) System[source]#

Load a System saved with the old schema (no bonded_force_model or interact_same_bond_id fields).

The current System.create factory is used to produce a valid skeleton; scalar fields (dt, time, step_count, key) and nested component dataclasses that still exist (collider, domain, force_model, mat_table, force_manager, integrators) are overwritten from the file where the schemas still match.

Parameters:
  • path (str) – Path to the .h5 file containing the saved System.

  • state_shape (tuple of int, optional) – Shape hint (N, dim) passed to System.create to build default components. If None, inferred from the stored force_manager/external_force or force_manager/external_force_com dataset.

Returns:

A new System instance populated with as much data from the file as possible. bonded_force_model defaults to None and interact_same_bond_id defaults to False.

Return type:

System

jaxdem.utils.make_save_steps_linear(*, num_steps: int, save_freq: int, include_step0: bool = True) ndarray[source]#
jaxdem.utils.make_save_steps_pseudolog(*, num_steps: int, reset_save_decade: int, min_save_decade: int, decade: int = 10, include_step0: bool = True, cap: int | None = None) ndarray[source]#

Pseudo-log schedule compatible with the BaseLogGroup logic.

Parameters are interpreted on the integer timestep grid 0..num_steps (inclusive).

jaxdem.utils.norm(v: Array) Array[source]#

Norm (magnitude) of vectors along the last axis.

\[\|v\| = \sqrt{\sum_{i} v_i^2}\]
Parameters:

v (jax.Array) – Input vector. Shape (…, D).

Returns:

The norm. Shape (…).

Return type:

jax.Array

jaxdem.utils.norm2(v: Array) Array[source]#

Squared norm of vectors along the last axis.

\[\|v\|^2 = \vec{v} \cdot \vec{v} = \sum_{i} v_i^2\]
Parameters:

v (jax.Array) – Input vector. Shape (…, D).

Returns:

The squared norm. Shape (…).

Return type:

jax.Array

jaxdem.utils.quasistatic_compress_to_packing_fraction(state: State, system: System, target_phi: float, *, step: float = 0.001, phi_tolerance: float = 1e-10, pe_tol: float = 1e-16, pe_diff_tol: float = 1e-16, max_n_min_steps_per_outer: int = 1000000, max_n_outer_steps: int = 1000000, progress: bool = False) tuple[State, System, jax.Array, jax.Array][source]#

Quasi-statically compress (or decompress) toward target_phi.

Alternates scale_to_packing_fraction() with minimize() in steps no larger than step in packing fraction. The final step is truncated so the target is hit exactly (within phi_tolerance). Works in both directions: if target_phi > current_phi the box shrinks and particles are pushed closer; if target_phi < current_phi the box grows and the system relaxes.

The state is minimized once up front, so a non-equilibrium input is safe. Above the jamming point the minimizer may exit with residual PE — the final PE is returned so the caller can detect this.

Parameters:
  • state – Current state/system; any domain type that scale_to_packing_fraction() supports is allowed.

  • system – Current state/system; any domain type that scale_to_packing_fraction() supports is allowed.

  • target_phi – Target packing fraction.

  • step – Maximum magnitude of the per-outer-step increment in phi. Smaller values are more quasistatic (costlier); 1e-3 is a reasonable default for dense compressions.

  • phi_tolerance – Absolute tolerance on the terminal packing fraction.

  • pe_tol – Minimizer convergence tolerances.

  • pe_diff_tol – Minimizer convergence tolerances.

  • max_n_min_steps_per_outer – FIRE iterations allowed per minimization (per outer step).

  • max_n_outer_steps – Hard cap on outer iterations (safety net).

  • progress – If True and tqdm is importable, wraps the outer loop in a progress bar. Otherwise silent.

Returns:

final_phi is compute_packing_fraction(state, system) at exit; final_pe is the PE after the last minimization.

Return type:

state, system, final_phi, final_pe

jaxdem.utils.random_state(*, N: int, dim: int, box_size: ArrayLike | None = None, box_anchor: ArrayLike | None = None, radius_range: ArrayLike | None = None, mass_range: ArrayLike | None = None, vel_range: ArrayLike | None = None, seed: int = 0) State[source]#

Generate N non-overlap-checked particles uniformly in an axis-aligned box.

Parameters:
  • N – Number of particles.

  • dim – Spatial dimension (2 or 3).

  • box_size – Edge lengths of the domain.

  • box_anchor – Coordinate of the lower box corner.

  • radius_range – min and max values that the radius can take.

  • mass_range – min and max values that the radius can take.

  • vel_range – min and max values that the velocity components can take.

  • seed – Integer for reproducibility.

Returns:

A fully-initialised State instance.

Return type:

State

jaxdem.utils.randomize_orientations(state: State, key: jax.Array) State[source]#

Randomize orientations for clumps (particles with repeated state.clump_id), leaving spheres unchanged.

jaxdem.utils.remove_rattlers(state: State, system: System, rattler_clump_ids: jax.Array) tuple[State, System][source]#

Remove all spheres belonging to rattler clumps and rebuild a matching system.

The state’s rattler spheres are dropped and its clump_id / bond_id / unique_id arrays are re-indexed. The returned system is a dataclasses.replace() copy of the input — so every field (domain, mat_table, integrators, user hooks, dt, time, and any future additions to System) is preserved by default — with only the state-size-dependent fields refreshed:

  • collider is rebuilt via its Create() method for stateful colliders (NeighborList, cell lists, sweep-and-prune) and passed through unchanged for stateless ones (naive). Create’s config kwargs are recovered from the current collider via introspection (see _refresh_collider()).

  • force_manager is rebuilt so that its per-particle buffers (external_force, external_force_com, external_torque) are sized for the reduced state. gravity, force_functions, energy_functions, and is_com_force are preserved.

Parameters:
  • state (State) – Current simulation state.

  • system (System) – Current system; all fields are carried over into the rebuilt system except the state-size-dependent ones listed above.

  • rattler_clump_ids (jax.Array) – 1-D array of clump IDs to remove.

Returns:

  • state (State) – New state with rattler spheres removed and IDs re-indexed.

  • system (System) – New system with matching state shape.

Notes

DP / bonded force models. When system.bonded_force_model is a DeformableParticleModel, its topology arrays (elements, edges, element_adjacency, …) reference vertices by unique_id. remove_rattlers() re-indexes unique_id in the state but does not remap the bonded-model topology, because the correct behavior is ambiguous (an element that partially straddles removed vertices could be dropped, re-triangulated, or flagged). A warning is emitted when a bonded model is present; users with DP systems should handle the topology remap manually.

Custom collider settings. Any collider Create-kwarg whose name is not a field on the current collider (e.g. number_density and safety_factor on NeighborList) gets Create’s default value, not the value originally used. If you need to preserve such settings, rebuild the system yourself.

jaxdem.utils.run_packing_fraction_protocol(state: State, system: System, *, strides: jax.Array, phi_at_frames: jax.Array, unroll: int = 2) tuple[State, System, tuple[State, System]][source]#

Integrate + scheduled box-rescale protocol, saving one frame per event.

For each frame i in range(K):

  1. Advance strides[i] integration steps via System.step().

  2. Rescale the periodic box to phi_at_frames[i] via scale_to_packing_fraction().

  3. Record (state, system) as the frame.

All dynamics — pairwise forces, bonded forces, thermostat integrators, neighbor-list rebuilds, etc. — are delegated to system.step. That means temperature control, if desired, is set up at System.create time by picking linear_integrator_type="verlet_rescaling" (deterministic velocity rescaling) or "langevin" (stochastic). This function then runs whatever integrator is on the System and adds the box-rescale schedule on top.

Parameters:
  • state – Initial state / system; the system’s integrators + collider + force model determine what happens between rescales.

  • system – Initial state / system; the system’s integrators + collider + force model determine what happens between rescales.

  • strides – 1D integer array of per-frame integration step counts. Length K sets the number of frames.

  • phi_at_frames – 1D float array of target packing fractions, one per frame, applied after the frame’s integration strides. Must have the same length as strides.

  • unroll – Unroll factor for the outer jax.lax.scan() (same semantics as System.trajectory_rollout()).

Returns:

Final state/system and the per-frame trajectory, stacked along leading axis K — same layout as trajectory_rollout’s default save_fn.

Return type:

(state, system, (traj_state, traj_system))

Notes

jaxdem.utils.scale_to_packing_fraction(state: State, system: System, new_packing_fraction: float) tuple[State, System][source]#
jaxdem.utils.scale_to_temperature(state: State, target_temperature: float, can_rotate: bool, subtract_drift: bool, k_B: float = 1.0) State[source]#

Scale the velocities of a state to a desired temperature state: State target_temperature: float - desired target temperature can_rotate: bool - whether to include the rigid body rotations subtract_drift: bool - whether to remove center of mass drift (usually only relevant for small systems) k_B: Optional[float] - boltzmanns constant, default is 1.0.

jaxdem.utils.set_temperature(state: State, target_temperature: float, can_rotate: bool, subtract_drift: bool, seed: int | None = 0, k_B: float = 1.0) State[source]#

Randomize the velocities of a state according to a desired temperature.

Parameters:
  • state (State) – Current simulation state.

  • target_temperature (float) – Desired target temperature.

  • can_rotate (bool) – Whether to include rigid body rotations.

  • subtract_drift (bool) – Whether to remove center-of-mass drift (usually only relevant for small systems).

  • seed (int, optional) – RNG seed.

  • k_B (float, optional) – Boltzmann constant (default is 1.0).

jaxdem.utils.signed_angle(v1: Array, v2: Array) Array[source]#

Directional angle from v1 -> v2 around normal \(\hat{z}\) (right-hand rule), in \([-\pi, \pi)\).

Calculates the signed angle between two 2D vectors \(\vec{v}_1$ and :math:\)vec{v}_2$ using the dot product and the 2D cross product:

\[\begin{split}\hat{v}_1 &= \text{unit}(\vec{v}_1) \\ \hat{v}_2 &= \text{unit}(\vec{v}_2) \\ d &= \hat{v}_1 \cdot \hat{v}_2 \\ s &= \hat{v}_{1,x} \hat{v}_{2,y} - \hat{v}_{1,y} \hat{v}_{2,x} \\ \theta &= \text{atan2}(s, d)\end{split}\]
Parameters:
  • v1 (jnp.ndarray) – First vector. Shape (…, 2).

  • v2 (jnp.ndarray) – Second vector. Shape (…, 2).

Returns:

Signed angle in radians.

Return type:

jnp.ndarray

jaxdem.utils.signed_angle_x(v1: Array) Array[source]#

Directional angle from v1 -> \(\hat{x}\) around normal \(\hat{z}\), in \((-\pi, \pi]\).

Calculates the signed angle of a 2D vector :math:`vec{v}_1$ relative to the positive x-axis :math:`(1, 0)$:

\[\theta = \text{atan2}(-v_{1,y}, v_{1,x})\]
Parameters:

v1 (jnp.ndarray) – The input vector. Shape (…, 2).

Returns:

Signed angle in radians.

Return type:

jnp.ndarray

jaxdem.utils.unit(v: Array) Array[source]#

Normalize vectors to unit vectors along the last axis.

If the vector is zero, the result is zero.

\[\begin{split}\hat{v} = \begin{cases} \frac{\vec{v}}{\|v\|} & \text{if } \|v\| > 0 \\ \vec{0} & \text{otherwise} \end{cases}\end{split}\]
Parameters:

v (jax.Array) – Input vector. Shape (…, D).

Returns:

Unit vector. Shape (…, D).

Return type:

jax.Array

jaxdem.utils.unit_and_norm(v: Array) tuple[Array, Array][source]#

Normalize vectors along the last axis and return both unit vectors and norms.

Parameters:

v (jax.Array) – Input vector. Shape (…, D).

Returns:

A tuple of (unit vectors, norms).

Return type:

Tuple[jax.Array, jax.Array]

Modules

angles

Utility functions to compute angles between vectors.

clumps

contacts

Utility functions for analyzing particle contacts and identifying rattlers.

dispersity

Utility functions to assign radius dispersity.

dynamicalMatrix

Functions for calculating the dynamical matrix (hessian of the potential energy w.r.t.

dynamicsRoutines

Protocols that interleave integration with periodic state/system rescaling.

environment

Utility functions to handle environments and LIDAR sensor.

geometricAsperityCreation

Utility functions for creating Geometric Asperity particle states in 2D and 3D.

gridState

Utility functions to initialize states with particles arranged in a grid.

h5

HDF5 save/load utilities (v2).

jamming

Jamming routines.

linalg

Utility functions to help with linear algebra.

load_legacy

Adapter for loading HDF5 data saved with the pre-merge-2-27-26 branch.

meshes

Asperity mesh generators for geometric-asperity (GA) particles.

packingUtils

Utility functions for calculating and changing the packing fraction.

particleCreation

Utility functions for creating states of GA particles (rigid/DP).

quaternion

Quaternion math utilities.

randomSphereConfiguration

Generates a random, energy-minimized configurations of spheres in 2D or 3D.

randomState

Utility functions to randomly initialize states.

randomizeOrientations

Utility functions to randomize particle orientations.

rollout_schedules

Utilities to generate step indices for trajectory logging.

serialization

surfaceProperties

Monte-Carlo-style sampling of the surface of a clump particle with a tracer clump.

thermal

Utility functions to compute thermodynamic quantities.