jaxdem.analysis#

Post-processing / analysis utilities.

This subpackage contains a minimal, JAX-friendly binned-accumulation engine:

class jaxdem.analysis.BinSpec(T: int, timestep: ndarray | None = None)[source]#

Bases: object

Abstract bin specification.

Parameters:
  • T (int) – Number of frames (time steps).

  • timestep (np.ndarray or None, optional) – Physical timestep labels of shape (T,). If absent, defaults to np.arange(T).

num_bins() int[source]#
bins() Iterable[int][source]#
value_of_bin(b: int) int | float | tuple[Any, ...][source]#
values() ndarray[source]#
weight(b: int) int[source]#

A cheap estimate of work for bin b (e.g., number of pairs).

iter_tuples(b: int) Iterator[list[int]][source]#

Yield index tuples (lists of ints) that belong to bin b.

class jaxdem.analysis.Binned(sums: Any, counts: Array, mean: Any, pairs: Pairs)#

Bases: object

Binned accumulation output.

sums#

pytree with each leaf shaped (B, …)

Type:

Any

counts#

array shape (B,) float32

Type:

jax.Array

mean#

pytree with each leaf shaped (B, …)

Type:

Any

pairs#

flattened pair representation used for the run (host arrays)

Type:

jaxdem.analysis.pairs.Pairs

sums: Any#
counts: Array#
mean: Any#
pairs: Pairs#
class jaxdem.analysis.KernelFn(*args, **kwargs)[source]#

Bases: Protocol

class jaxdem.analysis.LagBinsExact(T: int, taus: Sequence[int] | ndarray[Any, Any], *, cap: int | None = None, sample: str = 'stride', seed: int = 0, timestep: ndarray | None = None)[source]#

Bases: BinSpec

Bins for a provided set of exact physical time lags (taus).

classmethod from_source(source: Any, *args: Any, **kwargs: Any) LagBinsExact[source]#
num_bins() int[source]#
value_of_bin(b: int) int[source]#
weight(b: int) int[source]#

A cheap estimate of work for bin b (e.g., number of pairs).

iter_tuples(b: int) Iterator[list[int]][source]#

Yield index tuples (lists of ints) that belong to bin b.

class jaxdem.analysis.LagBinsLinear(T: int, dt_min: int | None = None, dt_max: int | None = None, *, step: int = 1, num_points: int | None = None, cap: int | None = None, sample: str = 'stride', seed: int = 0, timestep: ndarray | None = None)[source]#

Bases: LagBinsExact

Linearly spaced lag bins between [dt_min, dt_max] on the timestep grid.

classmethod from_source(source: Any, *args: Any, **kwargs: Any) LagBinsLinear[source]#
class jaxdem.analysis.LagBinsLog(T: int, dt_min: int | None = None, dt_max: int | None = None, *, num_bins: int | None = None, num_per_decade: int | None = None, cap: int | None = None, sample: str = 'stride', seed: int = 0, timestep: ndarray | None = None)[source]#

Bases: LagBinsExact

Log-spaced lag bins on the realizable timestep grid.

classmethod from_source(source: Any, *args: Any, **kwargs: Any) LagBinsLog[source]#
class jaxdem.analysis.LagBinsPseudoLog(T: int, dt_min: int | None = None, dt_max: int | None = None, *, digits: Sequence[int] = (1, 2, 3, 4, 5, 6, 7, 8, 9), cap: int | None = None, sample: str = 'stride', seed: int = 0, timestep: ndarray | None = None)[source]#

Bases: LagBinsExact

Pseudo-log lag bins using digits * powers of ten on the timestep grid.

classmethod from_source(source: Any, *args: Any, **kwargs: Any) LagBinsPseudoLog[source]#
class jaxdem.analysis.Pairs(pair_i: ndarray, pair_j: ndarray, bin_id: ndarray, counts_per_bin: ndarray)#

Bases: object

Flat representation of bin tuples, suitable for JAX execution.

pair_i#

shape (P,) int array

Type:

numpy.ndarray

pair_j#

shape (P,) int array

Type:

numpy.ndarray

bin_id#

shape (P,) int array in [0, B)

Type:

numpy.ndarray

counts_per_bin#

shape (B,) int array (number of tuples per bin)

Type:

numpy.ndarray

pair_i: ndarray#
pair_j: ndarray#
bin_id: ndarray#
counts_per_bin: ndarray#
class jaxdem.analysis.TimeBins(T: int, t_min: int | None = None, t_max: int | None = None, *, timestep: ndarray | None = None)[source]#

Bases: BinSpec

One bin per time index (optionally restricted to a timestep range).

classmethod from_source(source: Any, t_min: int | None = None, t_max: int | None = None) TimeBins[source]#
num_bins() int[source]#
value_of_bin(b: int) int[source]#
weight(b: int) int[source]#

A cheap estimate of work for bin b (e.g., number of pairs).

iter_tuples(b: int) Iterator[list[int]][source]#

Yield index tuples (lists of ints) that belong to bin b.

jaxdem.analysis.build_pairs(binspec: BinSpec) Pairs[source]#

Build (pair_i, pair_j, bin_id) arrays from a BinSpec.

jaxdem.analysis.evaluate_binned(kernel: Any, arrays: Mapping[str, Any], binspec: BinSpec, *, kernel_kwargs: dict[str, Any] | None = None, jit: bool = True, chunk_size: int | None = None) Binned[source]#

Run a kernel over bins and average in JAX.

Parameters:
  • kernel – pure function called as kernel(arrays, t0, t1, **kernel_kwargs).

  • arrays – mapping of field name -> array with leading time axis, e.g. pos: (T,N,d) or (T,S,N,d)

  • binspec – bin specification (host-side); defines which indices to use.

  • kernel_kwargs – passed to kernel.

  • jit – whether to jit the core compute.

  • chunk_size – optional maximum number of pairs to evaluate per chunk. When None (the default) all pairs are processed in a single jax.vmap call — identical to the previous behaviour. Set to a positive integer to process pairs in chunks via jax.lax.scan, which keeps peak device memory proportional to chunk_size rather than the total number of pairs.

jaxdem.analysis.msd_kernel(arrays: Mapping[str, Array], t0: Any, t1: Any) Array[source]#

Mean-squared displacement.

Works for both: - pos[t]: (N,d) -> returns () scalar - pos[t]: (S,N,d) -> returns (S,) vector

jaxdem.analysis.isf_self_isotropic_kernel(arrays: Mapping[str, Array], t0: Any, t1: Any, *, k: Any) Array[source]#

Self intermediate scattering function (isotropic average).

For isotropic averaging: - 2D: Fs(k, t) = < J0(k * |dr|) > - 3D: Fs(k, t) = < sin(k|dr|)/(k|dr|) >

Supports: - pos[t]: (N,d) -> returns () if k scalar else (K,) - pos[t]: (S,N,d)-> returns (S,) if k scalar else (S,K)

jaxdem.analysis.unwrap_angles_2d(q_w: Array, q_xyz: Array) Array[source]#

Convert (T, N, 1) and (T, N, 3) quaternion trajectory to unwrapped cumulative angle (T, N).

jaxdem.analysis.msad_kernel_2d(arrays: dict[str, Array], t0: int, t1: int) Array[source]#

Mean-squared angular displacement on unwrapped cumulative angle.

jaxdem.analysis.isf_angular_kernel_2d(arrays: dict[str, Array], t0: int, t1: int, *, theta_0: float) Array[source]#

Angular ISF: <cos(θ₀ · Δθ)>.

Modules

bessel

Bessel functions for JAX.

bins

engine

kernels

Kernel typing + small helpers for JAX analysis.

pairs