jaxdem.analysis#
Post-processing / analysis utilities.
This subpackage contains a minimal, JAX-friendly binned-accumulation engine:
Bin specifications (time bins, lag bins) in
jaxdem.analysis.binsFlattening bins to index-pairs in
jaxdem.analysis.pairsA JAX engine (vmap + segment_sum) in
jaxdem.analysis.engineExample kernels in
jaxdem.analysis.kernels
- class jaxdem.analysis.BinSpec(T: int, timestep: ndarray | None = None)[source]#
Bases:
objectAbstract bin specification.
- Parameters:
T (int) – Number of frames (time steps).
timestep (np.ndarray or None, optional) – Physical timestep labels of shape
(T,). If absent, defaults tonp.arange(T).
- class jaxdem.analysis.Binned(sums: Any, counts: Array, mean: Any, pairs: Pairs)#
Bases:
objectBinned accumulation output.
- sums#
pytree with each leaf shaped (B, …)
- Type:
Any
- counts#
array shape (B,) float32
- Type:
jax.Array
- mean#
pytree with each leaf shaped (B, …)
- Type:
Any
- pairs#
flattened pair representation used for the run (host arrays)
- sums: Any#
- counts: Array#
- mean: Any#
- class jaxdem.analysis.LagBinsExact(T: int, taus: Sequence[int] | ndarray[Any, Any], *, cap: int | None = None, sample: str = 'stride', seed: int = 0, timestep: ndarray | None = None)[source]#
Bases:
BinSpecBins for a provided set of exact physical time lags (taus).
- classmethod from_source(source: Any, *args: Any, **kwargs: Any) LagBinsExact[source]#
- class jaxdem.analysis.LagBinsLinear(T: int, dt_min: int | None = None, dt_max: int | None = None, *, step: int = 1, num_points: int | None = None, cap: int | None = None, sample: str = 'stride', seed: int = 0, timestep: ndarray | None = None)[source]#
Bases:
LagBinsExactLinearly spaced lag bins between [dt_min, dt_max] on the timestep grid.
- classmethod from_source(source: Any, *args: Any, **kwargs: Any) LagBinsLinear[source]#
- class jaxdem.analysis.LagBinsLog(T: int, dt_min: int | None = None, dt_max: int | None = None, *, num_bins: int | None = None, num_per_decade: int | None = None, cap: int | None = None, sample: str = 'stride', seed: int = 0, timestep: ndarray | None = None)[source]#
Bases:
LagBinsExactLog-spaced lag bins on the realizable timestep grid.
- classmethod from_source(source: Any, *args: Any, **kwargs: Any) LagBinsLog[source]#
- class jaxdem.analysis.LagBinsPseudoLog(T: int, dt_min: int | None = None, dt_max: int | None = None, *, digits: Sequence[int] = (1, 2, 3, 4, 5, 6, 7, 8, 9), cap: int | None = None, sample: str = 'stride', seed: int = 0, timestep: ndarray | None = None)[source]#
Bases:
LagBinsExactPseudo-log lag bins using digits * powers of ten on the timestep grid.
- classmethod from_source(source: Any, *args: Any, **kwargs: Any) LagBinsPseudoLog[source]#
- class jaxdem.analysis.Pairs(pair_i: ndarray, pair_j: ndarray, bin_id: ndarray, counts_per_bin: ndarray)#
Bases:
objectFlat representation of bin tuples, suitable for JAX execution.
- pair_i#
shape (P,) int array
- Type:
numpy.ndarray
- pair_j#
shape (P,) int array
- Type:
numpy.ndarray
- bin_id#
shape (P,) int array in [0, B)
- Type:
numpy.ndarray
- counts_per_bin#
shape (B,) int array (number of tuples per bin)
- Type:
numpy.ndarray
- pair_i: ndarray#
- pair_j: ndarray#
- bin_id: ndarray#
- counts_per_bin: ndarray#
- class jaxdem.analysis.TimeBins(T: int, t_min: int | None = None, t_max: int | None = None, *, timestep: ndarray | None = None)[source]#
Bases:
BinSpecOne bin per time index (optionally restricted to a timestep range).
- jaxdem.analysis.build_pairs(binspec: BinSpec) Pairs[source]#
Build (pair_i, pair_j, bin_id) arrays from a BinSpec.
- jaxdem.analysis.evaluate_binned(kernel: Any, arrays: Mapping[str, Any], binspec: BinSpec, *, kernel_kwargs: dict[str, Any] | None = None, jit: bool = True, chunk_size: int | None = None) Binned[source]#
Run a kernel over bins and average in JAX.
- Parameters:
kernel – pure function called as kernel(arrays, t0, t1, **kernel_kwargs).
arrays – mapping of field name -> array with leading time axis, e.g. pos: (T,N,d) or (T,S,N,d)
binspec – bin specification (host-side); defines which indices to use.
kernel_kwargs – passed to kernel.
jit – whether to jit the core compute.
chunk_size – optional maximum number of pairs to evaluate per chunk. When None (the default) all pairs are processed in a single
jax.vmapcall — identical to the previous behaviour. Set to a positive integer to process pairs in chunks viajax.lax.scan, which keeps peak device memory proportional to chunk_size rather than the total number of pairs.
- jaxdem.analysis.msd_kernel(arrays: Mapping[str, Array], t0: Any, t1: Any) Array[source]#
Mean-squared displacement.
Works for both: - pos[t]: (N,d) -> returns () scalar - pos[t]: (S,N,d) -> returns (S,) vector
- jaxdem.analysis.isf_self_isotropic_kernel(arrays: Mapping[str, Array], t0: Any, t1: Any, *, k: Any) Array[source]#
Self intermediate scattering function (isotropic average).
For isotropic averaging: - 2D: Fs(k, t) = < J0(k * |dr|) > - 3D: Fs(k, t) = < sin(k|dr|)/(k|dr|) >
Supports: - pos[t]: (N,d) -> returns () if k scalar else (K,) - pos[t]: (S,N,d)-> returns (S,) if k scalar else (S,K)
- jaxdem.analysis.unwrap_angles_2d(q_w: Array, q_xyz: Array) Array[source]#
Convert (T, N, 1) and (T, N, 3) quaternion trajectory to unwrapped cumulative angle (T, N).
- jaxdem.analysis.msad_kernel_2d(arrays: dict[str, Array], t0: int, t1: int) Array[source]#
Mean-squared angular displacement on unwrapped cumulative angle.
- jaxdem.analysis.isf_angular_kernel_2d(arrays: dict[str, Array], t0: int, t1: int, *, theta_0: float) Array[source]#
Angular ISF: <cos(θ₀ · Δθ)>.
Modules