# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project - https://github.com/cdelv/JaxDEM
from __future__ import annotations
"""Bin specifications for time-series and lagged analyses.
Minimal initial scope:
- Operates on in-memory arrays (dict of arrays with leading time axis).
- Bins are generated on the host (NumPy/Python), while compute is done in JAX.
`BinSpec.iter_tuples(b)` yields a list of integer indices.
For lag bins, this is typically `[t0, t1]`.
For time bins, this is typically `[t]`.
"""
from dataclasses import dataclass
from typing import Any, Iterable, Iterator, List, Optional, Sequence, Tuple, Union
import numpy as np
[docs]
class BinSpec:
"""Abstract bin specification.
Parameters
----------
T : int
Number of frames (time steps).
timestep : np.ndarray or None, optional
Physical timestep labels of shape ``(T,)``. If absent,
defaults to ``np.arange(T)``.
"""
def __init__(self, T: int, timestep: Optional[np.ndarray] = None):
self.T = int(T)
if self.T < 0:
raise ValueError("T must be non-negative")
if timestep is None:
self.timestep = np.arange(self.T, dtype=np.int64)
else:
ts = np.asarray(timestep).squeeze()
if ts.ndim != 1:
raise ValueError("timestep must be a 1D array")
if int(ts.size) != self.T:
raise ValueError("timestep length must equal T")
self.timestep = ts.astype(np.int64, copy=False)
[docs]
def num_bins(self) -> int:
raise NotImplementedError
[docs]
def bins(self) -> Iterable[int]:
return range(self.num_bins())
[docs]
def value_of_bin(self, b: int) -> Union[int, float, Tuple[Any, ...]]:
return b
[docs]
def values(self) -> np.ndarray:
B = self.num_bins()
return np.asarray([self.value_of_bin(b) for b in range(B)])
[docs]
def weight(self, b: int) -> int:
"""A cheap estimate of work for bin `b` (e.g., number of pairs)."""
raise NotImplementedError
[docs]
def iter_tuples(self, b: int) -> Iterator[List[int]]:
"""Yield index tuples (lists of ints) that belong to bin `b`."""
raise NotImplementedError
def _infer_timestep_and_T_from_source(source: Any) -> Tuple[np.ndarray, int]:
"""Infer (timestep, T) from simple in-memory sources.
Supported:
- int: interpreted as T
- dict: inferred from the first array's leading dimension; optionally
uses `source["timestep"]` if present.
"""
if isinstance(source, int):
T = int(source)
return np.arange(T, dtype=np.int64), T
if isinstance(source, dict):
if "timestep" in source:
ts = np.asarray(source["timestep"]).astype(np.int64, copy=False)
if ts.ndim != 1:
raise ValueError("source['timestep'] must be 1D")
return ts, int(ts.shape[0])
if not source:
raise ValueError("Empty dict provided; cannot infer T")
first = next(iter(source.values()))
arr = np.asarray(first)
if arr.ndim == 0:
raise ValueError("Arrays must have at least one dimension to infer T")
T = int(arr.shape[0])
return np.arange(T, dtype=np.int64), T
num_frames_fn = getattr(source, "num_frames", None)
if callable(num_frames_fn):
T = int(num_frames_fn())
return np.arange(T, dtype=np.int64), T
raise TypeError(
"Unsupported source type for inferring timestep/T. "
"Provide an int, a dict of arrays, or an object with num_frames()."
)
[docs]
class TimeBins(BinSpec):
"""One bin per time index (optionally restricted to a timestep range)."""
def __init__(
self,
T: int,
t_min: Optional[int] = None,
t_max: Optional[int] = None,
*,
timestep: Optional[np.ndarray] = None,
):
super().__init__(T, timestep=timestep)
ts = self.timestep
self._time_to_index = {int(t): i for i, t in enumerate(ts)}
if t_min is None:
lo_idx = 0
else:
key = int(t_min)
if key not in self._time_to_index:
raise ValueError("t_min not found in timestep array")
lo_idx = int(self._time_to_index[key])
if t_max is None:
hi_idx = self.T - 1
else:
key = int(t_max)
if key not in self._time_to_index:
raise ValueError("t_max not found in timestep array")
hi_idx = int(self._time_to_index[key])
if self.T == 0:
self._indices = np.empty((0,), dtype=np.int64)
else:
if not (0 <= lo_idx <= hi_idx <= self.T - 1):
raise ValueError("Invalid time range for TimeBins")
self._indices = np.arange(lo_idx, hi_idx + 1, dtype=np.int64)
[docs]
@classmethod
def from_source(
cls, source: Any, t_min: Optional[int] = None, t_max: Optional[int] = None
) -> TimeBins:
timestep, T = _infer_timestep_and_T_from_source(source)
return cls(T, t_min=t_min, t_max=t_max, timestep=timestep)
[docs]
def num_bins(self) -> int:
return int(self._indices.size)
[docs]
def value_of_bin(self, b: int) -> int:
idx = int(self._indices[int(b)])
return int(self.timestep[idx])
[docs]
def weight(self, b: int) -> int:
return 1
[docs]
def iter_tuples(self, b: int) -> Iterator[List[int]]:
idx = int(self._indices[int(b)])
yield [idx]
def _deterministic_subset(
n: int,
cap: Optional[int],
method: str = "stride",
seed: int = 0,
tag: int = 0,
) -> np.ndarray:
"""Deterministic selection of indices from [0, n).
Included for completeness (useful later), but in the current minimal
implementation we typically keep `cap=None`.
"""
if cap is None or cap >= n:
return np.arange(n, dtype=np.int64)
if cap <= 0:
return np.empty(0, dtype=np.int64)
if method == "stride":
if cap == 1:
return np.array([0], dtype=np.int64)
m = int(cap - 1)
n1 = int(n - 1)
ks = np.arange(cap, dtype=np.float64)
idx = np.floor((ks * n1) / m).astype(np.int64)
if idx.size > 1:
idx = np.unique(idx)
return idx
if method == "rng":
mix = (
(np.uint64(seed) * np.uint64(0x9E3779B97F4A7C15))
^ (np.uint64(tag) * np.uint64(0xBF58476D1CE4E5B9))
) & np.uint64(0xFFFFFFFFFFFFFFFF)
rng = np.random.default_rng(int(mix))
return np.sort(rng.choice(n, size=cap, replace=False).astype(np.int64))
raise ValueError(f"Unknown sampling method: {method!r}")
def _count_pairs_for_tau(timestep: np.ndarray, tau: int) -> int:
ts = timestep
T = int(ts.shape[0])
i = 0
j = 0
cnt = 0
while i < T and j < T:
while j < T and (int(ts[j]) - int(ts[i])) < tau:
j += 1
if j >= T:
break
diff = int(ts[j]) - int(ts[i])
if diff == tau:
cnt += 1
i += 1
if j < i:
j = i
elif diff > tau:
i += 1
if j < i:
j = i
return cnt
[docs]
class LagBinsExact(BinSpec):
"""Bins for a provided set of exact physical time lags (taus)."""
def __init__(
self,
T: int,
taus: Sequence[int] | np.ndarray[Any, Any],
*,
cap: Optional[int] = None,
sample: str = "stride",
seed: int = 0,
timestep: Optional[np.ndarray] = None,
):
super().__init__(T, timestep=timestep)
taus_arr = np.asarray(taus, dtype=np.int64)
if taus_arr.size == 0:
raise ValueError("taus must be non-empty")
max_tau = int(self.timestep[-1] - self.timestep[0]) if self.T > 0 else 0
if max_tau < 1:
raise ValueError("Insufficient time range for lag bins")
taus_arr = taus_arr[(taus_arr >= 1) & (taus_arr <= max_tau)]
taus_arr = np.unique(taus_arr)
if taus_arr.size == 0:
raise ValueError("No valid integer time-lags produced; adjust parameters")
kept: List[int] = []
counts: List[int] = []
for tau in taus_arr:
c = _count_pairs_for_tau(self.timestep, int(tau))
if c > 0:
kept.append(int(tau))
counts.append(int(c))
if not kept:
raise ValueError("All tau bins are empty; adjust parameters or data")
self._taus = np.asarray(kept, dtype=np.int64)
self._pairs_per_bin = np.asarray(counts, dtype=np.int64)
self.cap = None if cap is None else int(cap)
self.sample = str(sample)
self.seed = int(seed)
[docs]
@classmethod
def from_source(
cls,
source: Any,
*args: Any,
**kwargs: Any,
) -> LagBinsExact:
if "taus" in kwargs:
taus = kwargs.pop("taus")
extra_args: Tuple[Any, ...] = args
elif args:
taus = args[0]
extra_args = args[1:]
else:
raise TypeError("from_source() missing required argument: 'taus'")
cap = kwargs.pop("cap", None)
sample = kwargs.pop("sample", "stride")
seed = kwargs.pop("seed", 0)
if extra_args or kwargs:
raise TypeError("from_source() received unexpected arguments")
timestep, T = _infer_timestep_and_T_from_source(source)
return cls(T, taus, cap=cap, sample=sample, seed=seed, timestep=timestep)
[docs]
def num_bins(self) -> int:
return int(self._taus.size)
[docs]
def value_of_bin(self, b: int) -> int:
return int(self._taus[int(b)])
[docs]
def weight(self, b: int) -> int:
pairs = int(self._pairs_per_bin[int(b)])
if pairs <= 0:
return 0
return pairs if self.cap is None else min(pairs, self.cap)
[docs]
def iter_tuples(self, b: int) -> Iterator[List[int]]:
tau = int(self._taus[int(b)])
n_pairs = int(self._pairs_per_bin[int(b)])
if n_pairs <= 0:
return
sel = _deterministic_subset(n_pairs, self.cap, self.sample, self.seed, tag=tau)
ts = self.timestep
i = 0
j = 0
k = 0 # index within all matching pairs
p = 0 # pointer into sel
sel_size = int(sel.size)
while i < self.T and j < self.T and p < sel_size:
while j < self.T and (ts[j] - ts[i]) < tau:
j += 1
if j >= self.T:
break
diff = int(ts[j] - ts[i])
if diff == tau:
if k == int(sel[p]):
yield [int(i), int(j)]
p += 1
k += 1
i += 1
if j < i:
j = i
elif diff > tau:
i += 1
if j < i:
j = i
[docs]
class LagBinsLinear(LagBinsExact):
"""Linearly spaced lag bins between [dt_min, dt_max] on the timestep grid."""
def __init__(
self,
T: int,
dt_min: Optional[int] = None,
dt_max: Optional[int] = None,
*,
step: int = 1,
num_points: Optional[int] = None,
cap: Optional[int] = None,
sample: str = "stride",
seed: int = 0,
timestep: Optional[np.ndarray] = None,
):
ts = (
np.arange(T, dtype=np.int64)
if timestep is None
else np.asarray(timestep, dtype=np.int64)
)
max_tau = int(ts[-1] - ts[0]) if T > 0 else 0
if dt_min is None:
dt_min = 1
if dt_max is None:
dt_max = max_tau
if max_tau < 1:
raise ValueError("Insufficient time range for lag bins")
if not (1 <= int(dt_min) <= int(dt_max) <= max_tau):
raise ValueError("Invalid tau range for LagBinsLinear")
if step <= 0:
raise ValueError("step must be positive")
dts = np.arange(int(dt_min), int(dt_max) + 1, int(step), dtype=np.int64)
if num_points is not None:
m = int(num_points)
if m <= 0:
raise ValueError("num_points must be positive if provided")
n = int(dts.size)
if m < n:
if m == 1:
dts = dts[[0]]
else:
ks = np.arange(m, dtype=np.float64)
idx = np.floor(ks * (n - 1) / (m - 1)).astype(np.int64)
dts = dts[idx]
super().__init__(T, dts, cap=cap, sample=sample, seed=seed, timestep=ts)
[docs]
@classmethod
def from_source(
cls,
source: Any,
*args: Any,
**kwargs: Any,
) -> LagBinsLinear:
dt_min = kwargs.pop("dt_min", None)
dt_max = kwargs.pop("dt_max", None)
step = kwargs.pop("step", 1)
num_points = kwargs.pop("num_points", None)
cap = kwargs.pop("cap", None)
sample = kwargs.pop("sample", "stride")
seed = kwargs.pop("seed", 0)
if args:
dt_min = args[0]
if len(args) > 1:
dt_max = args[1]
if len(args) > 2 or kwargs:
raise TypeError("from_source() received unexpected arguments")
timestep, T = _infer_timestep_and_T_from_source(source)
return cls(
T,
dt_min=dt_min,
dt_max=dt_max,
step=step,
num_points=num_points,
cap=cap,
sample=sample,
seed=seed,
timestep=timestep,
)
[docs]
class LagBinsLog(LagBinsExact):
"""Log-spaced lag bins on the realizable timestep grid."""
def __init__(
self,
T: int,
dt_min: Optional[int] = None,
dt_max: Optional[int] = None,
*,
num_bins: Optional[int] = None,
num_per_decade: Optional[int] = None,
cap: Optional[int] = None,
sample: str = "stride",
seed: int = 0,
timestep: Optional[np.ndarray] = None,
):
ts = (
np.arange(T, dtype=np.int64)
if timestep is None
else np.asarray(timestep, dtype=np.int64)
)
max_tau = int(ts[-1] - ts[0]) if T > 0 else 0
if max_tau < 1:
raise ValueError("Insufficient time range for lag bins")
if dt_min is None:
dt_min = 1
if dt_max is None:
dt_max = max_tau
dt_min = max(1, int(dt_min))
dt_max = min(int(dt_max), max_tau)
diffs = np.diff(ts).astype(np.int64)
u = int(np.gcd.reduce(diffs)) if diffs.size > 0 else 1
if u <= 0:
u = 1
m_min = int(np.ceil(dt_min / u))
m_max = int(np.floor(dt_max / u))
if m_max < m_min:
raise ValueError("No realizable integer lags in the requested range")
if num_bins is None and num_per_decade is None:
# Simple default: ~10 bins per decade of m
span = max(1, int(np.ceil(10 * np.log10(max(1, m_max) / max(1, m_min)))))
num_bins = span
if num_bins is not None:
xs = np.logspace(np.log10(m_min), np.log10(m_max), int(num_bins))
m_vals = np.rint(xs).astype(np.int64)
else:
assert num_per_decade is not None
lo_dec = int(np.floor(np.log10(m_min)))
hi_dec = int(np.floor(np.log10(m_max)))
m_list: List[np.ndarray] = []
for k in range(lo_dec, hi_dec + 1):
left = max(m_min, int(10**k))
right = min(m_max, int(10 ** (k + 1)))
if right < left:
continue
xs = np.logspace(
np.log10(left), np.log10(right), int(num_per_decade), endpoint=False
)
m_list.append(np.rint(xs).astype(np.int64))
m_vals = np.concatenate(m_list) if m_list else np.array([], dtype=np.int64)
m_vals = m_vals[(m_vals >= m_min) & (m_vals <= m_max)]
m_vals = np.unique(m_vals)
if m_vals.size == 0:
raise ValueError("No valid log-spaced lags produced; adjust parameters")
taus = (m_vals * u).astype(np.int64)
super().__init__(T, taus, cap=cap, sample=sample, seed=seed, timestep=ts)
[docs]
@classmethod
def from_source(
cls,
source: Any,
*args: Any,
**kwargs: Any,
) -> LagBinsLog:
dt_min = kwargs.pop("dt_min", None)
dt_max = kwargs.pop("dt_max", None)
num_bins = kwargs.pop("num_bins", None)
num_per_decade = kwargs.pop("num_per_decade", None)
cap = kwargs.pop("cap", None)
sample = kwargs.pop("sample", "stride")
seed = kwargs.pop("seed", 0)
if args:
dt_min = args[0]
if len(args) > 1:
dt_max = args[1]
if len(args) > 2 or kwargs:
raise TypeError("from_source() received unexpected arguments")
timestep, T = _infer_timestep_and_T_from_source(source)
return cls(
T,
dt_min=dt_min,
dt_max=dt_max,
num_bins=num_bins,
num_per_decade=num_per_decade,
cap=cap,
sample=sample,
seed=seed,
timestep=timestep,
)
[docs]
class LagBinsPseudoLog(LagBinsExact):
"""Pseudo-log lag bins using digits * powers of ten on the timestep grid."""
def __init__(
self,
T: int,
dt_min: Optional[int] = None,
dt_max: Optional[int] = None,
*,
digits: Sequence[int] = tuple(range(1, 10)),
cap: Optional[int] = None,
sample: str = "stride",
seed: int = 0,
timestep: Optional[np.ndarray] = None,
):
ts = (
np.arange(T, dtype=np.int64)
if timestep is None
else np.asarray(timestep, dtype=np.int64)
)
max_tau = int(ts[-1] - ts[0]) if T > 0 else 0
if max_tau < 1:
raise ValueError("Insufficient time range for lag bins")
if dt_min is None:
dt_min = 1
if dt_max is None:
dt_max = max_tau
dt_min = max(1, int(dt_min))
dt_max = min(int(dt_max), max_tau)
digits = tuple(sorted(set(int(d) for d in digits if int(d) > 0)))
if not digits:
raise ValueError("digits must contain at least one positive integer")
diffs = np.diff(ts).astype(np.int64)
u = int(np.gcd.reduce(diffs)) if diffs.size > 0 else 1
if u <= 0:
u = 1
m_min = int(np.ceil(dt_min / u))
m_max = int(np.floor(dt_max / u))
if m_max < m_min:
raise ValueError("No realizable integer lags in the requested range")
lo_dec = int(np.floor(np.log10(m_min)))
hi_dec = int(np.floor(np.log10(m_max)))
m_vals: List[int] = []
for k in range(lo_dec, hi_dec + 1):
base = 10**k
for d in digits:
m = int(d * base)
if m_min <= m <= m_max:
m_vals.append(m)
m_arr = np.array(sorted(set(m_vals)), dtype=np.int64)
if m_arr.size == 0:
raise ValueError("No pseudo-log lags produced; adjust bounds/digits")
taus = (m_arr * u).astype(np.int64)
super().__init__(T, taus, cap=cap, sample=sample, seed=seed, timestep=ts)
[docs]
@classmethod
def from_source(
cls,
source: Any,
*args: Any,
**kwargs: Any,
) -> LagBinsPseudoLog:
dt_min = kwargs.pop("dt_min", None)
dt_max = kwargs.pop("dt_max", None)
digits = kwargs.pop("digits", tuple(range(1, 10)))
cap = kwargs.pop("cap", None)
sample = kwargs.pop("sample", "stride")
seed = kwargs.pop("seed", 0)
if args:
dt_min = args[0]
if len(args) > 1:
dt_max = args[1]
if len(args) > 2 or kwargs:
raise TypeError("from_source() received unexpected arguments")
timestep, T = _infer_timestep_and_T_from_source(source)
return cls(
T,
dt_min=dt_min,
dt_max=dt_max,
digits=digits,
cap=cap,
sample=sample,
seed=seed,
timestep=timestep,
)