Source code for jaxdem.utils.rollout_schedules

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project - https://github.com/cdelv/JaxDEM
"""
Utilities to generate step indices for trajectory logging.
"""

from __future__ import annotations

from typing import Optional
import numpy as np


[docs] def make_save_steps_linear( *, num_steps: int, save_freq: int, include_step0: bool = True, ) -> np.ndarray: num_steps = int(num_steps) save_freq = int(save_freq) if num_steps < 0: raise ValueError("num_steps must be >= 0") if save_freq < 1: raise ValueError("save_freq must be >= 1") start = 0 if include_step0 else save_freq if start > num_steps: return np.zeros((0,), dtype=np.int32) return np.arange(start, num_steps + 1, save_freq, dtype=np.int32)
[docs] def make_save_steps_pseudolog( *, num_steps: int, reset_save_decade: int, min_save_decade: int, decade: int = 10, include_step0: bool = True, cap: Optional[int] = None, ) -> np.ndarray: """ Pseudo-log schedule compatible with the BaseLogGroup logic. Parameters are interpreted on the integer timestep grid 0..num_steps (inclusive). """ num_steps = int(num_steps) reset = int(reset_save_decade) f0 = int(min_save_decade) decade = int(decade) if num_steps < 0: raise ValueError("num_steps must be >= 0") if reset < 1: raise ValueError("reset_save_decade must be >= 1") if f0 < 1: raise ValueError("min_save_decade must be >= 1") if decade < 2: raise ValueError("decade must be >= 2") out: list[int] = [] max_block = num_steps // reset for b in range(max_block + 1): base = b * reset off_min = 0 if (b == 0 and include_step0) else 1 off_max = min(reset, num_steps - base) if off_min > off_max: continue k = 0 while True: freq = f0 * (decade**k) region_end = min(off_max, f0 * (decade ** (k + 1))) region_start = off_min if k == 0 else max(off_min, f0 * (decade**k) + 1) if region_start <= region_end: first = ((region_start + freq - 1) // freq) * freq for off in range(first, region_end + 1, freq): out.append(base + off) if cap is not None and len(out) >= cap: return np.asarray(sorted(set(out)), dtype=np.int32) if region_end >= off_max: break k += 1 return np.asarray(sorted(set(out)), dtype=np.int32)