jaxdem.colliders.sweep_and_prune#

Sweep and prune \(O(N log N)\) collider implementation.

Functions

compute_hash(state, proj_perp, aabb, shift)

compute_virtual_shift(m, M, HASH)

force(i, j, state, system)

pad_to_power2(x)

Pad 3D simulations to 4D (Pallas Kernel limitations)

padd(state)

sap_kernel_full(state_ref, system_ref, ...)

sort(state, iota, m, M)

Classes

RefProxy(ref)

Wraps a Pallas Ref to behave like a JAX Array with auto-loading.

SweeAPrune()

jaxdem.colliders.sweep_and_prune.pad_to_power2(x)[source][source]#

Pad 3D simulations to 4D (Pallas Kernel limitations)

class jaxdem.colliders.sweep_and_prune.RefProxy(ref)[source][source]#

Bases: object

Wraps a Pallas Ref to behave like a JAX Array with auto-loading.

property shape[source]#
property dtype[source]#
property ndim[source]#
jaxdem.colliders.sweep_and_prune.sap_kernel_full(state_ref, system_ref, aabb_ref, m_ref, M_ref, HASH_ref, forces_ref)[source][source]#
jaxdem.colliders.sweep_and_prune.compute_hash(state, proj_perp, aabb, shift)[source][source]#
jaxdem.colliders.sweep_and_prune.compute_virtual_shift(m, M, HASH)[source][source]#
jaxdem.colliders.sweep_and_prune.sort(state, iota, m, M)[source][source]#
jaxdem.colliders.sweep_and_prune.padd(state)[source][source]#
jaxdem.colliders.sweep_and_prune.force(i: int, j: int, state: State, system: System) Tuple[jax.Array, jax.Array][source][source]#
class jaxdem.colliders.sweep_and_prune.SweeAPrune[source]#

Bases: Collider

static compute_potential_energy(state: State, system: System)[source][source]#
static compute_force(state: State, system: System) Tuple['State', 'System'][source][source]#