jaxdem.analysis#
Post-processing / analysis utilities.
This subpackage contains a minimal, JAX-friendly binned-accumulation engine:
Bin specifications (time bins, lag bins) in
jaxdem.analysis.binsFlattening bins to index-pairs in
jaxdem.analysis.pairsA JAX engine (vmap + segment_sum) in
jaxdem.analysis.engineExample kernels in
jaxdem.analysis.kernels
- class jaxdem.analysis.BinSpec(T: int, timestep: ndarray | None = None)[source][source]#
Bases:
objectAbstract bin specification.
- Parameters:
T (int) – Number of frames (time steps).
timestep (np.ndarray or None, optional) – Physical timestep labels of shape
(T,). If absent, defaults tonp.arange(T).
- class jaxdem.analysis.TimeBins(T: int, t_min: int | None = None, t_max: int | None = None, *, timestep: ndarray | None = None)[source][source]#
Bases:
BinSpecOne bin per time index (optionally restricted to a timestep range).
- class jaxdem.analysis.LagBinsExact(T: int, taus: Sequence[int] | ndarray[Any, Any], *, cap: int | None = None, sample: str = 'stride', seed: int = 0, timestep: ndarray | None = None)[source][source]#
Bases:
BinSpecBins for a provided set of exact physical time lags (taus).
- classmethod from_source(source: Any, *args: Any, **kwargs: Any) LagBinsExact[source][source]#
- class jaxdem.analysis.LagBinsLinear(T: int, dt_min: int | None = None, dt_max: int | None = None, *, step: int = 1, num_points: int | None = None, cap: int | None = None, sample: str = 'stride', seed: int = 0, timestep: ndarray | None = None)[source][source]#
Bases:
LagBinsExactLinearly spaced lag bins between [dt_min, dt_max] on the timestep grid.
- classmethod from_source(source: Any, *args: Any, **kwargs: Any) LagBinsLinear[source][source]#
- class jaxdem.analysis.LagBinsLog(T: int, dt_min: int | None = None, dt_max: int | None = None, *, num_bins: int | None = None, num_per_decade: int | None = None, cap: int | None = None, sample: str = 'stride', seed: int = 0, timestep: ndarray | None = None)[source][source]#
Bases:
LagBinsExactLog-spaced lag bins on the realizable timestep grid.
- classmethod from_source(source: Any, *args: Any, **kwargs: Any) LagBinsLog[source][source]#
- class jaxdem.analysis.LagBinsPseudoLog(T: int, dt_min: int | None = None, dt_max: int | None = None, *, digits: Sequence[int] = (1, 2, 3, 4, 5, 6, 7, 8, 9), cap: int | None = None, sample: str = 'stride', seed: int = 0, timestep: ndarray | None = None)[source][source]#
Bases:
LagBinsExactPseudo-log lag bins using digits * powers of ten on the timestep grid.
- classmethod from_source(source: Any, *args: Any, **kwargs: Any) LagBinsPseudoLog[source][source]#
- class jaxdem.analysis.Pairs(pair_i: ndarray, pair_j: ndarray, bin_id: ndarray, counts_per_bin: ndarray)[source]#
Bases:
objectFlat representation of bin tuples, suitable for JAX execution.
- pair_i#
shape (P,) int array
- Type:
numpy.ndarray
- pair_j#
shape (P,) int array
- Type:
numpy.ndarray
- bin_id#
shape (P,) int array in [0, B)
- Type:
numpy.ndarray
- counts_per_bin#
shape (B,) int array (number of tuples per bin)
- Type:
numpy.ndarray
- pair_i: ndarray#
- pair_j: ndarray#
- bin_id: ndarray#
- counts_per_bin: ndarray#
- jaxdem.analysis.build_pairs(binspec: BinSpec) Pairs[source][source]#
Build (pair_i, pair_j, bin_id) arrays from a BinSpec.
- class jaxdem.analysis.Binned(sums: Any, counts: Array, mean: Any, pairs: Pairs)[source]#
Bases:
objectBinned accumulation output.
- sums#
pytree with each leaf shaped (B, …)
- Type:
Any
- counts#
array shape (B,) float32
- Type:
jax.Array
- mean#
pytree with each leaf shaped (B, …)
- Type:
Any
- pairs#
flattened pair representation used for the run (host arrays)
- sums: Any#
- counts: Array#
- mean: Any#
- jaxdem.analysis.evaluate_binned(kernel: Any, arrays: Mapping[str, Any], binspec: BinSpec, *, kernel_kwargs: Dict[str, Any] | None = None, jit: bool = True) Binned[source][source]#
Run a kernel over bins and average in JAX.
- Parameters:
kernel (callable) – Pure function called as
kernel(arrays, t0, t1, **kernel_kwargs).arrays (Mapping[str, Any]) – Mapping of field name to array with leading time axis, e.g.
pos: (T, N, d)orpos: (T, S, N, d).binspec (BinSpec) – Bin specification (host-side); defines which indices to use.
kernel_kwargs (dict or None, optional) – Extra keyword arguments passed to kernel.
jit (bool, optional) – Whether to JIT-compile the core compute.
- jaxdem.analysis.msd_kernel(arrays: Mapping[str, Array], t0: Any, t1: Any) Array[source][source]#
Mean-squared displacement.
Works for both: - pos[t]: (N,d) -> returns () scalar - pos[t]: (S,N,d) -> returns (S,) vector
Modules