jaxdem.analysis.engine#

Functions

evaluate_binned(kernel, arrays, binspec, *)

Run a kernel over bins and average in JAX.

run_binned_jax(*args, **kwargs)

Deprecated alias for evaluate_binned().

Classes

Binned(sums, counts, mean, pairs)

Binned accumulation output.

BinnedResult

class jaxdem.analysis.engine.Binned(sums: Any, counts: Array, mean: Any, pairs: Pairs)[source]#

Bases: object

Binned accumulation output.

sums#

pytree with each leaf shaped (B, …)

Type:

Any

counts#

array shape (B,) float32

Type:

jax.Array

mean#

pytree with each leaf shaped (B, …)

Type:

Any

pairs#

flattened pair representation used for the run (host arrays)

Type:

jaxdem.analysis.pairs.Pairs

sums: Any#
counts: Array#
mean: Any#
pairs: Pairs#
jaxdem.analysis.engine.evaluate_binned(kernel: Any, arrays: Mapping[str, Any], binspec: BinSpec, *, kernel_kwargs: Dict[str, Any] | None = None, jit: bool = True) Binned[source][source]#

Run a kernel over bins and average in JAX.

Parameters:
  • kernel (callable) – Pure function called as kernel(arrays, t0, t1, **kernel_kwargs).

  • arrays (Mapping[str, Any]) – Mapping of field name to array with leading time axis, e.g. pos: (T, N, d) or pos: (T, S, N, d).

  • binspec (BinSpec) – Bin specification (host-side); defines which indices to use.

  • kernel_kwargs (dict or None, optional) – Extra keyword arguments passed to kernel.

  • jit (bool, optional) – Whether to JIT-compile the core compute.

jaxdem.analysis.engine.run_binned_jax(*args: Any, **kwargs: Any) Binned[source][source]#

Deprecated alias for evaluate_binned().

jaxdem.analysis.engine.BinnedResult#

alias of Binned