jaxdem.utils#
Utility functions used to set up simulations and analyze the output.
- class jaxdem.utils.Quaternion(w: Array, xyz: Array)#
Bases:
objectQuaternion representing particle orientation (body frame to lab frame).
- w: Array#
- xyz: Array#
- static create(w: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, xyz: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None) Quaternion[source]#
- static unit(q: Quaternion) Quaternion[source]#
- static conj(q: Quaternion) Quaternion[source]#
- static inv(q: Quaternion) Quaternion[source]#
- static rotate(q: Quaternion, v: Array) Array[source]#
Rotates a vector v from the body reference frame to the lab reference frame.
- static rotate_back(q: Quaternion, v: Array) Array[source]#
Rotates a vector v from the lab reference frame to the body reference frame.
- jaxdem.utils.compute_clump_properties(state: State, mat_table: MaterialTable, n_samples: int = 50000) State[source]#
- jaxdem.utils.compute_energy(state: State, system: System) jax.Array[source]#
Compute the total mechanical energy of the system.
\[E_{total} = E_{pot, total} + E_{trans, total} + E_{rot, total}\]
- jaxdem.utils.compute_particle_volume(state: State) jax.Array[source]#
Return the total particle volume.
- jaxdem.utils.compute_potential_energy(state: State, system: System) jax.Array[source]#
Compute the total potential energy of the system. Energy is computed from the force models in the collider, and gravity and force functions that have potential energy associated with them in the force manager.
\[E_{pot, total} = \sum_{i} U(r_i)\]
- jaxdem.utils.compute_rotational_kinetic_energy(state: State) jax.Array[source]#
Compute the total rotational kinetic energy of the system.
\[E_{rot, total} = \sum_{i} \frac{1}{2} \vec{\omega}_i^T I_i \vec{\omega}_i\]- Parameters:
state (State) – The current state of the system.
- Returns:
The scalar sum of rotational kinetic energy across all particles.
- Return type:
jax.Array
- jaxdem.utils.compute_rotational_kinetic_energy_per_particle(state: State) jax.Array[source]#
Compute the rotational kinetic energy per particle.
\[E_{rot} = \frac{1}{2} \vec{\omega}^T I \vec{\omega}\]Notes
The energy of clump members is divided by the number of spheres in the clump.
- Parameters:
state (State) – The current state of the system containing inertia, orientation, and angular velocity.
- Returns:
An array containing the rotational kinetic energy for each particle.
- Return type:
jax.Array
- jaxdem.utils.compute_temperature(state: State, can_rotate: bool, subtract_drift: bool, k_B: float = 1.0) float[source]#
Compute the temperature for a state.
- Parameters:
state (State) – Current simulation state.
can_rotate (bool) – Whether to include rigid body rotations.
subtract_drift (bool) – Whether to remove center-of-mass drift (usually only relevant for small systems).
k_B (float, optional) – Boltzmann constant (default is 1.0).
- jaxdem.utils.compute_translational_kinetic_energy(state: State) jax.Array[source]#
Compute the total translational kinetic energy of the system.
\[E_{trans, total} = \sum_{i} \frac{1}{2} m_i |v_i|^2\]- Parameters:
state (State) – The current state of the system.
- Returns:
The scalar sum of translational kinetic energy across all particles.
- Return type:
jax.Array
- jaxdem.utils.compute_translational_kinetic_energy_per_particle(state: State) jax.Array[source]#
Compute the translational kinetic energy per particle.
\[E_{trans} = \frac{1}{2} m |v|^2\]Notes
The energy of clump members is divided by the number of spheres in the clump.
- Parameters:
state (State) – The current state of the system containing particle masses and velocities.
- Returns:
An array containing the translational kinetic energy for each particle.
- Return type:
jax.Array
- jaxdem.utils.control_nvt_density(state: State, system: System, *, n: int, rescale_every: int, temperature_target: float | None = None, temperature_delta: float | None = None, packing_fraction_target: float | None = None, packing_fraction_delta: float | None = None, can_rotate: bool = True, subtract_drift: bool = True, k_B: float = 1.0, temperature_schedule: ScheduleFn | None = None, density_schedule: ScheduleFn | None = None, pf_min: float = 1e-12, init_temp_seed: int = 0, unroll: int = 2) tuple[State, System][source]#
Runs a protocol for n integration steps, applying (optional) NVT rescaling and/or density rescaling whenever system.step_count is divisible by rescale_every.
Notes
rescale_every is in integration steps (System.step_count units).
Provide either target or delta for each controlled quantity (or neither to disable).
temperature_schedule / density_schedule must be JIT-static (passed as static_argnames).
- jaxdem.utils.control_nvt_density_rollout(state: State, system: System, *, n: int, stride: int = 1, rescale_every: int = 1, temperature_target: float | None = None, temperature_delta: float | None = None, packing_fraction_target: float | None = None, packing_fraction_delta: float | None = None, can_rotate: bool = True, subtract_drift: bool = True, k_B: float = 1.0, temperature_schedule: ScheduleFn | None = None, density_schedule: ScheduleFn | None = None, pf_min: float = 1e-12, init_temp_seed: int = 0, unroll: int = 2) tuple[State, System, tuple[State, System]][source]#
Rollout variant (like System.trajectory_rollout), with globally-consistent schedules across the whole rollout.
- jaxdem.utils.count_clump_contacts(state: State, system: System, cutoff: float | None = None, max_neighbors: int | None = None) tuple[State, System, jax.Array][source]#
Count unique clump-level contacts per clump.
- Parameters:
- Returns:
state (State) – Potentially updated state.
system (System) – Potentially updated system.
contacts (jax.Array) –
(N_clumps,)array of unique clump contact counts per clump.
- jaxdem.utils.count_vertex_contacts(state: State, system: System, cutoff: float | None = None, max_neighbors: int | None = None) tuple[State, System, jax.Array][source]#
Count vertex-level contacts per clump.
- Parameters:
- Returns:
state (State) – Potentially updated state.
system (System) – Potentially updated system.
contacts (jax.Array) –
(N_clumps,)array of vertex contact counts per clump.
- jaxdem.utils.cross(a: Array, b: Array) Array[source]#
Computes the cross product of two vectors, ‘a’ and ‘b’, along their last axis.
For 3D vectors (D=3), the result is a vector orthogonal to both ‘a’ and ‘b’. For 2D vectors (D=2), the result is the scalar magnitude of the 3D cross product when a third zero component is assumed, often interpreted as the signed area of the parallelogram spanned by the vectors.
- Parameters:
a (JAX Array with shape (..., D), where D is the dimension (2 or 3).)
b (JAX Array with shape (..., D), where D must match a's dimension.)
- Returns:
A JAX Array representing the cross product.
- If D=3 (shape is (…, 3).)
- If D=2 (shape is (…, 1) (a scalar wrapped in an array).)
- Raises:
ValueError – If the last dimension (D) is not 2 or 3, or if the last dimensions of ‘a’ and ‘b’ do not match.:
- jaxdem.utils.cross_3X3D_1X2D(w: Array, r: Array) Array[source]#
Computes the cross product of angular velocity vector (w) and a position vector (r), often used to find tangential velocity: v = w x r.
This function handles two scenarios based on the dimension of ‘r’:
3D Case (r.shape[-1] == 3): - w must be a 3D vector (w.shape[-1] == 3). - Computes the standard 3D cross product: w x r.
2D Case (r.shape[-1] == 2): - w is treated as a scalar (the z-component of angular velocity, w_z). - The computation is equivalent to: (0, 0, w_z) x (r_x, r_y, 0). - The result is the 2D tangential velocity vector (v_x, v_y) in the xy-plane.
- Parameters:
w (JAX Array. In the 3D case, shape is (..., 3). In the 2D case, shape is (..., 1) or (...).)
r (JAX Array. Shape is (..., 3) or (..., 2).)
- Returns:
A JAX Array representing the tangential velocity (w x r).
- If r is 3D, the output shape is (…, 3).
- If r is 2D, the output shape is (…, 2).
- Raises:
ValueError – If r is not 2D or 3D, or if dimensions are incompatible.:
- jaxdem.utils.cross_lidar_2d(pos_a: jax.Array, pos_b: jax.Array, system: System, lidar_range: float, n_bins: int, max_neighbors: int) tuple[jax.Array, jax.Array, jax.Array][source]#
2-D LIDAR proximity and IDs from
pos_asensing targets inpos_b.Computes all-pairs displacements from
pos_atopos_b, bins by azimuthal angle, and returns per-bin proximity and closest target IDs.- Parameters:
pos_a (jax.Array) – Sensor positions, shape
(N_A, dim).pos_b (jax.Array) – Target positions, shape
(N_B, dim).system (System) – System configuration.
lidar_range (float) – Maximum detection range and reference distance for proximity.
n_bins (int) – Number of angular bins spanning \([-\pi, \pi)\).
max_neighbors (int) – Unused. Kept for backward compatibility.
- Returns:
(proximity, ids, overflow)whereproximityandidshave shape(N_A, n_bins)andoverflowis alwaysFalse. Empty bins getids = -1.- Return type:
Tuple[jax.Array, jax.Array, jax.Array]
Notes
Uses an all-pairs approach and does not invoke the collider. Returned
idsare indices intopos_bregardless of howpos_amay have been reordered by a cell-list collider.Examples
>>> prox, ids, overflow = cross_lidar_2d(agents, obstacles, system, ... lidar_range=5.0, n_bins=36, ... max_neighbors=64)
- jaxdem.utils.cross_lidar_3d(pos_a: jax.Array, pos_b: jax.Array, system: System, lidar_range: float, n_azimuth: int, n_elevation: int, max_neighbors: int) tuple[jax.Array, jax.Array, jax.Array][source]#
3-D LIDAR proximity and IDs from
pos_asensing targets inpos_b.Computes all-pairs displacements from
pos_atopos_b, bins on a spherical grid, and returns per-bin proximity and closest target IDs.- Parameters:
pos_a (jax.Array) – Sensor positions, shape
(N_A, 3).pos_b (jax.Array) – Target positions, shape
(N_B, 3).system (System) – System configuration.
lidar_range (float) – Maximum detection range and reference distance for proximity.
n_azimuth (int) – Number of azimuthal bins.
n_elevation (int) – Number of elevation bins.
max_neighbors (int) – Unused. Kept for backward compatibility.
- Returns:
(proximity, ids, overflow)whereproximityandidshave shape(N_A, n_azimuth * n_elevation)andoverflowis alwaysFalse. Empty bins getids = -1.- Return type:
Tuple[jax.Array, jax.Array, jax.Array]
Notes
Uses an all-pairs approach and does not invoke the collider. Returned
idsare indices intopos_bregardless of howpos_amay have been reordered by a cell-list collider.Examples
>>> prox, ids, overflow = cross_lidar_3d(agents, obstacles, system, ... lidar_range=5.0, n_azimuth=36, ... n_elevation=18, max_neighbors=64)
- jaxdem.utils.decode_callable(path: str) Callable[[...], Any][source]#
Import a callable from a dotted path string.
- jaxdem.utils.dot(a: Array, b: Array) Array[source]#
Dot product of vectors along the last axis.
a, b: (…, D) returns: (…), the dot product.
- jaxdem.utils.encode_callable(fn: Callable[[...], Any]) str[source]#
Return a dotted path like ‘jax._src.nn.functions.gelu’.
- jaxdem.utils.env_step(env: Environment, model: Callable[..., Any], key: jax.Array, *, n: int = 1, **kw: Any) tuple[Environment, jax.Array][source]#
Advance the environment n steps using actions from model.
- Parameters:
env (Environment) – Initial environment pytree (batchable).
model (Callable) – Callable with signature model(obs, key, **kw) -> action.
key (jax.Array) – JAX random key. The returned key is the advanced version that should be used for subsequent calls.
n (int) – Number of steps to perform.
**kw (Any) – Extra keyword arguments forwarded to model.
- Returns:
Updated environment and the advanced random key.
- Return type:
Tuple[Environment, jax.Array]
Examples
>>> env, key = env_step(env, model, key, n=10, objective=goal)
- jaxdem.utils.env_trajectory_rollout(env: Environment, model: Callable[..., Any], key: jax.Array, *, n: int, stride: int = 1, **kw: Any) tuple[Environment, jax.Array, Environment][source]#
Roll out a trajectory by applying model in chunks of stride steps and collecting the environment after each chunk.
- Parameters:
env (Environment) – Initial environment pytree.
model (Callable) – Callable with signature model(obs, key, **kw) -> action.
key (jax.Array) – JAX random key. The returned key is the advanced version that should be used for subsequent calls.
n (int) – Number of chunks to roll out. Total internal steps = n * stride.
stride (int) – Steps per chunk between recorded snapshots.
**kw (Any) – Extra keyword arguments passed to model on every step.
- Returns:
Final environment, advanced random key, and a stacked pytree of environments with length n, each snapshot taken after a chunk of stride steps.
- Return type:
Tuple[Environment, jax.Array, Environment]
Examples
>>> env, key, traj = env_trajectory_rollout(env, model, key, n=100, stride=5, objective=goal)
- jaxdem.utils.get_clump_rattler_ids(state: State, system: System, cutoff: float | None = None, max_neighbors: int | None = None, zc: int | None = None) tuple[State, System, jax.Array, jax.Array][source]#
Identify rattler clumps by iteratively removing under-coordinated clumps.
A clump is a rattler if its total vertex-contact count is below the coordination threshold zc.
- Parameters:
- Returns:
state (State) – Potentially updated state.
system (System) – Potentially updated system.
rattler_ids (jax.Array) – 1-D array of rattler clump IDs.
non_rattler_ids (jax.Array) – 1-D array of non-rattler clump IDs.
- jaxdem.utils.get_pair_forces_and_ids(state: State, system: System, cutoff: float | None = None, max_neighbors: int | None = None) tuple[State, System, jax.Array, jax.Array][source]#
Compute pairwise contact forces and their associated particle IDs.
- Parameters:
- Returns:
state (State) – Potentially updated state (after neighbor-list rebuild).
system (System) – Potentially updated system.
pair_ids (jax.Array) –
(M, 2)array of(i, j)sphere index pairs.forces (jax.Array) –
(M, dim)array of pairwise force vectors, one per pair.
- jaxdem.utils.get_sphere_rattler_ids(state: State, system: System, cutoff: float | None = None, max_neighbors: int | None = None, zc: int | None = None) tuple[State, System, jax.Array, jax.Array][source]#
Identify rattler spheres by iteratively removing under-coordinated particles.
- Parameters:
- Returns:
state (State) – Potentially updated state.
system (System) – Potentially updated system.
rattler_ids (jax.Array) – 1-D array of rattler sphere indices.
non_rattler_ids (jax.Array) – 1-D array of non-rattler sphere indices.
- jaxdem.utils.grid_state(*, n_per_axis: Sequence[int], spacing: ArrayLike | float, radius: float = 1.0, mass: float = 1.0, jitter: float = 0.0, vel_range: ArrayLike | None = None, radius_range: ArrayLike | None = None, mass_range: ArrayLike | None = None, seed: int = 0, key: jax.Array | None = None) State[source]#
Create a state where particles sit on a rectangular lattice.
Random values can be sampled for particle radii, masses and velocities by specifying
*_rangearguments, which are interpreted as(min, max)bounds for a uniform distribution. When a range is not provided the correspondingradiusormassargument is used for all particles and the velocity components are sampled in[-1, 1].- Parameters:
n_per_axis (tuple[int]) – Number of spheres along each axis.
spacing (tuple[float] | float) – Centre-to-centre distance; scalar is broadcast to every axis.
radius (float) – Shared radius / mass for all particles when the corresponding range is not provided.
mass (float) – Shared radius / mass for all particles when the corresponding range is not provided.
jitter (float) – Add a uniform random offset in the range [-jitter, +jitter] for non-perfect grids (useful to break symmetry).
vel_range (ArrayLike | None) –
(min, max)values for the velocity components, radii and masses.radius_range (ArrayLike | None) –
(min, max)values for the velocity components, radii and masses.mass_range (ArrayLike | None) –
(min, max)values for the velocity components, radii and masses.seed (int) – Integer seed used when
keyis not supplied.key (PRNG key, optional) – Controls randomness. If
Nonea key will be created fromseed.
- Return type:
- jaxdem.utils.lidar_2d(state: State, system: System, lidar_range: float, n_bins: int, max_neighbors: int, sense_edges: bool = False) tuple[State, System, jax.Array, jax.Array, jax.Array][source]#
2-D LIDAR proximity readings and neighbor IDs.
For every particle in
statethe displacement vectors to all other particles are projected onto the \(xy\)-plane and binned by azimuthal angle inton_binsuniform sectors spanning \([-\pi, \pi)\). Each bin stores the proximity value and the index of the closest neighbor in that sector:\[p_k = \max(0,\; r_{\max} - d_{\min,k})\]This works identically for 2-D and 3-D position data; in the 3-D case the \(z\)-component of the displacement is simply ignored during binning while the full Euclidean distance is used for proximity.
- Parameters:
state (State) – Simulation state (positions, radii, etc.).
system (System) – System configuration including domain.
lidar_range (float) – Maximum detection range and reference distance for proximity.
n_bins (int) – Number of angular bins (rays) spanning \([-\pi, \pi)\).
max_neighbors (int) – Unused. Kept for backward compatibility.
sense_edges (bool, optional) – If
True, domain boundaries are included as proximity sources. Wall detections receive an ID of-1. Only meaningful for bounded domains. Default isFalse.
- Returns:
(state, system, proximity, ids, overflow)wherestateandsystemare unchanged,proximityandidshave shape(N, n_bins), andoverflowis alwaysFalse. Bins with no detection haveidsset to the particle’s own index.- Return type:
Notes
This function computes all-pairs displacements directly from
state.posand does not invoke the collider. The returnedidsare indices intostate.posin whatever order it has at call time, so results are correct regardless of whether a cell-list collider has reordered the state.Examples
>>> state, system, prox, ids, overflow = lidar_2d(state, system, ... lidar_range=5.0, n_bins=36, max_neighbors=64)
- jaxdem.utils.lidar_3d(state: State, system: System, lidar_range: float, n_azimuth: int, n_elevation: int, max_neighbors: int, sense_edges: bool = False) tuple[State, System, jax.Array, jax.Array, jax.Array][source]#
3-D LIDAR proximity readings and neighbor IDs.
Similar to
lidar_2d()but bins neighbors on a spherical grid defined byn_azimuthazimuthal sectors in \([-\pi, \pi)\) andn_elevationelevation bands in \([-\pi/2, \pi/2]\). The returned proximity and ID arrays have shape(N, n_azimuth * n_elevation)with flat indexingaz * n_elevation + el.- Parameters:
state (State) – Simulation state.
system (System) – System configuration including domain.
lidar_range (float) – Maximum detection range and reference distance for proximity.
n_azimuth (int) – Number of azimuthal bins.
n_elevation (int) – Number of elevation bins.
max_neighbors (int) – Unused. Kept for backward compatibility.
sense_edges (bool, optional) – If
True, domain boundaries are included as proximity sources. Wall detections receive an ID of-1. Default isFalse.
- Returns:
(state, system, proximity, ids, overflow)wherestateandsystemare unchanged,proximityandidshave shape(N, n_azimuth * n_elevation), andoverflowis alwaysFalse.- Return type:
Notes
Uses an all-pairs approach and does not invoke the collider. Returned
idsindex intostate.posin its current order, so results are correct regardless of collider-induced reordering.Examples
>>> state, system, prox, ids, overflow = lidar_3d(state, system, ... lidar_range=5.0, n_azimuth=36, n_elevation=18, max_neighbors=64)
- jaxdem.utils.load_legacy_dp(path: str, ref_pos: jax.Array | None = None, dim: int = 3) DeformableParticleModel[source]#
Load an old
DeformableParticleContainerh5 file and return a newDeformableParticleModel.- Parameters:
path (str) – Path to the
.h5file containing the saved DP container.ref_pos (jax.Array, optional) – Reference vertex positions, shape
(N, dim). Required for 3D to compute the neww_bbending normalization. You can obtain this from the legacy state:ref_pos = state.pos.dim (int) – Spatial dimension (2 or 3). Needed to choose the correct
w_bcomputation.
- Returns:
A new model instance with fields mapped from the old container.
- Return type:
- jaxdem.utils.load_legacy_simulation(state_path: str, system_path: str, dp_path: str | None = None) tuple[State, System][source]#
Load state, system, and (optionally) a deformable-particle container from old-format h5 files and wire them into a ready-to-use
(State, System)pair.When dp_path is given, the DP model is attached to the system via
system.bonded_force_modeland its force/energy functions are registered in the force manager.- Parameters:
state_path (str) – Path to the legacy State h5 file.
system_path (str) – Path to the legacy System h5 file.
dp_path (str, optional) – Path to the legacy DeformableParticleContainer h5 file.
- Returns:
state (State) – The loaded state with current field names.
system (System) – The loaded system, with bonded forces wired up if dp_path was given.
Example
from jaxdem.utils.load_legacy import load_legacy_simulation state, system = load_legacy_simulation( "old_data/state.h5", "old_data/system.h5", dp_path="old_data/dp.h5", )
- jaxdem.utils.load_legacy_state(path: str) State[source]#
Load a
Statesaved with the old field naming convention (angVel,clump_ID,deformable_ID,unique_ID).- Parameters:
path (str) – Path to the
.h5file containing the saved State.- Returns:
A new State constructed with the current field names.
- Return type:
- jaxdem.utils.load_legacy_system(path: str, state_shape: tuple[int, ...] | None = None) System[source]#
Load a
Systemsaved with the old schema (nobonded_force_modelorinteract_same_bond_idfields).The current
System.createfactory is used to produce a valid skeleton; scalar fields (dt,time,step_count,key) and nested component dataclasses that still exist (collider,domain,force_model,mat_table,force_manager, integrators) are overwritten from the file where the schemas still match.- Parameters:
path (str) – Path to the
.h5file containing the saved System.state_shape (tuple of int, optional) – Shape hint
(N, dim)passed toSystem.createto build default components. If None, inferred from the storedforce_manager/external_forceorforce_manager/external_force_comdataset.
- Returns:
A new System instance populated with as much data from the file as possible.
bonded_force_modeldefaults to None andinteract_same_bond_iddefaults to False.- Return type:
- jaxdem.utils.make_save_steps_linear(*, num_steps: int, save_freq: int, include_step0: bool = True) ndarray[source]#
- jaxdem.utils.make_save_steps_pseudolog(*, num_steps: int, reset_save_decade: int, min_save_decade: int, decade: int = 10, include_step0: bool = True, cap: int | None = None) ndarray[source]#
Pseudo-log schedule compatible with the BaseLogGroup logic.
Parameters are interpreted on the integer timestep grid 0..num_steps (inclusive).
- jaxdem.utils.norm(v: Array) Array[source]#
Norm of vectors along the last axis.
v: (…, D) returns: (…), the norm.
- jaxdem.utils.norm2(v: Array) Array[source]#
Squared norm of vectors along the last axis.
v: (…, D) returns: (…), the squared norm.
- jaxdem.utils.random_state(*, N: int, dim: int, box_size: ArrayLike | None = None, box_anchor: ArrayLike | None = None, radius_range: ArrayLike | None = None, mass_range: ArrayLike | None = None, vel_range: ArrayLike | None = None, seed: int = 0) State[source]#
Generate N non-overlap-checked particles uniformly in an axis-aligned box.
- Parameters:
N – Number of particles.
dim – Spatial dimension (2 or 3).
box_size – Edge lengths of the domain.
box_anchor – Coordinate of the lower box corner.
radius_range – min and max values that the radius can take.
mass_range – min and max values that the radius can take.
vel_range – min and max values that the velocity components can take.
seed – Integer for reproducibility.
- Returns:
A fully-initialised State instance.
- Return type:
- jaxdem.utils.randomize_orientations(state: State, key: jax.Array) State[source]#
Randomize orientations for clumps (particles with repeated
state.clump_id), leaving spheres unchanged.
- jaxdem.utils.remove_rattlers_from_state(state: State, rattler_clump_ids: jax.Array) State[source]#
Remove all spheres belonging to rattler clumps and rebuild the state.
- jaxdem.utils.scale_to_packing_fraction(state: State, system: System, new_packing_fraction: float) tuple[State, System][source]#
- jaxdem.utils.scale_to_temperature(state: State, target_temperature: float, can_rotate: bool, subtract_drift: bool, k_B: float = 1.0) State[source]#
Scale the velocities of a state to a desired temperature state: State target_temperature: float - desired target temperature can_rotate: bool - whether to include the rigid body rotations subtract_drift: bool - whether to remove center of mass drift (usually only relevant for small systems) k_B: Optional[float] - boltzmanns constant, default is 1.0.
- jaxdem.utils.set_temperature(state: State, target_temperature: float, can_rotate: bool, subtract_drift: bool, seed: int | None = 0, k_B: float = 1.0) State[source]#
Randomize the velocities of a state according to a desired temperature.
- Parameters:
state (State) – Current simulation state.
target_temperature (float) – Desired target temperature.
can_rotate (bool) – Whether to include rigid body rotations.
subtract_drift (bool) – Whether to remove center-of-mass drift (usually only relevant for small systems).
seed (int, optional) – RNG seed.
k_B (float, optional) – Boltzmann constant (default is 1.0).
- jaxdem.utils.signed_angle(v1: Array, v2: Array) Array[source]#
Directional angle from v1 -> v2 around normal \(\hat{z}\) (right-hand rule), in \([-\pi, \pi)\).
- jaxdem.utils.signed_angle_x(v1: Array) Array[source]#
Directional angle from v1 -> \(\hat{x}\) around normal \(\hat{z}\), in \((-\pi, \pi]\).
- jaxdem.utils.unit(v: Array) Array[source]#
Normalize vectors along the last axis.
v: (…, D) returns: (…, D), unit vectors; zeros map to zeros.
- jaxdem.utils.unit_and_norm(v: Array) tuple[Array, Array][source]#
Normalize vectors along the last axis and return the norm.
v: (…, D) returns: ((…, D), (…, 1)), unit vectors and their norms; zeros map to zeros.
Modules
Utility functions to compute angles between vectors. |
|
Utility functions for analyzing particle contacts and identifying rattlers. |
|
Utility functions to assign radius dispersity. |
|
Jit-compiled routines for controlling temperature and density via basic rescaling. |
|
Utility functions to handle environments and LIDAR sensor. |
|
Utility functions for creating Geometric Asperity particle states in 2D and 3D. |
|
Utility functions to initialize states with particles arranged in a grid. |
|
HDF5 save/load utilities (v2). |
|
Jamming routines. |
|
Utility functions to help with linear algebra. |
|
Adapter for loading HDF5 data saved with the pre-merge-2-27-26 branch. |
|
Utility functions for calculating and changing the packing fraction. |
|
Quaternion math utilities. |
|
Generates a random, energy-minimized configurations of spheres in 2D or 3D. |
|
Utility functions to randomly initialize states. |
|
Utility functions to randomize particle orientations. |
|
Utilities to generate step indices for trajectory logging. |
|
Utility functions to compute thermodynamic quantities. |