Source code for jaxdem.writers.vtkWriter

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""
Implementation of the high-level VTKWriter frontend.
"""

from __future__ import annotations

import jax
import jax.numpy as jnp

import math
import os
import tempfile
import threading
from pathlib import Path
import shutil
import concurrent.futures as cf
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Dict, Set, Optional, Sequence

import numpy as np
import xml.etree.ElementTree as ET

from . import VTKBaseWriter

if TYPE_CHECKING:
    from ..state import State
    from ..system import System


def _is_safe_to_clean(path: Path) -> bool:
    """
    Return True if and only if it is safe to delete the target directory.

    Cleaning is refused when `path` resolves to:
      - the current working directory,
      - any ancestor of the current working directory, or
      - the filesystem root (or drive root on Windows).

    Parameters
    ----------
    path : Path
        Directory to test.

    Returns
    -------
    bool
        True if the path is safe to delete; False otherwise.
    """
    p = path.resolve()
    cwd = Path.cwd().resolve()

    # never nuke CWD, any parent of CWD, or the root/drive
    if p == cwd or p in cwd.parents:
        return False
    if p == Path(p.anchor):  # '/' on POSIX, 'C:\\' on Windows
        return False
    return True


[docs] @dataclass(slots=True) class VTKWriter: """ High-level front end for writing simulation data to VTK files. This class orchestrates the conversion of JAX-based :class:`jaxdem.State` and :class:`jaxdem.System` pytrees into VTK files, handling batches, trajectories, and dispatch to registered :class:`jaxdem.VTKBaseWriter` subclasses. How leading axes are interpreted -------------------------------- Let particle positions have shape ``(..., N, dim)``, where ``N`` is the number of particles and ``dim`` is 2 or 3. Define ``L = state.pos.ndim - 2``, i.e., the number of *leading* axes before ``(N, dim)``. - ``L == 0`` — single snapshot The input is one frame. It is written directly into ``frames/batch_00000000/`` (no batching, no trajectory). - ``trajectory=False`` (default) All leading axes are treated as **batch** axes (not time). If multiple batch axes are present, they are **flattened** into a single batch axis: ``(B, N, dim)`` with ``B = prod(shape[:L])``. Each batch ``b`` is written as a single snapshot under its own subdirectory ``frames/batch_XXXXXXXX/``. No trajectory is implied. - Example: ``(B, N, dim)`` → B separate directories with one frame each. - Example: ``(B1, B2, N, dim)`` → flatten to ``(B1*B2, N, dim)`` and treat as above. - ``trajectory=True`` The axis given by ``trajectory_axis`` is **swapped to the front (axis 0)** and interpreted as **time** ``T``. Any remaining leading axes are batch axes. If more than one non-time leading axis exists, they are flattened into a single batch axis so the data becomes ``(T, B, N, dim)`` with ``B = prod(other leading axes)``. - If there is only time (``L == 1``): ``(T, N, dim)`` — a single batch directory ``frames/batch_00000000/`` contains a time series with ``T`` frames. - If there is time plus batching (``L >= 2``): ``(T, B, N, dim)`` — each batch ``b`` gets its own directory ``frames/batch_XXXXXXXX/`` containing a time series (``T`` frames) for that batch. After these swaps/reshapes, dispatch is: - ``(N, dim)`` → single snapshot - ``(B, N, dim)`` → batches (no time) - ``(T, N, dim)`` → single batch with a trajectory - ``(T, B, N, dim)`` → per-batch trajectories Concrete writers receive per-frame NumPy arrays; leaves in :class:`System` are sliced/broadcast consistently with the current frame/batch. """ writers: List[str] = field(default_factory=list) """ A list of strings specifying which registered :class:`VTKBaseWriter` subclasses should be used for writing. If `None`, all available `VTKBaseWriter` subclasses will be used. """ directory: Path = Path("./frames") """ The base directory where output VTK files will be saved. Subdirectories might be created within this path for batched outputs. Defaults to "frames". """ binary: bool = True """ If :obj:`True`, VTK files will be written in binary format. If :obj:`False`, files will be written in ASCII format. Defaults to :obj:`True`. """ clean: bool = True """ If :obj:`True`, the `directory` will be completely emptied before any files are written. Defaults to :obj:`True`. This is useful for starting a fresh set of output frames. """ save_every: int = 1 """ How often to write; writes on every ``save_every``-th call to :meth:`save`. """ max_queue_size: int = 512 """ The maximum number of scheduled writes allowed. ``0`` means unbounded. """ max_workers: Optional[int] = None """ Maximum number of worker threads for the internal thread pool. """ _counter: int = 0 """ Internal counter for how many times :meth:`save` has been called. Initialized to 0. """ _pool: cf.ThreadPoolExecutor = field( default_factory=cf.ThreadPoolExecutor, init=False, repr=False ) """ Internal :class:`concurrent.futures.ThreadPoolExecutor` used for asynchronous file writing, allowing I/O operations to run in the background. """ _writer_classes: List = field(default_factory=list) """ Concrete writer classes corresponding to the names in :attr:`writers`, resolved from :class:`VTKBaseWriter`'s registry. """ _manifest: Dict = field(default_factory=dict) """ In-memory manifest of written (or scheduled) frames and metadata. Structure: {batch: {writer: {frame: {frame, time, epoch}, "_pvd_epoch": int}}} Used to prevent stale publishes and to build PVD collections. """ _pending_futures: Set[cf.Future] = field( default_factory=set, init=False, repr=False ) """ Futures representing in-flight write tasks. Drained by :meth:`block_until_ready` or when the queue is throttled via :attr:`max_queue_size`. """ _lock: threading.Lock = field( default_factory=threading.Lock, init=False, repr=False ) """ Internal lock protecting access to shared state such as :attr:`_manifest` and :attr:`_pending_futures`. """ def __post_init__(self): """ Validate configuration, clean/create the output directory, resolve writer classes from the registry, and start the thread pool. """ self.save_every = int(self.save_every) self.max_queue_size = int(self.max_queue_size) if self.max_workers: self.max_workers = int(self.max_workers) self.directory = Path(self.directory) available = list(VTKBaseWriter._registry.keys()) if not self.writers: self.writers = available unknown = [w for w in self.writers if w not in available] if unknown: raise KeyError( f"Unknown VTK writers {unknown}. " f"Available: {available}" ) if self.clean and self.directory.exists(): if _is_safe_to_clean(self.directory): shutil.rmtree(self.directory) self.directory.mkdir(parents=True, exist_ok=True) self._writer_classes = [VTKBaseWriter._registry[name] for name in self.writers] self._pool = cf.ThreadPoolExecutor(max_workers=self.max_workers)
[docs] def close(self): """ Flush all pending tasks and shut down the internal thread pool. Safe to call multiple times. """ self.block_until_ready() if self._pool is not None: self._pool.shutdown(wait=True, cancel_futures=False)
def __del__(self): """ Destructor to ensure the thread pool is shut down and pending tasks have completed before object is garbage-collected. """ try: self.close() except Exception: pass def _publish_vtp_if_latest( self, batch: str, writer: str, frame: int, epoch: int, final_path: Path, tmp_path: Path, ) -> bool: """ Publish a completed .vtp write if its epoch matches the latest known epoch for the given (batch, writer, frame). Otherwise, discard temp. Parameters ---------- batch : str Batch directory name (e.g., 'batch_00000003'). writer : str Writer key (e.g., 'spheres'). frame : int Frame number (usually system.step_count). epoch : int Epoch recorded when the task was scheduled. final_path : Path Destination file path. tmp_path : Path Temporary file path to be atomically renamed into place. Returns ------- bool True if the file was published; False if discarded due to staleness. """ current = self._current_epoch_for_vtp(batch, writer, frame) if current != epoch: try: os.remove(tmp_path) except FileNotFoundError: pass return False return self._replace_atomic(final_path, tmp_path) def _append_manifest(self, directory: Path, system) -> None: """ Record (or update) the manifest entry for the current frame/time for all writers in the given batch directory. Also updates the per-writer PVD epoch when the set of frames changes. Parameters ---------- directory : Path Target batch directory (e.g., frames/batch_00000000). system : System System snapshot providing `step_count` and `time`. """ frame = int(system.step_count) t = float(system.time) bkey = directory.name with self._lock: per_batch = self._manifest.setdefault(bkey, {}) for name in self.writers: per_writer = per_batch.setdefault(name, {}) before = set(k for k in per_writer.keys() if isinstance(k, int)) per_writer[frame] = { "frame": frame, "time": t, "epoch": self._counter, } after = before | {frame} if after != before: per_writer["_pvd_epoch"] = self._counter def _append_manifest_batch( self, directory: Path, systems: "Sequence[System]" ) -> None: """ Record manifest entries for many frames under one lock. Each System must have step_count and time populated. """ bkey = directory.name with self._lock: per_batch = self._manifest.setdefault(bkey, {}) for name in self.writers: per_writer = per_batch.setdefault(name, {}) before = {k for k in per_writer.keys() if isinstance(k, int)} for sys in systems: f = int(sys.step_count) per_writer[f] = { "frame": f, "time": float(sys.time), "epoch": self._counter, } after = {k for k in per_writer.keys() if isinstance(k, int)} if after != before: per_writer["_pvd_epoch"] = self._counter def _current_epoch_for_vtp(self, batch: str, writer: str, frame: int) -> int: """ Get the current (latest) epoch recorded for a specific VTP frame. Parameters ---------- batch : str Batch directory name. writer : str Writer key. frame : int Frame number. Returns ------- int Epoch value, or None if unknown. """ with self._lock: return ( self._manifest.get(batch, {}) .get(writer, {}) .get(frame, {}) .get("epoch", None) ) def _current_epoch_for_pvd(self, batch: str, writer: str) -> int: """ Get the current (latest) epoch recorded for a writer's PVD collection. Parameters ---------- batch : str Batch directory name. writer : str Writer key. Returns ------- int Epoch value for the PVD, or None if unknown. """ with self._lock: return self._manifest.get(batch, {}).get(writer, {}).get("_pvd_epoch", None) @staticmethod def _replace_atomic(final_path: Path, tmp_path: Path) -> bool: """ Atomically replace `final_path` with `tmp_path`. On success, returns True. On failure, attempts to delete the temporary file and re-raises the exception. Parameters ---------- final_path : Path Destination path to replace. tmp_path : Path Temporary file to move into place. Returns ------- bool True if replaced successfully. Raises ------ Exception Any exception raised by `os.replace` after temporary cleanup. """ # Optional: add retry around PermissionError on Windows if needed try: os.replace(os.fspath(tmp_path), os.fspath(final_path)) return True except Exception: try: os.remove(tmp_path) except FileNotFoundError: pass raise def _publish_pvd_if_latest( self, batch: str, writer: str, epoch: int, final_path: Path, tmp_path: Path, ) -> bool: """ Publish a completed .pvd write if its epoch matches the latest known epoch for the given (batch, writer). Otherwise, discard temp. Parameters ---------- batch : str Batch directory name. writer : str Writer key. epoch : int Epoch recorded when the task was scheduled. final_path : Path Destination .pvd file path. tmp_path : Path Temporary .pvd file path to be atomically renamed into place. Returns ------- bool True if the file was published; False if discarded due to staleness. """ current = self._current_epoch_for_pvd(batch, writer) if current != epoch: try: os.remove(tmp_path) except FileNotFoundError: pass return False return self._replace_atomic(final_path, tmp_path)
[docs] def block_until_ready(self): """ Wait until all scheduled writer tasks complete. This will wait for all pending futures, propagate exceptions (if any), and clear the pending set. """ if self._pending_futures: cf.wait(self._pending_futures) for f in self._pending_futures: f.result() self._pending_futures.clear()
[docs] def save( self, state: "State", system: "System", *, trajectory: bool = False, trajectory_axis: int = 0, batch0: int = 0, ): """ Schedule writing of a :class:`jaxdem.State` / :class:`jaxdem.System` pair to VTK files. This public entry point interprets leading axes (batch vs. trajectory), performs any required axis swapping and flattening, and then writes the resulting frames using the registered writers. It also creates per-batch ParaView ``.pvd`` collections referencing the generated files. Parameters ---------- state : State The simulation :class:`jaxdem.State` object to be saved. Its array leaves must end with ``(N, dim)``. system : System The :class:`jaxdem.System` object corresponding to `state`. Leading axes must be compatible (or broadcastable) with those of `state`. trajectory : bool, optional If ``True``, interpret ``trajectory_axis`` as time and write a trajectory; if ``False``, interpret the leading axis as batch. trajectory_axis : int, optional The axis in `state`/`system` to treat as the trajectory (time) axis when ``trajectory=True``. This axis is swapped to the front prior to writing. batch0 : in Initial value of batch from where to start counting the batches. """ if self._counter % self.save_every != 0: self._counter += 1 return Ndim = state.pos.ndim # Make sure trajectory axis is axis 0 if trajectory: state = jax.tree_util.tree_map( lambda x: jnp.swapaxes(x, trajectory_axis, 0), state ) system = jax.tree_util.tree_map( lambda x: jnp.swapaxes(x, trajectory_axis, 0), system ) # flatten all batch dimensions if Ndim >= 4: L = Ndim - 2 if trajectory: state = jax.tree_util.tree_map( lambda x: x.reshape( ( x.shape[0], math.prod(x.shape[1:L]), ) + x.shape[L:] ), state, ) system = jax.tree_util.tree_map( lambda x: x.reshape( ( x.shape[0], math.prod(x.shape[1:L]), ) + x.shape[L:] ), system, ) else: state = jax.tree_util.tree_map( lambda x: x.reshape((math.prod(x.shape[:L]),) + x.shape[L:]), state ) system = jax.tree_util.tree_map( lambda x: x.reshape((math.prod(x.shape[:L]),) + x.shape[L:]), system ) state = jax.tree_util.tree_map( lambda x: x if isinstance(x, np.ndarray) and x.flags["C_CONTIGUOUS"] else np.asarray(x, order="C"), state, ) system = jax.tree_util.tree_map( lambda x: x if isinstance(x, np.ndarray) and x.flags["C_CONTIGUOUS"] else np.asarray(x, order="C"), system, ) match state.pos.ndim: case 2: directory = self.directory / Path(f"batch_{batch0:08d}") self._append_manifest(directory, system) self._schedule_frame_writes(state, system, directory) case 3: if trajectory: T, _, _ = state.pos.shape directory = self.directory / Path(f"batch_{batch0:08d}") sys_list = [ jax.tree_util.tree_map(lambda x, i=i: x[i], system) for i in range(T) ] self._append_manifest_batch(directory, sys_list) for i in range(T): st = jax.tree_util.tree_map(lambda x: x[i], state) sys = sys_list[i] self._schedule_frame_writes(st, sys, directory) else: B, _, _ = state.pos.shape for j in range(B): st = jax.tree_util.tree_map(lambda x, j=j: x[j], state) sys = jax.tree_util.tree_map(lambda x, j=j: x[j], system) directory = self.directory / Path(f"batch_{batch0+j:08d}") self._append_manifest(directory, sys) self._schedule_frame_writes(st, sys, directory) case 4: T, B, _, _ = state.pos.shape for j in range(B): directory = self.directory / Path(f"batch_{batch0+j:08d}") sys_list = [ jax.tree_util.tree_map(lambda x, i=i, j=j: x[i, j], system) for i in range(T) ] self._append_manifest_batch(directory, sys_list) for i in range(T): st = jax.tree_util.tree_map(lambda x, i=i, j=j: x[i, j], state) sys = sys_list[i] self._schedule_frame_writes(st, sys, directory) with self._lock: manifest_snapshot = { batch: { writer: { "_pvd_epoch": info.get("_pvd_epoch", None), "frames": sorted(k for k in info.keys() if isinstance(k, int)), } for writer, info in writers.items() } for batch, writers in self._manifest.items() } for batch, writers in manifest_snapshot.items(): for writer, info in writers.items(): if ( self.max_queue_size and len(self._pending_futures) >= self.max_queue_size ): _, self._pending_futures = cf.wait( self._pending_futures, return_when=cf.FIRST_COMPLETED ) self._pending_futures.add( self._pool.submit( self._build_pvd_one, batch, writer, info["frames"], info["_pvd_epoch"], ) ) self._counter += 1
def _schedule_frame_writes(self, state_np, system_np, directory: Path): """ Queue per-writer tasks for a single frame (non-blocking). Parameters ---------- state : State State snapshot (arrays converted to NumPy for VTK). system : System System snapshot (arrays converted to NumPy for VTK). directory : Path Directory where the per-writer frame files will be written. """ directory.mkdir(parents=True, exist_ok=True) batch = directory.name frame = int(system_np.step_count) for cls, writer_name in zip(self._writer_classes, self.writers): final_path = ( directory / f"{writer_name}_{int(system_np.step_count):08d}.vtp" ) epoch = self._current_epoch_for_vtp(batch, writer_name, frame) d = final_path.parent base = final_path.name fd, tmp_path = tempfile.mkstemp( prefix=f"temp_{base}", suffix=".tmp", dir=os.fspath(d) ) os.close(fd) # let VTK open the file by path def write_one_file( tmp_path: Path = Path(tmp_path), final_path: Path = Path(final_path), state=state_np, system=system_np, binary: bool = self.binary, batch: str = batch, writer_name: str = writer_name, frame: int = frame, epoch: int = epoch, cls=cls, ) -> bool: try: cls.write(state, system, tmp_path, binary) return self._publish_vtp_if_latest( batch, writer_name, frame, epoch, final_path, tmp_path ) except Exception: try: os.remove(tmp_path) except FileNotFoundError: pass raise if ( self.max_queue_size and len(self._pending_futures) >= self.max_queue_size ): _, self._pending_futures = cf.wait( self._pending_futures, return_when=cf.FIRST_COMPLETED ) self._pending_futures.add(self._pool.submit(write_one_file)) def _build_pvd_one( self, batch: str, writer: str, frames: List[int], epoch: int, time_format: str = ".12g", ) -> None: """ Build a ParaView ``.pvd`` time-collection file for one (batch, writer). Uses the internal manifest to list frames in sorted order and to populate timesteps from recorded simulation times. Parameters ---------- batch : str Batch directory name (e.g., 'batch_00000000'). writer : str Writer key (e.g., 'spheres'). frames : List[int] Sorted frame indices to include. epoch : int Latest epoch for this writer's PVD; used to avoid stale publishes. time_format : str Python format specifier applied to the timestep when writing the XML attribute. Returns ------- None """ vtk_file_element = ET.Element( "VTKFile", type="Collection", version="0.1", byte_order="LittleEndian", ) collection_element = ET.SubElement(vtk_file_element, "Collection") with self._lock: by_writer = self._manifest.get(batch, {}).get(writer, {}) for frame in frames: t = by_writer.get(frame, {}).get("time", 0.0) name = f"{writer}_{frame:08d}.vtp" ET.SubElement( collection_element, "DataSet", timestep=format(t, time_format), file=f"{batch}/{name}", ) pvd_path = self.directory / f"{batch}_{writer}.pvd" d = pvd_path.parent base = pvd_path.name fd, tmp_path = tempfile.mkstemp( prefix=f"temp_{base}", suffix=".tmp", dir=os.fspath(d) ) os.close(fd) ET.ElementTree(vtk_file_element).write( tmp_path, encoding="utf-8", xml_declaration=True ) self._publish_pvd_if_latest(batch, writer, epoch, pvd_path, Path(tmp_path))
__all__ = ["VTKWriter"]