jaxdem.analysis.engine#
Functions
|
Run a kernel over bins and average in JAX. |
|
Deprecated alias for evaluate_binned(). |
Classes
|
Binned accumulation output. |
- class jaxdem.analysis.engine.Binned(sums: Any, counts: Array, mean: Any, pairs: Pairs)[source]#
Bases:
objectBinned accumulation output.
- sums#
pytree with each leaf shaped (B, …)
- Type:
Any
- counts#
array shape (B,) float32
- Type:
jax.Array
- mean#
pytree with each leaf shaped (B, …)
- Type:
Any
- pairs#
flattened pair representation used for the run (host arrays)
- sums: Any#
- counts: Array#
- mean: Any#
- jaxdem.analysis.engine.evaluate_binned(kernel: Any, arrays: Mapping[str, Any], binspec: BinSpec, *, kernel_kwargs: Dict[str, Any] | None = None, jit: bool = True) Binned[source][source]#
Run a kernel over bins and average in JAX.
- Parameters:
kernel (callable) – Pure function called as
kernel(arrays, t0, t1, **kernel_kwargs).arrays (Mapping[str, Any]) – Mapping of field name to array with leading time axis, e.g.
pos: (T, N, d)orpos: (T, S, N, d).binspec (BinSpec) – Bin specification (host-side); defines which indices to use.
kernel_kwargs (dict or None, optional) – Extra keyword arguments passed to kernel.
jit (bool, optional) – Whether to JIT-compile the core compute.