jaxdem.utils#
Utility functions used to set up simulations and analyze the output.
- jaxdem.utils.unit(v: Array) Array[source][source]#
Normalize vectors along the last axis. v: (…, D) returns: (…, D), unit vectors; zeros map to zeros.
- jaxdem.utils.cross_3X3D_1X2D(w: Array, r: Array) Array[source][source]#
Computes the cross product of angular velocity vector (w) and a position vector (r), often used to find tangential velocity: v = w x r.
This function handles two scenarios based on the dimension of ‘r’:
3D Case (r.shape[-1] == 3): - w must be a 3D vector (w.shape[-1] == 3). - Computes the standard 3D cross product: w x r.
2D Case (r.shape[-1] == 2): - w is treated as a scalar (the z-component of angular velocity, w_z). - The computation is equivalent to: (0, 0, w_z) x (r_x, r_y, 0). - The result is the 2D tangential velocity vector (v_x, v_y) in the xy-plane.
- Parameters:
w (JAX Array. In the 3D case, shape is (..., 3). In the 2D case, shape is (..., 1) or (...).)
r (JAX Array. Shape is (..., 3) or (..., 2).)
- Returns:
A JAX Array representing the tangential velocity (w x r).
- If r is 3D, the output shape is (…, 3).
- If r is 2D, the output shape is (…, 2).
- Raises:
ValueError – If r is not 2D or 3D, or if dimensions are incompatible.
- jaxdem.utils.signed_angle(v1: Array, v2: Array) Array[source][source]#
Directional angle from v1 -> v2 around normal \(\hat{z}\) (right-hand rule), in \([-\pi, \pi)\).
- jaxdem.utils.signed_angle_x(v1: Array) Array[source][source]#
Directional angle from v1 -> \(\hat{x}\) around normal \(\hat{z}\), in \((-\pi, \pi]\).
- jaxdem.utils.grid_state(*, n_per_axis: Sequence[int], spacing: ArrayLike | float, radius: float = 1.0, mass: float = 1.0, jitter: float = 0.0, vel_range: ArrayLike | None = None, radius_range: ArrayLike | None = None, mass_range: ArrayLike | None = None, seed: int = 0, key: jax.Array | None = None) State[source][source]#
Create a state where particles sit on a rectangular lattice.
Random values can be sampled for particle radii, masses and velocities by specifying
*_rangearguments, which are interpreted as(min, max)bounds for a uniform distribution. When a range is not provided the correspondingradiusormassargument is used for all particles and the velocity components are sampled in[-1, 1].- Parameters:
n_per_axis (tuple[int]) – Number of spheres along each axis.
spacing (tuple[float] | float) – Centre-to-centre distance; scalar is broadcast to every axis.
radius (float) – Shared radius / mass for all particles when the corresponding range is not provided.
mass (float) – Shared radius / mass for all particles when the corresponding range is not provided.
jitter (float) – Add a uniform random offset in the range [-jitter, +jitter] for non-perfect grids (useful to break symmetry).
vel_range (ArrayLike | None) –
(min, max)values for the velocity components, radii and masses.radius_range (ArrayLike | None) –
(min, max)values for the velocity components, radii and masses.mass_range (ArrayLike | None) –
(min, max)values for the velocity components, radii and masses.seed (int) – Integer seed used when
keyis not supplied.key (PRNG key, optional) – Controls randomness. If
Nonea key will be created fromseed.
- Return type:
- jaxdem.utils.random_state(*, N: int, dim: int, box_size: ArrayLike | None = None, box_anchor: ArrayLike | None = None, radius_range: ArrayLike | None = None, mass_range: ArrayLike | None = None, vel_range: ArrayLike | None = None, seed: int = 0) State[source][source]#
Generate N non-overlap-checked particles uniformly in an axis-aligned box.
- Parameters:
N – Number of particles.
dim – Spatial dimension (2 or 3).
box_size – Edge lengths of the domain.
box_anchor – Coordinate of the lower box corner.
radius_range – min and max values that the radius can take.
mass_range – min and max values that the radius can take.
vel_range – min and max values that the velocity components can take.
seed – Integer for reproducibility.
- Returns:
A fully-initialised State instance.
- Return type:
- jaxdem.utils.encode_callable(fn: Callable) str[source][source]#
Return a dotted path like ‘jax._src.nn.functions.gelu’.
- jaxdem.utils.decode_callable(path: str) Callable[source][source]#
Import a callable from a dotted path string.
- jaxdem.utils.env_step(env: Environment, model: Callable, key: jax.Array, *, n: int = 1, **kw: Any) Environment[source][source]#
Advance the environment n steps using actions from model.
- Parameters:
env (Environment) – Initial environment pytree (batchable).
model (Callable) – Callable with signature model(obs, key, **kw) -> action.
n (int) – Number of steps to perform.
**kw (Any) – Extra keyword arguments forwarded to model.
- Returns:
Environment after n steps.
- Return type:
Examples
>>> env = env_step(env, model, n=10, objective=goal)
- jaxdem.utils.env_trajectory_rollout(env: Environment, model: Callable, key: jax.Array, *, n: int, stride: int = 1, **kw: Any) Tuple['Environment', 'Environment'][source][source]#
Roll out a trajectory by applying model in chunks of stride steps and collecting the environment after each chunk.
- Parameters:
env (Environment) – Initial environment pytree.
model (Callable) – Callable with signature model(obs, key, **kw) -> action.
n (int) – Number of chunks to roll out. Total internal steps = n * stride.
stride (int) – Steps per chunk between recorded snapshots.
**kw (Any) – Extra keyword arguments passed to model on every step.
- Returns:
Environment – Environment after n * stride steps.
Environment – Stacked pytree of environments with length n, each snapshot taken after a chunk of stride steps.
Examples
>>> env, traj = env_trajectory_rollout(env, model, n=100, stride=5, objective=goal)
- jaxdem.utils.lidar(env: Environment) jax.Array[source][source]#
- class jaxdem.utils.Quaternion(w: Array, xyz: Array)[source]#
Bases:
objectQuaternion representing the orientation of a particle. Stores the rotation body to lab.
- static conj(q: Quaternion) Quaternion[source][source]#
- static create(w: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, xyz: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None) Quaternion[source][source]#
- static inv(q: Quaternion) Quaternion[source][source]#
- static rotate(q: Quaternion, v: Array) Array[source][source]#
Rotates a vector v from the body reference frame to the lab reference frame.
- static rotate_back(q: Quaternion, v: Array) Array[source][source]#
Rotates a vector v from the lab reference frame to the body reference frame.
- static unit(q: Quaternion) Quaternion[source][source]#
- w: Array#
- xyz: Array#
- jaxdem.utils.compute_clump_properties(state: State, mat_table: MaterialTable, n_samples: int = 50000) State[source][source]#
- jaxdem.utils.random_sphere_configuration(particle_radii: Sequence[float] | Sequence[Sequence[float]], phi: float | Sequence[float], dim: int, seed: int | None = None, collider_type='naive', box_aspect: Sequence[float] | Sequence[Sequence[float]] | None = None) Tuple[Array, Array][source][source]#
Generate one or more random sphere packings at a target packing fraction.
This builds periodic systems with spherical particles, initializes particle positions uniformly at random inside a rectangular periodic box, and then minimizes the potential energy to obtain a mechanically stable configuration.
The function supports batching over multiple independent “systems” by treating the leading axis as the system index and broadcasting any length-1 inputs to match the maximum number of systems inferred from the inputs.
- Parameters:
particle_radii –
Particle radii for one system or multiple systems.
Single system: a 1D sequence of length
N(radii for each particle).Multiple systems: a 2D sequence with shape
(S, N)(one radii list per system).
Internally, this is converted to a JAX array with shape
(S, N).phi –
Target packing fraction(s).
Scalar: a single float applied to all systems.
Per-system: a 1D sequence of length
S.
Internally, this is converted to a JAX array with shape
(S, 1)and then broadcast/padded to match the inferred number of systems.dim – Spatial dimension (e.g. 2 or 3).
seed –
RNG seed used to initialize particle positions. If
None, a random seed is drawn via NumPy.Note: a single JAX PRNGKey is used to generate the full position array of shape
(S, N, dim).collider_type – Collision detection backend. Must be one of
"naive"or"celllist".box_aspect –
Box aspect ratios for the periodic domain.
If
None, defaults tojnp.ones(dim).Otherwise must be a 1D sequence of length
dim.
Internally broadcast/padded to shape
(S, dim).(Even though the type annotation allows a sequence-of-sequences, the current implementation asserts
len(box_aspect) == dimbefore broadcasting, so per-system(S, dim)input is not accepted here.)
- Returns:
pos – Particle positions after minimization.
If
S > 1: shape(S, N, dim).If
S == 1: shape(N, dim)due tosqueeze().
box_size – Periodic box size vectors.
If
S > 1: shape(S, dim).If
S == 1: shape(dim,)due tosqueeze().
Notes
Broadcasting rule: any input provided for a single system (leading dimension 1) is replicated to match the maximum
Sinferred fromparticle_radii,phi, andbox_aspect.The final
squeeze()calls can also drop other singleton dimensions (e.g. ifN == 1). If you need stable rank/shape, remove the squeezes.
Modules
Utility functions to compute angles between vectors. |
|
Utility functions to assign radius disperisty. |
|
Utility functions to handle environments. |
|
Utility functions to initialize states with particles arranged in a grid. |
|
Jamming routines. |
|
Utility functions to help with linear algebra. |
|
Utility functions to handle environments. |
|
Utility functions to randomly initialize states. |
|