jaxdem.utils#

Utility functions used to set up simulations and analyze the output.

jaxdem.utils.unit(v: Array) Array[source][source]#

Normalize vectors along the last axis. v: (…, D) returns: (…, D), unit vectors; zeros map to zeros.

jaxdem.utils.cross_3X3D_1X2D(w: Array, r: Array) Array[source][source]#

Computes the cross product of angular velocity vector (w) and a position vector (r), often used to find tangential velocity: v = w x r.

This function handles two scenarios based on the dimension of ‘r’:

  1. 3D Case (r.shape[-1] == 3): - w must be a 3D vector (w.shape[-1] == 3). - Computes the standard 3D cross product: w x r.

  2. 2D Case (r.shape[-1] == 2): - w is treated as a scalar (the z-component of angular velocity, w_z). - The computation is equivalent to: (0, 0, w_z) x (r_x, r_y, 0). - The result is the 2D tangential velocity vector (v_x, v_y) in the xy-plane.

Parameters:
  • w (JAX Array. In the 3D case, shape is (..., 3). In the 2D case, shape is (..., 1) or (...).)

  • r (JAX Array. Shape is (..., 3) or (..., 2).)

Returns:

  • A JAX Array representing the tangential velocity (w x r).

  • - If r is 3D, the output shape is (…, 3).

  • - If r is 2D, the output shape is (…, 2).

Raises:

ValueError – If r is not 2D or 3D, or if dimensions are incompatible.

jaxdem.utils.signed_angle(v1: Array, v2: Array) Array[source][source]#

Directional angle from v1 -> v2 around normal \(\hat{z}\) (right-hand rule), in \([-\pi, \pi)\).

jaxdem.utils.signed_angle_x(v1: Array) Array[source][source]#

Directional angle from v1 -> \(\hat{x}\) around normal \(\hat{z}\), in \((-\pi, \pi]\).

jaxdem.utils.angle(v1: Array, v2: Array) Array[source][source]#

angle from v1 -> v2 in \([0, \pi]\)

jaxdem.utils.angle_x(v1: Array) Array[source][source]#

angle from v1 -> \(\hat{x}\) in \([0, \pi]\)

jaxdem.utils.grid_state(*, n_per_axis: Sequence[int], spacing: ArrayLike | float, radius: float = 1.0, mass: float = 1.0, jitter: float = 0.0, vel_range: ArrayLike | None = None, radius_range: ArrayLike | None = None, mass_range: ArrayLike | None = None, seed: int = 0, key: jax.Array | None = None) State[source][source]#

Create a state where particles sit on a rectangular lattice.

Random values can be sampled for particle radii, masses and velocities by specifying *_range arguments, which are interpreted as (min, max) bounds for a uniform distribution. When a range is not provided the corresponding radius or mass argument is used for all particles and the velocity components are sampled in [-1, 1].

Parameters:
  • n_per_axis (tuple[int]) – Number of spheres along each axis.

  • spacing (tuple[float] | float) – Centre-to-centre distance; scalar is broadcast to every axis.

  • radius (float) – Shared radius / mass for all particles when the corresponding range is not provided.

  • mass (float) – Shared radius / mass for all particles when the corresponding range is not provided.

  • jitter (float) – Add a uniform random offset in the range [-jitter, +jitter] for non-perfect grids (useful to break symmetry).

  • vel_range (ArrayLike | None) – (min, max) values for the velocity components, radii and masses.

  • radius_range (ArrayLike | None) – (min, max) values for the velocity components, radii and masses.

  • mass_range (ArrayLike | None) – (min, max) values for the velocity components, radii and masses.

  • seed (int) – Integer seed used when key is not supplied.

  • key (PRNG key, optional) – Controls randomness. If None a key will be created from seed.

Return type:

State

jaxdem.utils.random_state(*, N: int, dim: int, box_size: ArrayLike | None = None, box_anchor: ArrayLike | None = None, radius_range: ArrayLike | None = None, mass_range: ArrayLike | None = None, vel_range: ArrayLike | None = None, seed: int = 0) State[source][source]#

Generate N non-overlap-checked particles uniformly in an axis-aligned box.

Parameters:
  • N – Number of particles.

  • dim – Spatial dimension (2 or 3).

  • box_size – Edge lengths of the domain.

  • box_anchor – Coordinate of the lower box corner.

  • radius_range – min and max values that the radius can take.

  • mass_range – min and max values that the radius can take.

  • vel_range – min and max values that the velocity components can take.

  • seed – Integer for reproducibility.

Returns:

A fully-initialised State instance.

Return type:

State

jaxdem.utils.encode_callable(fn: Callable) str[source][source]#

Return a dotted path like ‘jax._src.nn.functions.gelu’.

jaxdem.utils.decode_callable(path: str) Callable[source][source]#

Import a callable from a dotted path string.

jaxdem.utils.env_step(env: Environment, model: Callable, key: jax.Array, *, n: int = 1, **kw: Any) Environment[source][source]#

Advance the environment n steps using actions from model.

Parameters:
  • env (Environment) – Initial environment pytree (batchable).

  • model (Callable) – Callable with signature model(obs, key, **kw) -> action.

  • n (int) – Number of steps to perform.

  • **kw (Any) – Extra keyword arguments forwarded to model.

Returns:

Environment after n steps.

Return type:

Environment

Examples

>>> env = env_step(env, model, n=10, objective=goal)
jaxdem.utils.env_trajectory_rollout(env: Environment, model: Callable, key: jax.Array, *, n: int, stride: int = 1, **kw: Any) Tuple['Environment', 'Environment'][source][source]#

Roll out a trajectory by applying model in chunks of stride steps and collecting the environment after each chunk.

Parameters:
  • env (Environment) – Initial environment pytree.

  • model (Callable) – Callable with signature model(obs, key, **kw) -> action.

  • n (int) – Number of chunks to roll out. Total internal steps = n * stride.

  • stride (int) – Steps per chunk between recorded snapshots.

  • **kw (Any) – Extra keyword arguments passed to model on every step.

Returns:

  • Environment – Environment after n * stride steps.

  • Environment – Stacked pytree of environments with length n, each snapshot taken after a chunk of stride steps.

Examples

>>> env, traj = env_trajectory_rollout(env, model, n=100, stride=5, objective=goal)
jaxdem.utils.lidar(env: Environment) jax.Array[source][source]#
class jaxdem.utils.Quaternion(w: Array, xyz: Array)[source]#

Bases: object

Quaternion representing the orientation of a particle. Stores the rotation body to lab.

static conj(q: Quaternion) Quaternion[source][source]#
static create(w: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, xyz: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None) Quaternion[source][source]#
static inv(q: Quaternion) Quaternion[source][source]#
static rotate(q: Quaternion, v: Array) Array[source][source]#

Rotates a vector v from the body reference frame to the lab reference frame.

static rotate_back(q: Quaternion, v: Array) Array[source][source]#

Rotates a vector v from the lab reference frame to the body reference frame.

static unit(q: Quaternion) Quaternion[source][source]#
w: Array#
xyz: Array#
jaxdem.utils.compute_clump_properties(state: State, mat_table: MaterialTable, n_samples: int = 50000) State[source][source]#
jaxdem.utils.random_sphere_configuration(particle_radii: Sequence[float] | Sequence[Sequence[float]], phi: float | Sequence[float], dim: int, seed: int | None = None, collider_type='naive', box_aspect: Sequence[float] | Sequence[Sequence[float]] | None = None) Tuple[Array, Array][source][source]#

Generate one or more random sphere packings at a target packing fraction.

This builds periodic systems with spherical particles, initializes particle positions uniformly at random inside a rectangular periodic box, and then minimizes the potential energy to obtain a mechanically stable configuration.

The function supports batching over multiple independent “systems” by treating the leading axis as the system index and broadcasting any length-1 inputs to match the maximum number of systems inferred from the inputs.

Parameters:
  • particle_radii

    Particle radii for one system or multiple systems.

    • Single system: a 1D sequence of length N (radii for each particle).

    • Multiple systems: a 2D sequence with shape (S, N) (one radii list per system).

    Internally, this is converted to a JAX array with shape (S, N).

  • phi

    Target packing fraction(s).

    • Scalar: a single float applied to all systems.

    • Per-system: a 1D sequence of length S.

    Internally, this is converted to a JAX array with shape (S, 1) and then broadcast/padded to match the inferred number of systems.

  • dim – Spatial dimension (e.g. 2 or 3).

  • seed

    RNG seed used to initialize particle positions. If None, a random seed is drawn via NumPy.

    Note: a single JAX PRNGKey is used to generate the full position array of shape (S, N, dim).

  • collider_type – Collision detection backend. Must be one of "naive" or "celllist".

  • box_aspect

    Box aspect ratios for the periodic domain.

    • If None, defaults to jnp.ones(dim).

    • Otherwise must be a 1D sequence of length dim.

    Internally broadcast/padded to shape (S, dim).

    (Even though the type annotation allows a sequence-of-sequences, the current implementation asserts len(box_aspect) == dim before broadcasting, so per-system (S, dim) input is not accepted here.)

Returns:

  • pos – Particle positions after minimization.

    • If S > 1: shape (S, N, dim).

    • If S == 1: shape (N, dim) due to squeeze().

  • box_size – Periodic box size vectors.

    • If S > 1: shape (S, dim).

    • If S == 1: shape (dim,) due to squeeze().

Notes

  • Broadcasting rule: any input provided for a single system (leading dimension 1) is replicated to match the maximum S inferred from particle_radii, phi, and box_aspect.

  • The final squeeze() calls can also drop other singleton dimensions (e.g. if N == 1). If you need stable rank/shape, remove the squeezes.

Modules

angles

Utility functions to compute angles between vectors.

clumps

dispersity

Utility functions to assign radius disperisty.

environment

Utility functions to handle environments.

gridState

Utility functions to initialize states with particles arranged in a grid.

h5

jamming

Jamming routines.

linalg

Utility functions to help with linear algebra.

quaternion

Utility functions to handle environments.

randomSphereConfiguration

randomState

Utility functions to randomly initialize states.

serialization