# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
Utility functions to initialize states with particles arranged in a grid.
"""
from __future__ import annotations
import jax
import jax.numpy as jnp
from jax.typing import ArrayLike
from typing import Sequence, Optional
from .. import State
[docs]
def grid_state(
*,
n_per_axis: Sequence[int], # e.g. (nx, ny, nz) or (nx, ny)
spacing: ArrayLike | float, # lattice spacing (sx, sy, …)
radius: float = 1.0, # same radius for every sphere
mass: float = 1.0,
jitter: float = 0.0, # optional small random offset
vel_range: Optional[ArrayLike] = None,
radius_range: Optional[ArrayLike] = None,
mass_range: Optional[ArrayLike] = None,
seed: int = 0,
key: Optional[jax.Array] = None,
) -> State:
"""
Create a state where particles sit on a rectangular lattice.
Random values can be sampled for particle radii, masses and velocities
by specifying ``*_range`` arguments, which are interpreted as
``(min, max)`` bounds for a uniform distribution. When a range is not
provided the corresponding ``radius`` or ``mass`` argument is used for
all particles and the velocity components are sampled in ``[-1, 1]``.
Parameters
----------
n_per_axis : tuple[int]
Number of spheres along each axis.
spacing : tuple[float] | float
Centre-to-centre distance; scalar is broadcast to every axis.
radius, mass : float
Shared radius / mass for all particles when the corresponding range
is not provided.
jitter : float
Add a uniform random offset in the range [-jitter, +jitter] for
non-perfect grids (useful to break symmetry).
vel_range, radius_range, mass_range : ArrayLike | None
``(min, max)`` values for the velocity components, radii and masses.
seed : int
Integer seed used when ``key`` is not supplied.
key : PRNG key, optional
Controls randomness. If ``None`` a key will be created from ``seed``.
Returns
-------
State
"""
n_per_axis = tuple(n_per_axis)
dim = len(n_per_axis)
spacing = dim * (spacing,) if isinstance(spacing, (int, float)) else tuple(spacing)
assert len(spacing) == dim
# build grid
axes = [jnp.arange(n) * s for n, s in zip(n_per_axis, spacing)]
coords = jnp.stack(jnp.meshgrid(*axes, indexing="ij"), axis=-1)
coords = coords.reshape((-1, dim)) # (N, dim)
N, dim = coords.shape
if key is None:
key = jax.random.PRNGKey(seed)
if jitter > 0.0:
key, key_jitter = jax.random.split(key)
coords += jax.random.uniform(
key_jitter, coords.shape, minval=-jitter, maxval=jitter
)
if radius_range is not None:
radius_range = jnp.asarray(radius_range, dtype=float)
assert radius_range.size == 2, "radius_range should be size == 2"
key, key_rad = jax.random.split(key)
rad = jax.random.uniform(
key_rad, (N,), minval=radius_range[0], maxval=radius_range[1], dtype=float
)
else:
rad = radius * jnp.ones(N)
if mass_range is not None:
mass_range = jnp.asarray(mass_range, dtype=float)
assert mass_range.size == 2, "mass_range should be size == 2"
key, key_mass = jax.random.split(key)
mass_arr = jax.random.uniform(
key_mass, (N,), minval=mass_range[0], maxval=mass_range[1], dtype=float
)
else:
mass_arr = mass * jnp.ones(N)
if vel_range is None:
vel_range = jnp.array([-1.0, 1.0], dtype=float)
else:
vel_range = jnp.asarray(vel_range, dtype=float)
assert vel_range.size == 2, "vel_range should be size == 2"
key, key_vel = jax.random.split(key)
vel = jax.random.uniform(
key_vel, shape=coords.shape, minval=vel_range[0], maxval=vel_range[1]
)
return State.create(
pos=coords,
vel=vel,
rad=rad,
mass=mass_arr,
)
__all__ = ["grid_state"]