# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
Utility functions used to set up simulations and analyze the output.
"""
import jax
import jax.numpy as jnp
from jax.typing import ArrayLike
from typing import Sequence, Tuple, Union, Optional
from .state import State
# ------------------------------------------------------------------ #
# 1. Grid initialiser #
# ------------------------------------------------------------------ #
[docs]
def grid_state(
*,
n_per_axis: Sequence[int], # e.g. (nx, ny, nz) or (nx, ny)
spacing: ArrayLike | float, # lattice spacing (sx, sy, …)
radius: float = 0.5, # same radius for every sphere
mass: float = 1.0,
jitter: float = 0.0, # optional small random offset
key: Optional[jax.Array] = None,
) -> State:
"""
Create a state where particles sit on a rectangular lattice.
Parameters
----------
n_per_axis : tuple[int]
Number of spheres along each axis.
spacing : tuple[float] | float
Centre-to-centre distance; scalar is broadcast to every axis.
radius, mass : float
Shared radius / mass for all particles.
jitter : float
Add a uniform random offset in the range [-jitter, +jitter] for
non-perfect grids (useful to break symmetry).
key : PRNG key
Required when `jitter > 0`.
Returns
-------
State
"""
n_per_axis = tuple(n_per_axis)
dim = len(n_per_axis)
spacing = dim * (spacing,) if isinstance(spacing, (int, float)) else tuple(spacing)
assert len(spacing) == dim
# build grid
axes = [jnp.arange(n) * s for n, s in zip(n_per_axis, spacing)]
coords = jnp.stack(jnp.meshgrid(*axes, indexing="ij"), axis=-1)
coords = coords.reshape((-1, dim)) # (N, dim)
N, dim = coords.shape
if jitter > 0.0:
if key is None:
raise ValueError("`key` must be provided when jitter > 0")
coords += jax.random.uniform(key, coords.shape, minval=-jitter, maxval=jitter)
key = jax.random.key(0)
vel = jax.random.uniform(key, shape=coords.shape, minval=-1.0, maxval=1.0)
return State.create(
pos=coords,
vel=vel,
rad=radius * jnp.ones(N),
mass=mass * jnp.ones(N),
)
[docs]
def random_state(
*,
N: int,
dim: int,
box_size: Optional[ArrayLike] = None,
box_anchor: Optional[ArrayLike] = None,
radius_range: Optional[ArrayLike] = None,
mass_range: Optional[ArrayLike] = None,
vel_range: Optional[ArrayLike] = None,
seed: int = 0,
) -> State:
"""
Generate `N` non-overlap-checked particles uniformly in an axis-aligned box.
Parameters
----------
N
Number of particles.
dim
Spatial dimension (2 or 3).
box_size
Edge lengths of the domain.
box_anchor
Coordinate of the lower box corner.
radius_range, mass_range
min and max values that the radius can take.
vel_range
min and max values that the velocity components can take.
seed
Integer for reproducibility.
Returns
-------
State
A fully-initialised `State` instance.
"""
if box_size is None:
box_size = 10 * jnp.ones(dim, dtype=float)
box_size = jnp.asarray(box_size, dtype=float)
if box_anchor is None:
box_anchor = jnp.zeros(dim, dtype=float)
box_anchor = jnp.asarray(box_anchor, dtype=float)
if radius_range is None:
radius_range = 10 * jnp.ones(2, dtype=float)
radius_range = jnp.asarray(radius_range, dtype=float)
assert radius_range.size == 2, "Rad range should be size == 2"
if mass_range is None:
mass_range = jnp.ones(2, dtype=float)
mass_range = jnp.asarray(mass_range, dtype=float)
assert mass_range.size == 2, "Mass range should be size == 2"
if vel_range is None:
vel_range = jnp.ones(2, dtype=float)
vel_range = jnp.asarray(vel_range, dtype=float)
assert vel_range.size == 2, "Vel range should be size == 2"
box_min = box_anchor
box_max = box_anchor + box_size
key = jax.random.PRNGKey(seed)
key_pos, key_rad, key_mass, key_vel = jax.random.split(key, 4)
pos = jax.random.uniform(
key_pos, (N, dim), minval=box_min, maxval=box_max, dtype=float
)
rad = jax.random.uniform(
key_rad, (N,), minval=radius_range[0], maxval=radius_range[1], dtype=float
)
mass = jax.random.uniform(
key_mass, (N,), minval=mass_range[0], maxval=mass_range[1], dtype=float
)
vel = jax.random.uniform(
key_vel, (N, dim), minval=vel_range[0], maxval=vel_range[1], dtype=float
)
return State.create(
pos=pos,
vel=vel,
rad=rad,
mass=mass,
ID=jnp.arange(N, dtype=int),
)