Source code for jaxdem.analysis.pairs
# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project - https://github.com/cdelv/JaxDEM
from __future__ import annotations
"""Helpers for converting bins into flat index-pairs.
The JAX engine operates on a flat list of pairs (t0, t1) and a `bin_id` per pair.
"""
from dataclasses import dataclass
from typing import List, Tuple
import numpy as np
from .bins import BinSpec
[docs]
@dataclass(frozen=True)
class Pairs:
"""Flat representation of bin tuples, suitable for JAX execution.
Attributes:
pair_i: shape (P,) int array
pair_j: shape (P,) int array
bin_id: shape (P,) int array in [0, B)
counts_per_bin: shape (B,) int array (number of tuples per bin)
"""
pair_i: np.ndarray
pair_j: np.ndarray
bin_id: np.ndarray
counts_per_bin: np.ndarray
[docs]
def build_pairs(binspec: BinSpec) -> Pairs:
"""Build (pair_i, pair_j, bin_id) arrays from a BinSpec."""
B = binspec.num_bins()
pair_i_list: List[int] = []
pair_j_list: List[int] = []
bin_id_list: List[int] = []
counts = np.zeros((B,), dtype=np.int64)
for b in range(B):
cnt = 0
for idxs in binspec.iter_tuples(b):
if not idxs:
continue
i = int(idxs[0])
j = int(idxs[-1])
pair_i_list.append(i)
pair_j_list.append(j)
bin_id_list.append(int(b))
cnt += 1
counts[b] = cnt
pair_i = np.asarray(pair_i_list, dtype=np.int32)
pair_j = np.asarray(pair_j_list, dtype=np.int32)
bin_id = np.asarray(bin_id_list, dtype=np.int32)
return Pairs(pair_i=pair_i, pair_j=pair_j, bin_id=bin_id, counts_per_bin=counts)
[docs]
def flatten_pairs(binspec: BinSpec) -> Pairs:
"""Deprecated alias for build_pairs()."""
import warnings
warnings.warn(
"jaxdem.analysis.flatten_pairs is deprecated; use build_pairs instead.",
DeprecationWarning,
stacklevel=2,
)
return build_pairs(binspec)
FlatPairs = Pairs # backwards-compatible alias