jaxdem.utils#

Utility functions used to set up simulations and analyze the output.

jaxdem.utils.unit(v: Array) Array[source][source]#

Normalize vectors along the last axis. v: (…, D) returns: (…, D), unit vectors; zeros map to zeros.

jaxdem.utils.cross_3X3D_1X2D(w: Array, r: Array) Array[source][source]#

Computes the cross product of angular velocity vector (w) and a position vector (r), often used to find tangential velocity: v = w x r.

This function handles two scenarios based on the dimension of ‘r’:

  1. 3D Case (r.shape[-1] == 3): - w must be a 3D vector (w.shape[-1] == 3). - Computes the standard 3D cross product: w x r.

  2. 2D Case (r.shape[-1] == 2): - w is treated as a scalar (the z-component of angular velocity, w_z). - The computation is equivalent to: (0, 0, w_z) x (r_x, r_y, 0). - The result is the 2D tangential velocity vector (v_x, v_y) in the xy-plane.

Parameters:
  • w (JAX Array. In the 3D case, shape is (..., 3). In the 2D case, shape is (..., 1) or (...).)

  • r (JAX Array. Shape is (..., 3) or (..., 2).)

Returns:

  • A JAX Array representing the tangential velocity (w x r).

  • - If r is 3D, the output shape is (…, 3).

  • - If r is 2D, the output shape is (…, 2).

Raises:

ValueError – If r is not 2D or 3D, or if dimensions are incompatible.

jaxdem.utils.signed_angle(v1: Array, v2: Array) Array[source][source]#

Directional angle from v1 -> v2 around normal \(\hat{z}\) (right-hand rule), in \([-\pi, \pi)\).

jaxdem.utils.signed_angle_x(v1: Array) Array[source][source]#

Directional angle from v1 -> \(\hat{x}\) around normal \(\hat{z}\), in \((-\pi, \pi]\).

jaxdem.utils.angle(v1: Array, v2: Array) Array[source][source]#

angle from v1 -> v2 in \([0, \pi]\)

jaxdem.utils.angle_x(v1: Array) Array[source][source]#

angle from v1 -> \(\hat{x}\) in \([0, \pi]\)

jaxdem.utils.grid_state(*, n_per_axis: Sequence[int], spacing: ArrayLike | float, radius: float = 1.0, mass: float = 1.0, jitter: float = 0.0, vel_range: ArrayLike | None = None, radius_range: ArrayLike | None = None, mass_range: ArrayLike | None = None, seed: int = 0, key: jax.Array | None = None) State[source][source]#

Create a state where particles sit on a rectangular lattice.

Random values can be sampled for particle radii, masses and velocities by specifying *_range arguments, which are interpreted as (min, max) bounds for a uniform distribution. When a range is not provided the corresponding radius or mass argument is used for all particles and the velocity components are sampled in [-1, 1].

Parameters:
  • n_per_axis (tuple[int]) – Number of spheres along each axis.

  • spacing (tuple[float] | float) – Centre-to-centre distance; scalar is broadcast to every axis.

  • radius (float) – Shared radius / mass for all particles when the corresponding range is not provided.

  • mass (float) – Shared radius / mass for all particles when the corresponding range is not provided.

  • jitter (float) – Add a uniform random offset in the range [-jitter, +jitter] for non-perfect grids (useful to break symmetry).

  • vel_range (ArrayLike | None) – (min, max) values for the velocity components, radii and masses.

  • radius_range (ArrayLike | None) – (min, max) values for the velocity components, radii and masses.

  • mass_range (ArrayLike | None) – (min, max) values for the velocity components, radii and masses.

  • seed (int) – Integer seed used when key is not supplied.

  • key (PRNG key, optional) – Controls randomness. If None a key will be created from seed.

Return type:

State

jaxdem.utils.random_state(*, N: int, dim: int, box_size: ArrayLike | None = None, box_anchor: ArrayLike | None = None, radius_range: ArrayLike | None = None, mass_range: ArrayLike | None = None, vel_range: ArrayLike | None = None, seed: int = 0) State[source][source]#

Generate N non-overlap-checked particles uniformly in an axis-aligned box.

Parameters:
  • N – Number of particles.

  • dim – Spatial dimension (2 or 3).

  • box_size – Edge lengths of the domain.

  • box_anchor – Coordinate of the lower box corner.

  • radius_range – min and max values that the radius can take.

  • mass_range – min and max values that the radius can take.

  • vel_range – min and max values that the velocity components can take.

  • seed – Integer for reproducibility.

Returns:

A fully-initialised State instance.

Return type:

State

jaxdem.utils.encode_callable(fn: Callable[[...], Any]) str[source][source]#

Return a dotted path like ‘jax._src.nn.functions.gelu’.

jaxdem.utils.decode_callable(path: str) Callable[[...], Any][source][source]#

Import a callable from a dotted path string.

jaxdem.utils.env_step(env: Environment, model: Callable[..., Any], key: jax.Array, *, n: int = 1, **kw: Any) Environment[source][source]#

Advance the environment n steps using actions from model.

Parameters:
  • env (Environment) – Initial environment pytree (batchable).

  • model (Callable) – Callable with signature model(obs, key, **kw) -> action.

  • n (int) – Number of steps to perform.

  • **kw (Any) – Extra keyword arguments forwarded to model.

Returns:

Environment after n steps.

Return type:

Environment

Examples

>>> env = env_step(env, model, n=10, objective=goal)
jaxdem.utils.env_trajectory_rollout(env: Environment, model: Callable[..., Any], key: jax.Array, *, n: int, stride: int = 1, **kw: Any) Tuple[Environment, Environment][source][source]#

Roll out a trajectory by applying model in chunks of stride steps and collecting the environment after each chunk.

Parameters:
  • env (Environment) – Initial environment pytree.

  • model (Callable) – Callable with signature model(obs, key, **kw) -> action.

  • n (int) – Number of chunks to roll out. Total internal steps = n * stride.

  • stride (int) – Steps per chunk between recorded snapshots.

  • **kw (Any) – Extra keyword arguments passed to model on every step.

Returns:

  • Environment – Environment after n * stride steps.

  • Environment – Stacked pytree of environments with length n, each snapshot taken after a chunk of stride steps.

Examples

>>> env, traj = env_trajectory_rollout(env, model, n=100, stride=5, objective=goal)
jaxdem.utils.lidar(env: Environment) jax.Array[source][source]#
class jaxdem.utils.Quaternion(w: Array, xyz: Array)[source]#

Bases: object

Quaternion representing particle orientation (body frame to lab frame).

w: Array#
xyz: Array#
static create(w: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, xyz: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None) Quaternion[source][source]#
static unit(q: Quaternion) Quaternion[source][source]#
static conj(q: Quaternion) Quaternion[source][source]#
static inv(q: Quaternion) Quaternion[source][source]#
static rotate(q: Quaternion, v: Array) Array[source][source]#

Rotates a vector v from the body reference frame to the lab reference frame.

static rotate_back(q: Quaternion, v: Array) Array[source][source]#

Rotates a vector v from the lab reference frame to the body reference frame.

jaxdem.utils.compute_clump_properties(state: State, mat_table: MaterialTable, n_samples: int = 50000) State[source][source]#
jaxdem.utils.compute_particle_volume(state: State) jax.Array[source][source]#

Return the total particle volume.

jaxdem.utils.compute_packing_fraction(state: State, system: System) jax.Array[source][source]#
jaxdem.utils.scale_to_packing_fraction(state: State, system: System, new_packing_fraction: float) Tuple[State, System][source][source]#
jaxdem.utils.randomize_orientations(state: State, key: jax.Array) State[source][source]#

Randomize orientations for clumps (particles with repeated state.clump_ID), leaving spheres unchanged.

jaxdem.utils.compute_translational_kinetic_energy_per_particle(state: State) jax.Array[source][source]#

Compute the translational kinetic energy per particle.

\[E_{trans} = \frac{1}{2} m |v|^2\]

Notes

  • The energy of clump members is divided by the number of spheres in the clump.

Parameters:

state (State) – The current state of the system containing particle masses and velocities.

Returns:

An array containing the translational kinetic energy for each particle.

Return type:

jax.Array

jaxdem.utils.compute_rotational_kinetic_energy_per_particle(state: State) jax.Array[source][source]#

Compute the rotational kinetic energy per particle.

\[E_{rot} = \frac{1}{2} \vec{\omega}^T I \vec{\omega}\]

Notes

  • The energy of clump members is divided by the number of spheres in the clump.

Parameters:

state (State) – The current state of the system containing inertia, orientation, and angular velocity.

Returns:

An array containing the rotational kinetic energy for each particle.

Return type:

jax.Array

jaxdem.utils.compute_translational_kinetic_energy(state: State) jax.Array[source][source]#

Compute the total translational kinetic energy of the system.

\[E_{trans, total} = \sum_{i} \frac{1}{2} m_i |v_i|^2\]
Parameters:

state (State) – The current state of the system.

Returns:

The scalar sum of translational kinetic energy across all particles.

Return type:

jax.Array

jaxdem.utils.compute_rotational_kinetic_energy(state: State) jax.Array[source][source]#

Compute the total rotational kinetic energy of the system.

\[E_{rot, total} = \sum_{i} \frac{1}{2} \vec{\omega}_i^T I_i \vec{\omega}_i\]
Parameters:

state (State) – The current state of the system.

Returns:

The scalar sum of rotational kinetic energy across all particles.

Return type:

jax.Array

jaxdem.utils.compute_potential_energy_per_particle(state: State, system: System) jax.Array[source][source]#

Compute the potential energy per particle based on system interactions. Energy is computed from the force models in the collider, and gravity and force functions that have potential energy associated with them in the force manager.

Parameters:
  • state (State) – The current state of the system.

  • system (System) – The system definition containing the collider and potential energy functions.

Returns:

An array containing the potential energy for each particle.

Return type:

jax.Array

jaxdem.utils.compute_potential_energy(state: State, system: System) jax.Array[source][source]#

Compute the total potential energy of the system. Energy is computed from the force models in the collider, and gravity and force functions that have potential energy associated with them in the force manager.

\[E_{pot, total} = \sum_{i} U(r_i)\]
Parameters:
  • state (State) – The current state of the system.

  • system (System) – The system definition containing the collider.

Returns:

The scalar sum of potential energy across all particles.

Return type:

jax.Array

jaxdem.utils.compute_energy(state: State, system: System) jax.Array[source][source]#

Compute the total mechanical energy of the system.

\[E_{total} = E_{pot, total} + E_{trans, total} + E_{rot, total}\]
Parameters:
  • state (State) – The current state of the system.

  • system (System) – The system definition containing physics parameters and colliders.

Returns:

The total energy (scalar) of the system.

Return type:

jax.Array

jaxdem.utils.compute_temperature(state: State, can_rotate: bool, subtract_drift: bool, k_B: float = 1.0) float[source][source]#

Compute the temperature for a state.

Parameters:
  • state (State) – Current simulation state.

  • can_rotate (bool) – Whether to include rigid body rotations.

  • subtract_drift (bool) – Whether to remove center-of-mass drift (usually only relevant for small systems).

  • k_B (float, optional) – Boltzmann constant (default is 1.0).

jaxdem.utils.scale_to_temperature(state: State, target_temperature: float, can_rotate: bool, subtract_drift: bool, k_B: float = 1.0) State[source][source]#

Scale the velocities of a state to a desired temperature state: State target_temperature: float - desired target temperature can_rotate: bool - whether to include the rigid body rotations subtract_drift: bool - whether to remove center of mass drift (usually only relevant for small systems) k_B: Optional[float] - boltzmanns constant, default is 1.0

jaxdem.utils.set_temperature(state: State, target_temperature: float, can_rotate: bool, subtract_drift: bool, seed: int | None = 0, k_B: float = 1.0) State[source][source]#

Randomize the velocities of a state according to a desired temperature.

Parameters:
  • state (State) – Current simulation state.

  • target_temperature (float) – Desired target temperature.

  • can_rotate (bool) – Whether to include rigid body rotations.

  • subtract_drift (bool) – Whether to remove center-of-mass drift (usually only relevant for small systems).

  • seed (int, optional) – RNG seed.

  • k_B (float, optional) – Boltzmann constant (default is 1.0).

jaxdem.utils.control_nvt_density(state: State, system: System, *, n: int, rescale_every: int, temperature_target: float | None = None, temperature_delta: float | None = None, packing_fraction_target: float | None = None, packing_fraction_delta: float | None = None, can_rotate: bool = True, subtract_drift: bool = True, k_B: float = 1.0, temperature_schedule: ScheduleFn | None = None, density_schedule: ScheduleFn | None = None, pf_min: float = 1e-12, init_temp_seed: int = 0, unroll: int = 2) Tuple[State, System][source][source]#

Runs a protocol for n integration steps, applying (optional) NVT rescaling and/or density rescaling whenever system.step_count is divisible by rescale_every.

Notes - rescale_every is in integration steps (System.step_count units). - Provide either target or delta for each controlled quantity (or neither to disable). - temperature_schedule / density_schedule must be JIT-static (passed as static_argnames).

jaxdem.utils.control_nvt_density_rollout(state: State, system: System, *, n: int, stride: int = 1, rescale_every: int = 1, temperature_target: float | None = None, temperature_delta: float | None = None, packing_fraction_target: float | None = None, packing_fraction_delta: float | None = None, can_rotate: bool = True, subtract_drift: bool = True, k_B: float = 1.0, temperature_schedule: ScheduleFn | None = None, density_schedule: ScheduleFn | None = None, pf_min: float = 1e-12, init_temp_seed: int = 0, unroll: int = 2) Tuple[State, System, Tuple[State, System]][source][source]#

Rollout variant (like System.trajectory_rollout), with globally-consistent schedules across the whole rollout.

jaxdem.utils.make_save_steps_linear(*, num_steps: int, save_freq: int, include_step0: bool = True) ndarray[source][source]#
jaxdem.utils.make_save_steps_pseudolog(*, num_steps: int, reset_save_decade: int, min_save_decade: int, decade: int = 10, include_step0: bool = True, cap: int | None = None) ndarray[source][source]#

Pseudo-log schedule compatible with the BaseLogGroup logic.

Parameters are interpreted on the integer timestep grid 0..num_steps (inclusive).

Modules

angles

Utility functions to compute angles between vectors.

clumps

dispersity

Utility functions to assign radius dispersity.

dynamicsRoutines

Jit-compiled routines for controlling temperature and density via basic rescaling.

environment

Utility functions to handle environments.

geometricAsperityCreation

Utility functions for creating Geometric Asperity particle states in 2D and 3D.

gridState

Utility functions to initialize states with particles arranged in a grid.

h5

HDF5 save/load utilities (v2).

jamming

Jamming routines.

linalg

Utility functions to help with linear algebra.

packingUtils

Utility functions for calculating and changing the packing fraction.

quaternion

Quaternion math utilities.

randomSphereConfiguration

Generates a random, energy-minimized configurations of spheres in 2D or 3D.

randomState

Utility functions to randomly initialize states.

randomizeOrientations

Utility functions to randomize particle orientations.

rollout_schedules

Utilities to generate step indices for trajectory logging.

serialization

thermal

Utility functions to compute thermodynamic quantities.