jaxdem.analysis.engine#
Functions
|
Run a kernel over bins and average in JAX. |
Classes
|
Binned accumulation output. |
- class jaxdem.analysis.engine.Binned(sums: Any, counts: Array, mean: Any, pairs: Pairs)#
Bases:
objectBinned accumulation output.
- sums#
pytree with each leaf shaped (B, …)
- Type:
Any
- counts#
array shape (B,) float32
- Type:
jax.Array
- mean#
pytree with each leaf shaped (B, …)
- Type:
Any
- pairs#
flattened pair representation used for the run (host arrays)
- sums: Any#
- counts: Array#
- mean: Any#
- jaxdem.analysis.engine.evaluate_binned(kernel: Any, arrays: Mapping[str, Any], binspec: BinSpec, *, kernel_kwargs: dict[str, Any] | None = None, jit: bool = True, chunk_size: int | None = None) Binned[source]#
Run a kernel over bins and average in JAX.
- Parameters:
kernel – pure function called as kernel(arrays, t0, t1, **kernel_kwargs).
arrays – mapping of field name -> array with leading time axis, e.g. pos: (T,N,d) or (T,S,N,d)
binspec – bin specification (host-side); defines which indices to use.
kernel_kwargs – passed to kernel.
jit – whether to jit the core compute.
chunk_size – optional maximum number of pairs to evaluate per chunk. When None (the default) all pairs are processed in a single
jax.vmapcall — identical to the previous behaviour. Set to a positive integer to process pairs in chunks viajax.lax.scan, which keeps peak device memory proportional to chunk_size rather than the total number of pairs.