jaxdem.colliders.sweep_and_prune#

Sweep and prune \(O(N log N)\) collider implementation.

Functions

compute_hash(state, proj_perp, aabb, shift)

Computes cell hashes for particles based on their perpendicular projection.

compute_virtual_shift(m, M, HASH)

Applies a virtual shift to the particle bounds along the sweep axis based on cell hashes.

force(i, j, state, system)

Compute linear spring-like interaction force acting on particle \(i\) due to particle \(j\).

pad_state(state)

Pads the state to a power of two to accommodate Pallas kernel requirements.

pad_to_power2(x)

Pad odd-dimensional vectors to an even size (Pallas kernel limitation).

sap_kernel_full(state_ref, system_ref, ...)

Pallas kernel for the Sweep and Prune algorithm.

sort(state, iota, m, M)

Sorts the state and particle bounds by the lower bound m.

Classes

jaxdem.colliders.sweep_and_prune.pad_to_power2(x: Array) Array[source]#

Pad odd-dimensional vectors to an even size (Pallas kernel limitation).

jaxdem.colliders.sweep_and_prune.sap_kernel_full(state_ref: Any, system_ref: Any, aabb_ref: Array, m_ref: Array, M_ref: Array, HASH_ref: Array, forces_ref: Array) None[source]#

Pallas kernel for the Sweep and Prune algorithm.

Parameters:
  • state_ref (Any) – Reference to the simulation state.

  • system_ref (Any) – Reference to the simulation system.

  • aabb_ref (jax.Array) – Axis-aligned bounding box half-extents.

  • m_ref (jax.Array) – Lower bounds of the bounding boxes along the sweep axis.

  • M_ref (jax.Array) – Upper bounds of the bounding boxes along the sweep axis.

  • HASH_ref (jax.Array) – Cell hashes for the particles.

  • forces_ref (jax.Array) – Output array for the accumulated forces.

jaxdem.colliders.sweep_and_prune.compute_hash(state: Any, proj_perp: Array, aabb: Array, shift: Array) Array[source]#

Computes cell hashes for particles based on their perpendicular projection.

Parameters:
  • state (Any) – The current simulation state.

  • proj_perp (jax.Array) – Projections of particle positions onto the plane perpendicular to the sweep axis.

  • aabb (jax.Array) – Axis-aligned bounding box half-extents.

  • shift (jax.Array) – Virtual shift to apply to the grid.

Returns:

Computed cell hashes for each particle.

Return type:

jax.Array

jaxdem.colliders.sweep_and_prune.compute_virtual_shift(m: Array, M: Array, HASH: Array) tuple[Array, Array][source]#

Applies a virtual shift to the particle bounds along the sweep axis based on cell hashes.

Parameters:
  • m (jax.Array) – Lower bounds along the sweep axis.

  • M (jax.Array) – Upper bounds along the sweep axis.

  • HASH (jax.Array) – Cell hashes.

Returns:

Shifted (m, M) bounds.

Return type:

Tuple[jax.Array, jax.Array]

jaxdem.colliders.sweep_and_prune.sort(state: State, iota: jax.Array, m: jax.Array, M: jax.Array) tuple[State, jax.Array, jax.Array, jax.Array][source]#

Sorts the state and particle bounds by the lower bound m.

Parameters:
  • state (State) – The current simulation state.

  • iota (jax.Array) – Indices [0, 1, …, N-1].

  • m (jax.Array) – Lower bounds along the sweep axis.

  • M (jax.Array) – Upper bounds along the sweep axis.

Returns:

A tuple containing: - Sorted State. - Sorted lower bounds. - Sorted upper bounds. - Sorting permutation.

Return type:

Tuple[State, jax.Array, jax.Array, jax.Array]

jaxdem.colliders.sweep_and_prune.pad_state(state: State) State[source]#

Pads the state to a power of two to accommodate Pallas kernel requirements.

Parameters:

state (State) – The current simulation state.

Returns:

The padded simulation state.

Return type:

State

jaxdem.colliders.sweep_and_prune.force(i: int, j: int, state: State, system: System) tuple[jax.Array, jax.Array][source]#

Compute linear spring-like interaction force acting on particle \(i\) due to particle \(j\).

Returns zero when \(i = j\).

Parameters:
  • i (int) – Index of the first particle.

  • j (int) – Index of the second particle.

  • state (State) – Current state of the simulation.

  • system (System) – Simulation system configuration.

Returns:

Force vector acting on particle \(i\) due to particle \(j\).

Return type:

jax.Array

class jaxdem.colliders.sweep_and_prune.SweepAndPrune#

Bases: Collider

static create_neighbor_list(state: State, system: System, cutoff: float, max_neighbors: int) tuple[State, System, jax.Array, jax.Array][source]#
static compute_force(state: State, system: System) tuple[State, System][source]#
jaxdem.colliders.sweep_and_prune.SweeAPrune#

alias of SweepAndPrune