jaxdem.colliders.sweep_and_prune#
Sweep and prune \(O(N log N)\) collider implementation.
Functions
|
Computes cell hashes for particles based on their perpendicular projection. |
|
Applies a virtual shift to the particle bounds along the sweep axis based on cell hashes. |
|
Compute linear spring-like interaction force acting on particle \(i\) due to particle \(j\). |
|
Pads the state to a power of two to accommodate Pallas kernel requirements. |
Pad odd-dimensional vectors to an even size (Pallas kernel limitation). |
|
|
Pallas kernel for the Sweep and Prune algorithm. |
|
Sorts the state and particle bounds by the lower bound m. |
Classes
- jaxdem.colliders.sweep_and_prune.pad_to_power2(x: Array) Array[source]#
Pad odd-dimensional vectors to an even size (Pallas kernel limitation).
- jaxdem.colliders.sweep_and_prune.sap_kernel_full(state_ref: Any, system_ref: Any, aabb_ref: Array, m_ref: Array, M_ref: Array, HASH_ref: Array, forces_ref: Array) None[source]#
Pallas kernel for the Sweep and Prune algorithm.
- Parameters:
state_ref (Any) – Reference to the simulation state.
system_ref (Any) – Reference to the simulation system.
aabb_ref (jax.Array) – Axis-aligned bounding box half-extents.
m_ref (jax.Array) – Lower bounds of the bounding boxes along the sweep axis.
M_ref (jax.Array) – Upper bounds of the bounding boxes along the sweep axis.
HASH_ref (jax.Array) – Cell hashes for the particles.
forces_ref (jax.Array) – Output array for the accumulated forces.
- jaxdem.colliders.sweep_and_prune.compute_hash(state: Any, proj_perp: Array, aabb: Array, shift: Array) Array[source]#
Computes cell hashes for particles based on their perpendicular projection.
- Parameters:
state (Any) – The current simulation state.
proj_perp (jax.Array) – Projections of particle positions onto the plane perpendicular to the sweep axis.
aabb (jax.Array) – Axis-aligned bounding box half-extents.
shift (jax.Array) – Virtual shift to apply to the grid.
- Returns:
Computed cell hashes for each particle.
- Return type:
jax.Array
- jaxdem.colliders.sweep_and_prune.compute_virtual_shift(m: Array, M: Array, HASH: Array) tuple[Array, Array][source]#
Applies a virtual shift to the particle bounds along the sweep axis based on cell hashes.
- Parameters:
m (jax.Array) – Lower bounds along the sweep axis.
M (jax.Array) – Upper bounds along the sweep axis.
HASH (jax.Array) – Cell hashes.
- Returns:
Shifted (m, M) bounds.
- Return type:
Tuple[jax.Array, jax.Array]
- jaxdem.colliders.sweep_and_prune.sort(state: State, iota: jax.Array, m: jax.Array, M: jax.Array) tuple[State, jax.Array, jax.Array, jax.Array][source]#
Sorts the state and particle bounds by the lower bound m.
- Parameters:
state (State) – The current simulation state.
iota (jax.Array) – Indices [0, 1, …, N-1].
m (jax.Array) – Lower bounds along the sweep axis.
M (jax.Array) – Upper bounds along the sweep axis.
- Returns:
A tuple containing: - Sorted State. - Sorted lower bounds. - Sorted upper bounds. - Sorting permutation.
- Return type:
Tuple[State, jax.Array, jax.Array, jax.Array]
- jaxdem.colliders.sweep_and_prune.pad_state(state: State) State[source]#
Pads the state to a power of two to accommodate Pallas kernel requirements.
- jaxdem.colliders.sweep_and_prune.force(i: int, j: int, state: State, system: System) tuple[jax.Array, jax.Array][source]#
Compute linear spring-like interaction force acting on particle \(i\) due to particle \(j\).
Returns zero when \(i = j\).
- class jaxdem.colliders.sweep_and_prune.SweepAndPrune#
Bases:
Collider
- jaxdem.colliders.sweep_and_prune.SweeAPrune#
alias of
SweepAndPrune