Source code for jaxdem.writers.vtkSpheresWriter

# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project – https://github.com/cdelv/JaxDEM
"""VTK writer that exports spheres data."""

from __future__ import annotations

from dataclasses import dataclass, fields
from pathlib import Path
from typing import TYPE_CHECKING

import numpy as np
import vtk
import vtk.util.numpy_support as vtk_np

from . import VTKBaseWriter

if TYPE_CHECKING:  # pragma: no cover
    from ..state import State
    from ..system import System


[docs] @VTKBaseWriter.register("spheres") @dataclass(slots=True, frozen=True) class VTKSpheresWriter(VTKBaseWriter): """ A :class:`VTKBaseWriter` that writes particle centers as VTK points and attaches per-particle :class:`State` fields as ``PointData`` attributes. For each particle, its position is written as a point. Relevant per-particle fields (e.g., ``vel``, ``rad``, ``mass``) are exported as arrays. Positions and 2-component vectors are padded to 3D as required by VTK. """
[docs] @classmethod def write( cls, state: "State", system: "System", filename: Path, binary: bool, ): pos = state.pos n = pos.shape[0] if pos.shape[-1] == 2: pos = np.pad(pos, (*[(0, 0)] * (pos.ndim - 1), (0, 1)), "constant") poly = vtk.vtkPolyData() points = vtk.vtkPoints() points.SetData(vtk_np.numpy_to_vtk(pos, deep=False)) poly.SetPoints(points) for fld in fields(state): name = fld.name if name == "pos": continue arr = getattr(state, name) if isinstance(arr, np.ndarray) and arr.ndim >= 1 and arr.shape[0] == n: if arr.dtype == np.bool_: arr = arr.astype(np.int8) if arr.ndim == 2 and arr.shape[1] == 2: arr = np.pad(arr, ((0, 0), (0, 1)), "constant") vtk_arr = vtk_np.numpy_to_vtk(arr, deep=False) vtk_arr.SetName(name) poly.GetPointData().AddArray(vtk_arr) writer = vtk.vtkXMLPolyDataWriter() writer.SetFileName(str(filename)) writer.SetInputData(poly) if binary: writer.SetDataModeToAppended() compressor = vtk.vtkZLibDataCompressor() writer.SetCompressor(compressor) else: writer.SetDataModeToAscii() ok = writer.Write() if ok != 1: raise RuntimeError("VTK spheres writer failed")
__all__ = ["VTKSpheresWriter"]