jaxdem.analysis.engine#

Functions

evaluate_binned(kernel, arrays, binspec, *)

Run a kernel over bins and average in JAX.

Classes

Binned(sums, counts, mean, pairs)

Binned accumulation output.

class jaxdem.analysis.engine.Binned(sums: Any, counts: Array, mean: Any, pairs: Pairs)#

Bases: object

Binned accumulation output.

sums#

pytree with each leaf shaped (B, …)

Type:

Any

counts#

array shape (B,) float32

Type:

jax.Array

mean#

pytree with each leaf shaped (B, …)

Type:

Any

pairs#

flattened pair representation used for the run (host arrays)

Type:

jaxdem.analysis.pairs.Pairs

sums: Any#
counts: Array#
mean: Any#
pairs: Pairs#
jaxdem.analysis.engine.evaluate_binned(kernel: Any, arrays: Mapping[str, Any], binspec: BinSpec, *, kernel_kwargs: dict[str, Any] | None = None, jit: bool = True, chunk_size: int | None = None) Binned[source]#

Run a kernel over bins and average in JAX.

Parameters:
  • kernel – pure function called as kernel(arrays, t0, t1, **kernel_kwargs).

  • arrays – mapping of field name -> array with leading time axis, e.g. pos: (T,N,d) or (T,S,N,d)

  • binspec – bin specification (host-side); defines which indices to use.

  • kernel_kwargs – passed to kernel.

  • jit – whether to jit the core compute.

  • chunk_size – optional maximum number of pairs to evaluate per chunk. When None (the default) all pairs are processed in a single jax.vmap call — identical to the previous behaviour. Set to a positive integer to process pairs in chunks via jax.lax.scan, which keeps peak device memory proportional to chunk_size rather than the total number of pairs.