from typing import Optional, Sequence, Tuple
import numpy as np
import jax
import jax.numpy as jnp
import jaxdem as jd
def _broadcast(arr, is_scalar):
arr = jnp.asarray(arr)
assert not (arr.ndim == 0 and not is_scalar), f"This is not a scalar array!"
if arr.ndim == 0:
return arr.reshape(1, 1)
if arr.ndim == 1:
if is_scalar:
return arr[:, None]
else:
return arr[None, :]
return arr
def _pad(arr, size):
assert not (arr.shape[0] != size and arr.shape[0] != 1)
return arr * jnp.ones((size, arr.shape[1]), dtype=arr.dtype)
[docs]
def random_sphere_configuration(
particle_radii: Sequence[float] | Sequence[Sequence[float]],
phi: float | Sequence[float],
dim: int,
seed: Optional[int] = None,
collider_type="naive",
box_aspect: Optional[Sequence[float] | Sequence[Sequence[float]]] = None) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Generate one or more random sphere packings at a target packing fraction.
This builds periodic systems with spherical particles, initializes particle
positions uniformly at random inside a rectangular periodic box, and then minimizes
the potential energy to obtain a mechanically stable configuration.
The function supports **batching over multiple independent "systems"** by treating
the leading axis as the system index and broadcasting any length-1 inputs to match
the maximum number of systems inferred from the inputs.
Parameters
----------
particle_radii
Particle radii for one system or multiple systems.
- **Single system**: a 1D sequence of length ``N`` (radii for each particle).
- **Multiple systems**: a 2D sequence with shape ``(S, N)`` (one radii list per system).
Internally, this is converted to a JAX array with shape ``(S, N)``.
phi
Target packing fraction(s).
- **Scalar**: a single float applied to all systems.
- **Per-system**: a 1D sequence of length ``S``.
Internally, this is converted to a JAX array with shape ``(S, 1)`` and then
broadcast/padded to match the inferred number of systems.
dim
Spatial dimension (e.g. 2 or 3).
seed
RNG seed used to initialize particle positions. If ``None``, a random seed is
drawn via NumPy.
Note: a **single** JAX PRNGKey is used to generate the full position array of
shape ``(S, N, dim)``.
collider_type
Collision detection backend. Must be one of ``"naive"`` or ``"celllist"``.
box_aspect
Box aspect ratios for the periodic domain.
- If ``None``, defaults to ``jnp.ones(dim)``.
- Otherwise must be a 1D sequence of length ``dim``.
Internally broadcast/padded to shape ``(S, dim)``.
(Even though the type annotation allows a sequence-of-sequences, the current
implementation asserts ``len(box_aspect) == dim`` before broadcasting, so
per-system ``(S, dim)`` input is not accepted here.)
Returns
-------
pos
Particle positions after minimization.
- If ``S > 1``: shape ``(S, N, dim)``.
- If ``S == 1``: shape ``(N, dim)`` due to ``squeeze()``.
box_size
Periodic box size vectors.
- If ``S > 1``: shape ``(S, dim)``.
- If ``S == 1``: shape ``(dim,)`` due to ``squeeze()``.
Notes
-----
- **Broadcasting rule**: any input provided for a single system (leading dimension 1)
is replicated to match the maximum ``S`` inferred from ``particle_radii``, ``phi``,
and ``box_aspect``.
- The final ``squeeze()`` calls can also drop other singleton dimensions (e.g. if
``N == 1``). If you need stable rank/shape, remove the squeezes.
"""
# handle seed assignment
if seed is None:
seed = np.random.randint(0, 1e9)
assert collider_type in ["naive", "celllist"], f"Collider type {collider_type} not understood. Must be one of [naive, celllist]"
if box_aspect is None:
box_aspect = jnp.ones(dim)
else:
box_aspect = jnp.asarray(box_aspect)
assert dim == len(box_aspect), f"Box aspect ({len(box_aspect)}) and spatial dimension ({dim}) do not match."
# broadcast to leading dimension
particle_radii = _broadcast(particle_radii, is_scalar=False)
phi = _broadcast(phi, is_scalar=True)
box_aspect = _broadcast(box_aspect, is_scalar=False)
# pad to proper sizing
N_systems = max(arr.shape[0] for arr in [particle_radii, phi, box_aspect])
particle_radii = _pad(particle_radii, N_systems)
phi = _pad(phi, N_systems)
box_aspect = _pad(box_aspect, N_systems)
e_int = 1.0
mass = 1.0
dt = 1e-2
N = particle_radii.shape[1]
key = jax.random.PRNGKey(seed)
mats = [jd.Material.create("elastic", young=e_int, poisson=0.5, density=1.0)]
matcher = jd.MaterialMatchmaker.create("harmonic")
mat_table = jd.MaterialTable.from_materials(mats, matcher=matcher)
pos = jax.random.uniform(key, (N_systems, N, dim), minval=0, maxval=1) * box_aspect[:, None, :]
def _build_state(i):
# create system and state
state = jd.State.create(pos=pos[i], rad=particle_radii[i], mass=mass * jnp.ones(N))
# box aspect = [a, b, c]
# box size = l * box aspect
l = (jnp.sum(state.volume) / (phi[i] * jnp.prod(box_aspect))) ** (1 / dim)
box_size = l * jnp.array(box_aspect[i])
state.pos_c *= l
collider_kw = dict()
if collider_type == "celllist":
collider_kw = dict(state=state)
system = jd.System.create(
state_shape=state.shape,
dt=dt,
linear_integrator_type="linearfire",
rotation_integrator_type="",
domain_type="periodic",
force_model_type="spring",
collider_type=collider_type,
collider_kw=collider_kw,
mat_table=mat_table,
domain_kw=dict(
box_size=box_size,
),
)
return state, system
state, system = jax.vmap(_build_state)(jnp.arange(N_systems))
assert jnp.all(jnp.isclose(jnp.sum(state.volume, axis=-1) / jnp.prod(system.domain.box_size, axis=-1), phi.squeeze()))
state, system, steps, final_pe = jax.vmap(lambda st, sys: jd.minimizers.minimize(st, sys, max_steps=1_000_000, pe_tol=1e-16, pe_diff_tol=1e-16, initialize=True))(state, system)
return state.pos.squeeze(), system.domain.box_size.squeeze()