jaxdem.utils#

Utility functions used to set up simulations and analyze the output.

class jaxdem.utils.Quaternion(w: Array, xyz: Array)#

Bases: object

Quaternion representing particle orientation (body frame to lab frame).

w: Array#
xyz: Array#
static create(w: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, xyz: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None) Quaternion[source]#
static unit(q: Quaternion) Quaternion[source]#
static conj(q: Quaternion) Quaternion[source]#
static inv(q: Quaternion) Quaternion[source]#
static rotate(q: Quaternion, v: Array) Array[source]#

Rotates a vector v from the body reference frame to the lab reference frame.

static rotate_back(q: Quaternion, v: Array) Array[source]#

Rotates a vector v from the lab reference frame to the body reference frame.

jaxdem.utils.angle(v1: Array, v2: Array) Array[source]#

Angle from v1 -> v2 in \([0, \pi]\).

jaxdem.utils.angle_x(v1: Array) Array[source]#

Angle from v1 -> \(\hat{x}\) in \([0, \pi]\).

jaxdem.utils.compute_clump_properties(state: State, mat_table: MaterialTable, n_samples: int = 50000) State[source]#
jaxdem.utils.compute_energy(state: State, system: System) jax.Array[source]#

Compute the total mechanical energy of the system.

\[E_{total} = E_{pot, total} + E_{trans, total} + E_{rot, total}\]
Parameters:
  • state (State) – The current state of the system.

  • system (System) – The system definition containing physics parameters and colliders.

Returns:

The total energy (scalar) of the system.

Return type:

jax.Array

jaxdem.utils.compute_packing_fraction(state: State, system: System) jax.Array[source]#
jaxdem.utils.compute_particle_volume(state: State) jax.Array[source]#

Return the total particle volume.

jaxdem.utils.compute_potential_energy(state: State, system: System) jax.Array[source]#

Compute the total potential energy of the system. Energy is computed from the force models in the collider, and gravity and force functions that have potential energy associated with them in the force manager.

\[E_{pot, total} = \sum_{i} U(r_i)\]
Parameters:
  • state (State) – The current state of the system.

  • system (System) – The system definition containing the collider.

Returns:

The scalar sum of potential energy across all particles.

Return type:

jax.Array

jaxdem.utils.compute_rotational_kinetic_energy(state: State) jax.Array[source]#

Compute the total rotational kinetic energy of the system.

\[E_{rot, total} = \sum_{i} \frac{1}{2} \vec{\omega}_i^T I_i \vec{\omega}_i\]
Parameters:

state (State) – The current state of the system.

Returns:

The scalar sum of rotational kinetic energy across all particles.

Return type:

jax.Array

jaxdem.utils.compute_rotational_kinetic_energy_per_particle(state: State) jax.Array[source]#

Compute the rotational kinetic energy per particle.

\[E_{rot} = \frac{1}{2} \vec{\omega}^T I \vec{\omega}\]

Notes

  • The energy of clump members is divided by the number of spheres in the clump.

Parameters:

state (State) – The current state of the system containing inertia, orientation, and angular velocity.

Returns:

An array containing the rotational kinetic energy for each particle.

Return type:

jax.Array

jaxdem.utils.compute_temperature(state: State, can_rotate: bool, subtract_drift: bool, k_B: float = 1.0) float[source]#

Compute the temperature for a state.

Parameters:
  • state (State) – Current simulation state.

  • can_rotate (bool) – Whether to include rigid body rotations.

  • subtract_drift (bool) – Whether to remove center-of-mass drift (usually only relevant for small systems).

  • k_B (float, optional) – Boltzmann constant (default is 1.0).

jaxdem.utils.compute_translational_kinetic_energy(state: State) jax.Array[source]#

Compute the total translational kinetic energy of the system.

\[E_{trans, total} = \sum_{i} \frac{1}{2} m_i |v_i|^2\]
Parameters:

state (State) – The current state of the system.

Returns:

The scalar sum of translational kinetic energy across all particles.

Return type:

jax.Array

jaxdem.utils.compute_translational_kinetic_energy_per_particle(state: State) jax.Array[source]#

Compute the translational kinetic energy per particle.

\[E_{trans} = \frac{1}{2} m |v|^2\]

Notes

  • The energy of clump members is divided by the number of spheres in the clump.

Parameters:

state (State) – The current state of the system containing particle masses and velocities.

Returns:

An array containing the translational kinetic energy for each particle.

Return type:

jax.Array

jaxdem.utils.control_nvt_density(state: State, system: System, *, n: int, rescale_every: int, temperature_target: float | None = None, temperature_delta: float | None = None, packing_fraction_target: float | None = None, packing_fraction_delta: float | None = None, can_rotate: bool = True, subtract_drift: bool = True, k_B: float = 1.0, temperature_schedule: ScheduleFn | None = None, density_schedule: ScheduleFn | None = None, pf_min: float = 1e-12, init_temp_seed: int = 0, unroll: int = 2) tuple[State, System][source]#

Runs a protocol for n integration steps, applying (optional) NVT rescaling and/or density rescaling whenever system.step_count is divisible by rescale_every.

Notes

  • rescale_every is in integration steps (System.step_count units).

  • Provide either target or delta for each controlled quantity (or neither to disable).

  • temperature_schedule / density_schedule must be JIT-static (passed as static_argnames).

jaxdem.utils.control_nvt_density_rollout(state: State, system: System, *, n: int, stride: int = 1, rescale_every: int = 1, temperature_target: float | None = None, temperature_delta: float | None = None, packing_fraction_target: float | None = None, packing_fraction_delta: float | None = None, can_rotate: bool = True, subtract_drift: bool = True, k_B: float = 1.0, temperature_schedule: ScheduleFn | None = None, density_schedule: ScheduleFn | None = None, pf_min: float = 1e-12, init_temp_seed: int = 0, unroll: int = 2) tuple[State, System, tuple[State, System]][source]#

Rollout variant (like System.trajectory_rollout), with globally-consistent schedules across the whole rollout.

jaxdem.utils.count_clump_contacts(state: State, system: System, cutoff: float | None = None, max_neighbors: int | None = None) tuple[State, System, jax.Array][source]#

Count unique clump-level contacts per clump.

Parameters:
  • state (State) – Current simulation state.

  • system (System) – System definition.

  • cutoff (float, optional) – Neighbor search cutoff distance.

  • max_neighbors (int, optional) – Maximum number of neighbors per particle.

Returns:

  • state (State) – Potentially updated state.

  • system (System) – Potentially updated system.

  • contacts (jax.Array) – (N_clumps,) array of unique clump contact counts per clump.

jaxdem.utils.count_vertex_contacts(state: State, system: System, cutoff: float | None = None, max_neighbors: int | None = None) tuple[State, System, jax.Array][source]#

Count vertex-level contacts per clump.

Parameters:
  • state (State) – Current simulation state.

  • system (System) – System definition.

  • cutoff (float, optional) – Neighbor search cutoff distance.

  • max_neighbors (int, optional) – Maximum number of neighbors per particle.

Returns:

  • state (State) – Potentially updated state.

  • system (System) – Potentially updated system.

  • contacts (jax.Array) – (N_clumps,) array of vertex contact counts per clump.

jaxdem.utils.cross(a: Array, b: Array) Array[source]#

Computes the cross product of two vectors, ‘a’ and ‘b’, along their last axis.

For 3D vectors (D=3), the result is a vector orthogonal to both ‘a’ and ‘b’. For 2D vectors (D=2), the result is the scalar magnitude of the 3D cross product when a third zero component is assumed, often interpreted as the signed area of the parallelogram spanned by the vectors.

Parameters:
  • a (JAX Array with shape (..., D), where D is the dimension (2 or 3).)

  • b (JAX Array with shape (..., D), where D must match a's dimension.)

Returns:

  • A JAX Array representing the cross product.

  • - If D=3 (shape is (…, 3).)

  • - If D=2 (shape is (…, 1) (a scalar wrapped in an array).)

Raises:

ValueError – If the last dimension (D) is not 2 or 3, or if the last dimensions of ‘a’ and ‘b’ do not match.:

jaxdem.utils.cross_3X3D_1X2D(w: Array, r: Array) Array[source]#

Computes the cross product of angular velocity vector (w) and a position vector (r), often used to find tangential velocity: v = w x r.

This function handles two scenarios based on the dimension of ‘r’:

  1. 3D Case (r.shape[-1] == 3): - w must be a 3D vector (w.shape[-1] == 3). - Computes the standard 3D cross product: w x r.

  2. 2D Case (r.shape[-1] == 2): - w is treated as a scalar (the z-component of angular velocity, w_z). - The computation is equivalent to: (0, 0, w_z) x (r_x, r_y, 0). - The result is the 2D tangential velocity vector (v_x, v_y) in the xy-plane.

Parameters:
  • w (JAX Array. In the 3D case, shape is (..., 3). In the 2D case, shape is (..., 1) or (...).)

  • r (JAX Array. Shape is (..., 3) or (..., 2).)

Returns:

  • A JAX Array representing the tangential velocity (w x r).

  • - If r is 3D, the output shape is (…, 3).

  • - If r is 2D, the output shape is (…, 2).

Raises:

ValueError – If r is not 2D or 3D, or if dimensions are incompatible.:

jaxdem.utils.cross_lidar_2d(pos_a: jax.Array, pos_b: jax.Array, system: System, lidar_range: float, n_bins: int, max_neighbors: int) tuple[jax.Array, jax.Array, jax.Array][source]#

2-D LIDAR proximity and IDs from pos_a sensing targets in pos_b.

Computes all-pairs displacements from pos_a to pos_b, bins by azimuthal angle, and returns per-bin proximity and closest target IDs.

Parameters:
  • pos_a (jax.Array) – Sensor positions, shape (N_A, dim).

  • pos_b (jax.Array) – Target positions, shape (N_B, dim).

  • system (System) – System configuration.

  • lidar_range (float) – Maximum detection range and reference distance for proximity.

  • n_bins (int) – Number of angular bins spanning \([-\pi, \pi)\).

  • max_neighbors (int) – Unused. Kept for backward compatibility.

Returns:

(proximity, ids, overflow) where proximity and ids have shape (N_A, n_bins) and overflow is always False. Empty bins get ids = -1.

Return type:

Tuple[jax.Array, jax.Array, jax.Array]

Notes

Uses an all-pairs approach and does not invoke the collider. Returned ids are indices into pos_b regardless of how pos_a may have been reordered by a cell-list collider.

Examples

>>> prox, ids, overflow = cross_lidar_2d(agents, obstacles, system,
...                                      lidar_range=5.0, n_bins=36,
...                                      max_neighbors=64)
jaxdem.utils.cross_lidar_3d(pos_a: jax.Array, pos_b: jax.Array, system: System, lidar_range: float, n_azimuth: int, n_elevation: int, max_neighbors: int) tuple[jax.Array, jax.Array, jax.Array][source]#

3-D LIDAR proximity and IDs from pos_a sensing targets in pos_b.

Computes all-pairs displacements from pos_a to pos_b, bins on a spherical grid, and returns per-bin proximity and closest target IDs.

Parameters:
  • pos_a (jax.Array) – Sensor positions, shape (N_A, 3).

  • pos_b (jax.Array) – Target positions, shape (N_B, 3).

  • system (System) – System configuration.

  • lidar_range (float) – Maximum detection range and reference distance for proximity.

  • n_azimuth (int) – Number of azimuthal bins.

  • n_elevation (int) – Number of elevation bins.

  • max_neighbors (int) – Unused. Kept for backward compatibility.

Returns:

(proximity, ids, overflow) where proximity and ids have shape (N_A, n_azimuth * n_elevation) and overflow is always False. Empty bins get ids = -1.

Return type:

Tuple[jax.Array, jax.Array, jax.Array]

Notes

Uses an all-pairs approach and does not invoke the collider. Returned ids are indices into pos_b regardless of how pos_a may have been reordered by a cell-list collider.

Examples

>>> prox, ids, overflow = cross_lidar_3d(agents, obstacles, system,
...                                      lidar_range=5.0, n_azimuth=36,
...                                      n_elevation=18, max_neighbors=64)
jaxdem.utils.decode_callable(path: str) Callable[[...], Any][source]#

Import a callable from a dotted path string.

jaxdem.utils.dot(a: Array, b: Array) Array[source]#

Dot product of vectors along the last axis.

a, b: (…, D) returns: (…), the dot product.

jaxdem.utils.encode_callable(fn: Callable[[...], Any]) str[source]#

Return a dotted path like ‘jax._src.nn.functions.gelu’.

jaxdem.utils.env_step(env: Environment, model: Callable[..., Any], key: jax.Array, *, n: int = 1, **kw: Any) tuple[Environment, jax.Array][source]#

Advance the environment n steps using actions from model.

Parameters:
  • env (Environment) – Initial environment pytree (batchable).

  • model (Callable) – Callable with signature model(obs, key, **kw) -> action.

  • key (jax.Array) – JAX random key. The returned key is the advanced version that should be used for subsequent calls.

  • n (int) – Number of steps to perform.

  • **kw (Any) – Extra keyword arguments forwarded to model.

Returns:

Updated environment and the advanced random key.

Return type:

Tuple[Environment, jax.Array]

Examples

>>> env, key = env_step(env, model, key, n=10, objective=goal)
jaxdem.utils.env_trajectory_rollout(env: Environment, model: Callable[..., Any], key: jax.Array, *, n: int, stride: int = 1, **kw: Any) tuple[Environment, jax.Array, Environment][source]#

Roll out a trajectory by applying model in chunks of stride steps and collecting the environment after each chunk.

Parameters:
  • env (Environment) – Initial environment pytree.

  • model (Callable) – Callable with signature model(obs, key, **kw) -> action.

  • key (jax.Array) – JAX random key. The returned key is the advanced version that should be used for subsequent calls.

  • n (int) – Number of chunks to roll out. Total internal steps = n * stride.

  • stride (int) – Steps per chunk between recorded snapshots.

  • **kw (Any) – Extra keyword arguments passed to model on every step.

Returns:

Final environment, advanced random key, and a stacked pytree of environments with length n, each snapshot taken after a chunk of stride steps.

Return type:

Tuple[Environment, jax.Array, Environment]

Examples

>>> env, key, traj = env_trajectory_rollout(env, model, key, n=100, stride=5, objective=goal)
jaxdem.utils.get_clump_rattler_ids(state: State, system: System, cutoff: float | None = None, max_neighbors: int | None = None, zc: int | None = None) tuple[State, System, jax.Array, jax.Array][source]#

Identify rattler clumps by iteratively removing under-coordinated clumps.

A clump is a rattler if its total vertex-contact count is below the coordination threshold zc.

Parameters:
  • state (State) – Current simulation state.

  • system (System) – System definition.

  • cutoff (float, optional) – Neighbor search cutoff distance.

  • max_neighbors (int, optional) – Maximum number of neighbors per particle.

  • zc (int, optional) – Minimum contact count. Defaults to dim + angular_dof + 1.

Returns:

  • state (State) – Potentially updated state.

  • system (System) – Potentially updated system.

  • rattler_ids (jax.Array) – 1-D array of rattler clump IDs.

  • non_rattler_ids (jax.Array) – 1-D array of non-rattler clump IDs.

jaxdem.utils.get_pair_forces_and_ids(state: State, system: System, cutoff: float | None = None, max_neighbors: int | None = None) tuple[State, System, jax.Array, jax.Array][source]#

Compute pairwise contact forces and their associated particle IDs.

Parameters:
  • state (State) – Current simulation state.

  • system (System) – System definition containing the collider and force model.

  • cutoff (float, optional) – Neighbor search cutoff distance. Defaults to 3 * max(rad).

  • max_neighbors (int, optional) – Maximum number of neighbors per particle (default 100).

Returns:

  • state (State) – Potentially updated state (after neighbor-list rebuild).

  • system (System) – Potentially updated system.

  • pair_ids (jax.Array) – (M, 2) array of (i, j) sphere index pairs.

  • forces (jax.Array) – (M, dim) array of pairwise force vectors, one per pair.

jaxdem.utils.get_sphere_rattler_ids(state: State, system: System, cutoff: float | None = None, max_neighbors: int | None = None, zc: int | None = None) tuple[State, System, jax.Array, jax.Array][source]#

Identify rattler spheres by iteratively removing under-coordinated particles.

Parameters:
  • state (State) – Current simulation state.

  • system (System) – System definition.

  • cutoff (float, optional) – Neighbor search cutoff distance.

  • max_neighbors (int, optional) – Maximum number of neighbors per particle.

  • zc (int, optional) – Minimum contact count. Defaults to dim + 1.

Returns:

  • state (State) – Potentially updated state.

  • system (System) – Potentially updated system.

  • rattler_ids (jax.Array) – 1-D array of rattler sphere indices.

  • non_rattler_ids (jax.Array) – 1-D array of non-rattler sphere indices.

jaxdem.utils.grid_state(*, n_per_axis: Sequence[int], spacing: ArrayLike | float, radius: float = 1.0, mass: float = 1.0, jitter: float = 0.0, vel_range: ArrayLike | None = None, radius_range: ArrayLike | None = None, mass_range: ArrayLike | None = None, seed: int = 0, key: jax.Array | None = None) State[source]#

Create a state where particles sit on a rectangular lattice.

Random values can be sampled for particle radii, masses and velocities by specifying *_range arguments, which are interpreted as (min, max) bounds for a uniform distribution. When a range is not provided the corresponding radius or mass argument is used for all particles and the velocity components are sampled in [-1, 1].

Parameters:
  • n_per_axis (tuple[int]) – Number of spheres along each axis.

  • spacing (tuple[float] | float) – Centre-to-centre distance; scalar is broadcast to every axis.

  • radius (float) – Shared radius / mass for all particles when the corresponding range is not provided.

  • mass (float) – Shared radius / mass for all particles when the corresponding range is not provided.

  • jitter (float) – Add a uniform random offset in the range [-jitter, +jitter] for non-perfect grids (useful to break symmetry).

  • vel_range (ArrayLike | None) – (min, max) values for the velocity components, radii and masses.

  • radius_range (ArrayLike | None) – (min, max) values for the velocity components, radii and masses.

  • mass_range (ArrayLike | None) – (min, max) values for the velocity components, radii and masses.

  • seed (int) – Integer seed used when key is not supplied.

  • key (PRNG key, optional) – Controls randomness. If None a key will be created from seed.

Return type:

State

jaxdem.utils.lidar_2d(state: State, system: System, lidar_range: float, n_bins: int, max_neighbors: int, sense_edges: bool = False) tuple[State, System, jax.Array, jax.Array, jax.Array][source]#

2-D LIDAR proximity readings and neighbor IDs.

For every particle in state the displacement vectors to all other particles are projected onto the \(xy\)-plane and binned by azimuthal angle into n_bins uniform sectors spanning \([-\pi, \pi)\). Each bin stores the proximity value and the index of the closest neighbor in that sector:

\[p_k = \max(0,\; r_{\max} - d_{\min,k})\]

This works identically for 2-D and 3-D position data; in the 3-D case the \(z\)-component of the displacement is simply ignored during binning while the full Euclidean distance is used for proximity.

Parameters:
  • state (State) – Simulation state (positions, radii, etc.).

  • system (System) – System configuration including domain.

  • lidar_range (float) – Maximum detection range and reference distance for proximity.

  • n_bins (int) – Number of angular bins (rays) spanning \([-\pi, \pi)\).

  • max_neighbors (int) – Unused. Kept for backward compatibility.

  • sense_edges (bool, optional) – If True, domain boundaries are included as proximity sources. Wall detections receive an ID of -1. Only meaningful for bounded domains. Default is False.

Returns:

(state, system, proximity, ids, overflow) where state and system are unchanged, proximity and ids have shape (N, n_bins), and overflow is always False. Bins with no detection have ids set to the particle’s own index.

Return type:

Tuple[State, System, jax.Array, jax.Array, jax.Array]

Notes

This function computes all-pairs displacements directly from state.pos and does not invoke the collider. The returned ids are indices into state.pos in whatever order it has at call time, so results are correct regardless of whether a cell-list collider has reordered the state.

Examples

>>> state, system, prox, ids, overflow = lidar_2d(state, system,
...     lidar_range=5.0, n_bins=36, max_neighbors=64)
jaxdem.utils.lidar_3d(state: State, system: System, lidar_range: float, n_azimuth: int, n_elevation: int, max_neighbors: int, sense_edges: bool = False) tuple[State, System, jax.Array, jax.Array, jax.Array][source]#

3-D LIDAR proximity readings and neighbor IDs.

Similar to lidar_2d() but bins neighbors on a spherical grid defined by n_azimuth azimuthal sectors in \([-\pi, \pi)\) and n_elevation elevation bands in \([-\pi/2, \pi/2]\). The returned proximity and ID arrays have shape (N, n_azimuth * n_elevation) with flat indexing az * n_elevation + el.

Parameters:
  • state (State) – Simulation state.

  • system (System) – System configuration including domain.

  • lidar_range (float) – Maximum detection range and reference distance for proximity.

  • n_azimuth (int) – Number of azimuthal bins.

  • n_elevation (int) – Number of elevation bins.

  • max_neighbors (int) – Unused. Kept for backward compatibility.

  • sense_edges (bool, optional) – If True, domain boundaries are included as proximity sources. Wall detections receive an ID of -1. Default is False.

Returns:

(state, system, proximity, ids, overflow) where state and system are unchanged, proximity and ids have shape (N, n_azimuth * n_elevation), and overflow is always False.

Return type:

Tuple[State, System, jax.Array, jax.Array, jax.Array]

Notes

Uses an all-pairs approach and does not invoke the collider. Returned ids index into state.pos in its current order, so results are correct regardless of collider-induced reordering.

Examples

>>> state, system, prox, ids, overflow = lidar_3d(state, system,
...     lidar_range=5.0, n_azimuth=36, n_elevation=18, max_neighbors=64)
jaxdem.utils.load_legacy_dp(path: str, ref_pos: jax.Array | None = None, dim: int = 3) DeformableParticleModel[source]#

Load an old DeformableParticleContainer h5 file and return a new DeformableParticleModel.

Parameters:
  • path (str) – Path to the .h5 file containing the saved DP container.

  • ref_pos (jax.Array, optional) – Reference vertex positions, shape (N, dim). Required for 3D to compute the new w_b bending normalization. You can obtain this from the legacy state: ref_pos = state.pos.

  • dim (int) – Spatial dimension (2 or 3). Needed to choose the correct w_b computation.

Returns:

A new model instance with fields mapped from the old container.

Return type:

DeformableParticleModel

jaxdem.utils.load_legacy_simulation(state_path: str, system_path: str, dp_path: str | None = None) tuple[State, System][source]#

Load state, system, and (optionally) a deformable-particle container from old-format h5 files and wire them into a ready-to-use (State, System) pair.

When dp_path is given, the DP model is attached to the system via system.bonded_force_model and its force/energy functions are registered in the force manager.

Parameters:
  • state_path (str) – Path to the legacy State h5 file.

  • system_path (str) – Path to the legacy System h5 file.

  • dp_path (str, optional) – Path to the legacy DeformableParticleContainer h5 file.

Returns:

  • state (State) – The loaded state with current field names.

  • system (System) – The loaded system, with bonded forces wired up if dp_path was given.

Example

from jaxdem.utils.load_legacy import load_legacy_simulation

state, system = load_legacy_simulation(
    "old_data/state.h5",
    "old_data/system.h5",
    dp_path="old_data/dp.h5",
)
jaxdem.utils.load_legacy_state(path: str) State[source]#

Load a State saved with the old field naming convention (angVel, clump_ID, deformable_ID, unique_ID).

Parameters:

path (str) – Path to the .h5 file containing the saved State.

Returns:

A new State constructed with the current field names.

Return type:

State

jaxdem.utils.load_legacy_system(path: str, state_shape: tuple[int, ...] | None = None) System[source]#

Load a System saved with the old schema (no bonded_force_model or interact_same_bond_id fields).

The current System.create factory is used to produce a valid skeleton; scalar fields (dt, time, step_count, key) and nested component dataclasses that still exist (collider, domain, force_model, mat_table, force_manager, integrators) are overwritten from the file where the schemas still match.

Parameters:
  • path (str) – Path to the .h5 file containing the saved System.

  • state_shape (tuple of int, optional) – Shape hint (N, dim) passed to System.create to build default components. If None, inferred from the stored force_manager/external_force or force_manager/external_force_com dataset.

Returns:

A new System instance populated with as much data from the file as possible. bonded_force_model defaults to None and interact_same_bond_id defaults to False.

Return type:

System

jaxdem.utils.make_save_steps_linear(*, num_steps: int, save_freq: int, include_step0: bool = True) ndarray[source]#
jaxdem.utils.make_save_steps_pseudolog(*, num_steps: int, reset_save_decade: int, min_save_decade: int, decade: int = 10, include_step0: bool = True, cap: int | None = None) ndarray[source]#

Pseudo-log schedule compatible with the BaseLogGroup logic.

Parameters are interpreted on the integer timestep grid 0..num_steps (inclusive).

jaxdem.utils.norm(v: Array) Array[source]#

Norm of vectors along the last axis.

v: (…, D) returns: (…), the norm.

jaxdem.utils.norm2(v: Array) Array[source]#

Squared norm of vectors along the last axis.

v: (…, D) returns: (…), the squared norm.

jaxdem.utils.random_state(*, N: int, dim: int, box_size: ArrayLike | None = None, box_anchor: ArrayLike | None = None, radius_range: ArrayLike | None = None, mass_range: ArrayLike | None = None, vel_range: ArrayLike | None = None, seed: int = 0) State[source]#

Generate N non-overlap-checked particles uniformly in an axis-aligned box.

Parameters:
  • N – Number of particles.

  • dim – Spatial dimension (2 or 3).

  • box_size – Edge lengths of the domain.

  • box_anchor – Coordinate of the lower box corner.

  • radius_range – min and max values that the radius can take.

  • mass_range – min and max values that the radius can take.

  • vel_range – min and max values that the velocity components can take.

  • seed – Integer for reproducibility.

Returns:

A fully-initialised State instance.

Return type:

State

jaxdem.utils.randomize_orientations(state: State, key: jax.Array) State[source]#

Randomize orientations for clumps (particles with repeated state.clump_id), leaving spheres unchanged.

jaxdem.utils.remove_rattlers_from_state(state: State, rattler_clump_ids: jax.Array) State[source]#

Remove all spheres belonging to rattler clumps and rebuild the state.

Parameters:
  • state (State) – Current simulation state.

  • rattler_clump_ids (jax.Array) – 1-D array of clump IDs to remove.

Returns:

A new state with rattler spheres removed and IDs re-indexed.

Return type:

State

jaxdem.utils.scale_to_packing_fraction(state: State, system: System, new_packing_fraction: float) tuple[State, System][source]#
jaxdem.utils.scale_to_temperature(state: State, target_temperature: float, can_rotate: bool, subtract_drift: bool, k_B: float = 1.0) State[source]#

Scale the velocities of a state to a desired temperature state: State target_temperature: float - desired target temperature can_rotate: bool - whether to include the rigid body rotations subtract_drift: bool - whether to remove center of mass drift (usually only relevant for small systems) k_B: Optional[float] - boltzmanns constant, default is 1.0.

jaxdem.utils.set_temperature(state: State, target_temperature: float, can_rotate: bool, subtract_drift: bool, seed: int | None = 0, k_B: float = 1.0) State[source]#

Randomize the velocities of a state according to a desired temperature.

Parameters:
  • state (State) – Current simulation state.

  • target_temperature (float) – Desired target temperature.

  • can_rotate (bool) – Whether to include rigid body rotations.

  • subtract_drift (bool) – Whether to remove center-of-mass drift (usually only relevant for small systems).

  • seed (int, optional) – RNG seed.

  • k_B (float, optional) – Boltzmann constant (default is 1.0).

jaxdem.utils.signed_angle(v1: Array, v2: Array) Array[source]#

Directional angle from v1 -> v2 around normal \(\hat{z}\) (right-hand rule), in \([-\pi, \pi)\).

jaxdem.utils.signed_angle_x(v1: Array) Array[source]#

Directional angle from v1 -> \(\hat{x}\) around normal \(\hat{z}\), in \((-\pi, \pi]\).

jaxdem.utils.unit(v: Array) Array[source]#

Normalize vectors along the last axis.

v: (…, D) returns: (…, D), unit vectors; zeros map to zeros.

jaxdem.utils.unit_and_norm(v: Array) tuple[Array, Array][source]#

Normalize vectors along the last axis and return the norm.

v: (…, D) returns: ((…, D), (…, 1)), unit vectors and their norms; zeros map to zeros.

Modules

angles

Utility functions to compute angles between vectors.

clumps

contacts

Utility functions for analyzing particle contacts and identifying rattlers.

dispersity

Utility functions to assign radius dispersity.

dynamicsRoutines

Jit-compiled routines for controlling temperature and density via basic rescaling.

environment

Utility functions to handle environments and LIDAR sensor.

geometricAsperityCreation

Utility functions for creating Geometric Asperity particle states in 2D and 3D.

gridState

Utility functions to initialize states with particles arranged in a grid.

h5

HDF5 save/load utilities (v2).

jamming

Jamming routines.

linalg

Utility functions to help with linear algebra.

load_legacy

Adapter for loading HDF5 data saved with the pre-merge-2-27-26 branch.

packingUtils

Utility functions for calculating and changing the packing fraction.

quaternion

Quaternion math utilities.

randomSphereConfiguration

Generates a random, energy-minimized configurations of spheres in 2D or 3D.

randomState

Utility functions to randomly initialize states.

randomizeOrientations

Utility functions to randomize particle orientations.

rollout_schedules

Utilities to generate step indices for trajectory logging.

serialization

thermal

Utility functions to compute thermodynamic quantities.