jaxdem.analysis#

Post-processing / analysis utilities.

This subpackage contains a minimal, JAX-friendly binned-accumulation engine:

class jaxdem.analysis.BinSpec(T: int, timestep: ndarray | None = None)[source][source]#

Bases: object

Abstract bin specification.

Parameters:
  • T (int) – Number of frames (time steps).

  • timestep (np.ndarray or None, optional) – Physical timestep labels of shape (T,). If absent, defaults to np.arange(T).

num_bins() int[source][source]#
bins() Iterable[int][source][source]#
value_of_bin(b: int) int | float | Tuple[Any, ...][source][source]#
values() ndarray[source][source]#
weight(b: int) int[source][source]#

A cheap estimate of work for bin b (e.g., number of pairs).

iter_tuples(b: int) Iterator[List[int]][source][source]#

Yield index tuples (lists of ints) that belong to bin b.

class jaxdem.analysis.TimeBins(T: int, t_min: int | None = None, t_max: int | None = None, *, timestep: ndarray | None = None)[source][source]#

Bases: BinSpec

One bin per time index (optionally restricted to a timestep range).

classmethod from_source(source: Any, t_min: int | None = None, t_max: int | None = None) TimeBins[source][source]#
num_bins() int[source][source]#
value_of_bin(b: int) int[source][source]#
weight(b: int) int[source][source]#

A cheap estimate of work for bin b (e.g., number of pairs).

iter_tuples(b: int) Iterator[List[int]][source][source]#

Yield index tuples (lists of ints) that belong to bin b.

class jaxdem.analysis.LagBinsExact(T: int, taus: Sequence[int] | ndarray[Any, Any], *, cap: int | None = None, sample: str = 'stride', seed: int = 0, timestep: ndarray | None = None)[source][source]#

Bases: BinSpec

Bins for a provided set of exact physical time lags (taus).

classmethod from_source(source: Any, *args: Any, **kwargs: Any) LagBinsExact[source][source]#
num_bins() int[source][source]#
value_of_bin(b: int) int[source][source]#
weight(b: int) int[source][source]#

A cheap estimate of work for bin b (e.g., number of pairs).

iter_tuples(b: int) Iterator[List[int]][source][source]#

Yield index tuples (lists of ints) that belong to bin b.

class jaxdem.analysis.LagBinsLinear(T: int, dt_min: int | None = None, dt_max: int | None = None, *, step: int = 1, num_points: int | None = None, cap: int | None = None, sample: str = 'stride', seed: int = 0, timestep: ndarray | None = None)[source][source]#

Bases: LagBinsExact

Linearly spaced lag bins between [dt_min, dt_max] on the timestep grid.

classmethod from_source(source: Any, *args: Any, **kwargs: Any) LagBinsLinear[source][source]#
class jaxdem.analysis.LagBinsLog(T: int, dt_min: int | None = None, dt_max: int | None = None, *, num_bins: int | None = None, num_per_decade: int | None = None, cap: int | None = None, sample: str = 'stride', seed: int = 0, timestep: ndarray | None = None)[source][source]#

Bases: LagBinsExact

Log-spaced lag bins on the realizable timestep grid.

classmethod from_source(source: Any, *args: Any, **kwargs: Any) LagBinsLog[source][source]#
class jaxdem.analysis.LagBinsPseudoLog(T: int, dt_min: int | None = None, dt_max: int | None = None, *, digits: Sequence[int] = (1, 2, 3, 4, 5, 6, 7, 8, 9), cap: int | None = None, sample: str = 'stride', seed: int = 0, timestep: ndarray | None = None)[source][source]#

Bases: LagBinsExact

Pseudo-log lag bins using digits * powers of ten on the timestep grid.

classmethod from_source(source: Any, *args: Any, **kwargs: Any) LagBinsPseudoLog[source][source]#
class jaxdem.analysis.Pairs(pair_i: ndarray, pair_j: ndarray, bin_id: ndarray, counts_per_bin: ndarray)[source]#

Bases: object

Flat representation of bin tuples, suitable for JAX execution.

pair_i#

shape (P,) int array

Type:

numpy.ndarray

pair_j#

shape (P,) int array

Type:

numpy.ndarray

bin_id#

shape (P,) int array in [0, B)

Type:

numpy.ndarray

counts_per_bin#

shape (B,) int array (number of tuples per bin)

Type:

numpy.ndarray

pair_i: ndarray#
pair_j: ndarray#
bin_id: ndarray#
counts_per_bin: ndarray#
jaxdem.analysis.build_pairs(binspec: BinSpec) Pairs[source][source]#

Build (pair_i, pair_j, bin_id) arrays from a BinSpec.

class jaxdem.analysis.Binned(sums: Any, counts: Array, mean: Any, pairs: Pairs)[source]#

Bases: object

Binned accumulation output.

sums#

pytree with each leaf shaped (B, …)

Type:

Any

counts#

array shape (B,) float32

Type:

jax.Array

mean#

pytree with each leaf shaped (B, …)

Type:

Any

pairs#

flattened pair representation used for the run (host arrays)

Type:

jaxdem.analysis.pairs.Pairs

sums: Any#
counts: Array#
mean: Any#
pairs: Pairs#
jaxdem.analysis.evaluate_binned(kernel: Any, arrays: Mapping[str, Any], binspec: BinSpec, *, kernel_kwargs: Dict[str, Any] | None = None, jit: bool = True) Binned[source][source]#

Run a kernel over bins and average in JAX.

Parameters:
  • kernel (callable) – Pure function called as kernel(arrays, t0, t1, **kernel_kwargs).

  • arrays (Mapping[str, Any]) – Mapping of field name to array with leading time axis, e.g. pos: (T, N, d) or pos: (T, S, N, d).

  • binspec (BinSpec) – Bin specification (host-side); defines which indices to use.

  • kernel_kwargs (dict or None, optional) – Extra keyword arguments passed to kernel.

  • jit (bool, optional) – Whether to JIT-compile the core compute.

class jaxdem.analysis.KernelFn(*args, **kwargs)[source][source]#

Bases: Protocol

jaxdem.analysis.msd_kernel(arrays: Mapping[str, Array], t0: Any, t1: Any) Array[source][source]#

Mean-squared displacement.

Works for both: - pos[t]: (N,d) -> returns () scalar - pos[t]: (S,N,d) -> returns (S,) vector

Modules

bessel

Bessel functions for JAX.

bins

engine

kernels

Kernel typing + small helpers for JAX analysis.

pairs