Source code for jaxdem.utils.dispersity
# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
Utility functions to assign radius disperisty.
"""
from __future__ import annotations
import warnings
from typing import Sequence
import jax
import numpy as np
import jax.numpy as jnp
[docs]
def allocate_counts(N: int, count_ratios: Sequence[float], *, ensure_each_size_nonzero: bool = True) -> np.ndarray:
"""
Convert population fractions into integer counts that sum exactly to N.
Uses a "largest remainder" method (Hamilton apportionment).
If ensure_each_size_nonzero is True, enforces count >= 1 for each species
(requires N >= num_species).
"""
if N <= 0:
raise ValueError(f"N must be positive; got {N}.")
ratios = np.asarray(count_ratios, dtype=float)
if ratios.ndim != 1 or ratios.size == 0:
raise ValueError("count_ratios must be a 1D non-empty array-like.")
if not np.isfinite(ratios).all():
raise ValueError("count_ratios contains non-finite values.")
if np.any(ratios < 0):
raise ValueError("count_ratios must be non-negative.")
k = int(ratios.size)
# Normalize (keeping behavior stable if user passes percentages that don't sum to 1).
total = float(ratios.sum())
if total <= 0:
raise ValueError("count_ratios must sum to a positive value.")
ratios = ratios / total
if ensure_each_size_nonzero:
if N < k:
raise ValueError(f"Cannot give each of {k} sizes at least 1 particle with N={N}.")
if np.any(ratios == 0):
warnings.warn(
"ensure_each_size_nonzero=True but some count_ratios are 0; "
"those species will still be forced to have count=1.",
RuntimeWarning,
)
# Start with one particle per species, apportion the remaining N-k.
base = np.ones(k, dtype=int)
N_rem = N - k
raw = N_rem * ratios
else:
base = np.zeros(k, dtype=int)
raw = N * ratios
floors = np.floor(raw).astype(int)
frac = raw - floors
counts = base + floors
remaining = int(N - counts.sum())
if remaining > 0:
# Give +1 to the species with largest fractional parts.
order = np.argsort(-frac) # descending
counts[order[:remaining]] += 1
elif remaining < 0:
# Remove -1 from the smallest fractional parts, but never drop below base.
order = np.argsort(frac) # ascending
to_remove = -remaining
for idx in order:
if to_remove == 0:
break
if counts[idx] > base[idx]:
counts[idx] -= 1
to_remove -= 1
if to_remove != 0:
raise RuntimeError("Failed to apportion counts without violating nonzero constraint.")
# Final safety: exact sum and integer non-negative.
if counts.sum() != N:
raise RuntimeError(f"Internal error: counts sum {counts.sum()} != N {N}.")
if np.any(counts < 0):
raise RuntimeError("Internal error: negative counts produced.")
if ensure_each_size_nonzero and np.any(counts < 1):
raise RuntimeError("Internal error: nonzero constraint violated.")
return counts
[docs]
def get_polydisperse_radii(
N: int,
count_ratios: Sequence[float] = (0.5, 0.5),
size_ratios: Sequence[float] = (1.0, 1.4),
small_radius: float = 0.5,
ensure_size_nonzero: bool = False,
) -> jax.Array:
"""
Construct a polydisperse set of particle radii from population and size ratios.
Parameters
----------
N : int
Total number of particles.
count_ratios : array-like of float
Population fractions for each size class (will be normalized to sum to 1).
size_ratios : array-like of float
Radius multipliers for each size class, relative to the smallest size.
For example, size_ratios=[1.0, 1.4] means the large particles have radius 1.4x the
small particles (before normalization by min(size_ratios)).
small_radius : float
The absolute radius corresponding to the smallest size class (must be > 0).
ensure_size_nonzero : bool
If True, enforce that each size class has at least one particle (requires N >= number of sizes).
Returns
-------
jax.Array
1D array of length N containing the radii for each particle.
"""
count_ratios = np.asarray(count_ratios)
size_ratios = np.asarray(size_ratios)
assert len(count_ratios) == len(size_ratios), f"Got inconsistent sizes for count_ratios ({len(count_ratios)}) and size_ratios ({len(size_ratios)})"
count_ratios = count_ratios / np.sum(count_ratios)
assert np.all(np.isfinite(size_ratios))
assert np.all(size_ratios > 0), "size_ratios must be positive (multiples of small_radius)."
assert np.isfinite(small_radius) and small_radius > 0, "small_radius must be positive."
size_ratios = size_ratios / np.min(size_ratios)
counts = allocate_counts(N, count_ratios, ensure_each_size_nonzero=ensure_size_nonzero)
sizes = small_radius * size_ratios
# Warn if finite-size rounding makes the achieved fractions differ from requested.
achieved = counts / N
if not np.all(np.isclose(achieved, count_ratios)):
print(f'Warning: cannot achieve exact count ratio ({count_ratios}) - got ({achieved})')
return jnp.array(np.concatenate([np.ones(c) * s for c, s in zip(counts, sizes)]))