# SPDX-License-Identifier: BSD-3-Clause
# Part of the JaxDEM project - https://github.com/cdelv/JaxDEM
"""
Utility functions to compute thermodynamic quantities.
"""
from __future__ import annotations
import jax
import jax.numpy as jnp
from dataclasses import replace
from typing import TYPE_CHECKING, Optional, Tuple
from functools import partial
if TYPE_CHECKING:
from ..state import State
from ..system import System
[docs]
@jax.jit
@partial(
jax.named_call, name="thermal.compute_translational_kinetic_energy_per_particle"
)
def compute_translational_kinetic_energy_per_particle(state: State) -> jax.Array:
r"""
Compute the translational kinetic energy per particle.
.. math::
E_{trans} = \frac{1}{2} m |v|^2
Notes
------
- The energy of clump members is divided by the number of spheres in the clump.
Parameters
----------
state : State
The current state of the system containing particle masses and velocities.
Returns
-------
jax.Array
An array containing the translational kinetic energy for each particle.
"""
count = jnp.bincount(state.clump_ID, length=state.N)[state.clump_ID]
weight = state.mass / count
return 0.5 * weight * jnp.sum(state.vel * state.vel, axis=-1)
[docs]
@jax.jit
@partial(jax.named_call, name="thermal.compute_rotational_kinetic_energy_per_particle")
def compute_rotational_kinetic_energy_per_particle(state: State) -> jax.Array:
r"""
Compute the rotational kinetic energy per particle.
.. math::
E_{rot} = \frac{1}{2} \vec{\omega}^T I \vec{\omega}
Notes
------
- The energy of clump members is divided by the number of spheres in the clump.
Parameters
----------
state : State
The current state of the system containing inertia, orientation, and angular velocity.
Returns
-------
jax.Array
An array containing the rotational kinetic energy for each particle.
"""
count = jnp.bincount(state.clump_ID, length=state.N)[state.clump_ID]
if state.dim == 2:
w_body = state.angVel
else:
w_body = state.q.rotate_back(state.q, state.angVel) # to body frame
return 0.5 * jnp.vecdot(w_body, state.inertia * w_body) / count
[docs]
@jax.jit
@partial(jax.named_call, name="thermal.compute_translational_kinetic_energy")
def compute_translational_kinetic_energy(state: State) -> jax.Array:
r"""
Compute the total translational kinetic energy of the system.
.. math::
E_{trans, total} = \sum_{i} \frac{1}{2} m_i |v_i|^2
Parameters
----------
state : State
The current state of the system.
Returns
-------
jax.Array
The scalar sum of translational kinetic energy across all particles.
"""
return jnp.sum(compute_translational_kinetic_energy_per_particle(state))
[docs]
@jax.jit
@partial(jax.named_call, name="thermal.compute_rotational_kinetic_energy")
def compute_rotational_kinetic_energy(state: State) -> jax.Array:
r"""
Compute the total rotational kinetic energy of the system.
.. math::
E_{rot, total} = \sum_{i} \frac{1}{2} \vec{\omega}_i^T I_i \vec{\omega}_i
Parameters
----------
state : State
The current state of the system.
Returns
-------
jax.Array
The scalar sum of rotational kinetic energy across all particles.
"""
return jnp.sum(compute_rotational_kinetic_energy_per_particle(state))
[docs]
@jax.jit
@partial(jax.named_call, name="thermal.compute_potential_energy_per_particle")
def compute_potential_energy_per_particle(state: State, system: System) -> jax.Array:
"""
Compute the potential energy per particle based on system interactions.
Energy is computed from the force models in the collider, and gravity and force functions
that have potential energy associated with them in the force manager.
Parameters
----------
state : State
The current state of the system.
system : System
The system definition containing the collider and potential energy functions.
Returns
-------
jax.Array
An array containing the potential energy for each particle.
"""
pe_force_manager = system.force_manager.compute_potential_energy(state, system)
pe_collider = system.collider.compute_potential_energy(state, system)
return pe_force_manager + pe_collider
[docs]
@jax.jit
@partial(jax.named_call, name="thermal.compute_potential_energy")
def compute_potential_energy(state: State, system: System) -> jax.Array:
r"""
Compute the total potential energy of the system.
Energy is computed from the force models in the collider, and gravity and force functions
that have potential energy associated with them in the force manager.
.. math::
E_{pot, total} = \sum_{i} U(r_i)
Parameters
----------
state : State
The current state of the system.
system : System
The system definition containing the collider.
Returns
-------
jax.Array
The scalar sum of potential energy across all particles.
"""
return jnp.sum(compute_potential_energy_per_particle(state, system))
[docs]
@jax.jit
@partial(jax.named_call, name="thermal.compute_energy")
def compute_energy(state: State, system: System) -> jax.Array:
"""
Compute the total mechanical energy of the system.
.. math::
E_{total} = E_{pot, total} + E_{trans, total} + E_{rot, total}
Parameters
----------
state : State
The current state of the system.
system : System
The system definition containing physics parameters and colliders.
Returns
-------
jax.Array
The total energy (scalar) of the system.
"""
Pe = compute_potential_energy(state, system)
Ke_t = compute_translational_kinetic_energy(state)
Ke_r = compute_rotational_kinetic_energy(state)
return Pe + Ke_t + Ke_r
[docs]
def count_dynamic_dofs(
state: State, can_rotate: bool, subtract_drift: bool
) -> Tuple[jax.Array, jax.Array, jax.Array]:
"""
Count the number of degrees of freedom for the dynamics.
Parameters
----------
state : State
Current simulation state.
can_rotate : bool
Whether to include rigid body rotations.
subtract_drift : bool
Whether to include center-of-mass drift (usually only relevant for small systems).
"""
counts = jnp.bincount(state.clump_ID, length=state.N)
fixed_counts = jnp.bincount(
state.clump_ID, weights=state.fixed.astype(jnp.int32), length=state.N
)
free_count = jnp.sum((counts > 0) & (fixed_counts == 0))
n_dof_v = (free_count - subtract_drift) * state.vel.shape[-1]
n_dof_w = free_count * state.angVel.shape[-1] * can_rotate
n_dof = n_dof_v + n_dof_w
return n_dof, n_dof_v, n_dof_w
def _assign_random_velocities(
state: State, subtract_drift: bool, seed: Optional[int] = 0
) -> State:
"""
Assign random translational and angular velocities.
Parameters
----------
state : State
Current simulation state.
subtract_drift : bool
Whether to remove center-of-mass drift.
seed : int, optional
RNG seed.
"""
if seed is None:
seed = 0
key = jax.random.PRNGKey(seed)
v_k, w_k = jax.random.split(key, 2)
counts = jnp.bincount(state.clump_ID, length=state.N)
exists = counts > 0
fixed_counts = jnp.bincount(
state.clump_ID, weights=state.fixed.astype(jnp.int32), length=state.N
)
free_mask = (fixed_counts == 0) & exists
v_clump = jax.random.normal(v_k, (state.N, state.dim)) * free_mask[:, None]
if subtract_drift:
num_clumps = jnp.sum(exists)
v_clump_mean = jnp.sum(v_clump, axis=0) / jnp.maximum(num_clumps, 1)
v_clump -= v_clump_mean * exists[:, None]
vel = v_clump[state.clump_ID]
w_clump = (
jax.random.normal(w_k, (state.N, state.angVel.shape[-1])) * free_mask[:, None]
) # body frame
w = w_clump[state.clump_ID]
if state.dim == 2:
angVel = w
else: # rotate to lab frame
angVel = state.q.rotate(state.q, w)
return replace(state, vel=vel, angVel=angVel)
[docs]
def compute_temperature(
state: State, can_rotate: bool, subtract_drift: bool, k_B: float = 1.0
) -> float:
"""
Compute the temperature for a state.
Parameters
----------
state : State
Current simulation state.
can_rotate : bool
Whether to include rigid body rotations.
subtract_drift : bool
Whether to remove center-of-mass drift (usually only relevant for small systems).
k_B : float, optional
Boltzmann constant (default is 1.0).
"""
n_dof, _, _ = count_dynamic_dofs(state, can_rotate, subtract_drift)
ke_t = compute_translational_kinetic_energy(state)
if can_rotate:
ke_r = compute_rotational_kinetic_energy(state)
else:
ke_r = 0.0
ke = ke_t + ke_r
return 2 * ke / (k_B * n_dof)
[docs]
def set_temperature(
state: State,
target_temperature: float,
can_rotate: bool,
subtract_drift: bool,
seed: Optional[int] = 0,
k_B: float = 1.0,
) -> State:
"""
Randomize the velocities of a state according to a desired temperature.
Parameters
----------
state : State
Current simulation state.
target_temperature : float
Desired target temperature.
can_rotate : bool
Whether to include rigid body rotations.
subtract_drift : bool
Whether to remove center-of-mass drift (usually only relevant for small systems).
seed : int, optional
RNG seed.
k_B : float, optional
Boltzmann constant (default is 1.0).
"""
# assign random
state = _assign_random_velocities(state, subtract_drift, seed)
# compute temperature
temperature = compute_temperature(state, can_rotate, subtract_drift, k_B)
# scale to temperature
scale = jnp.sqrt(target_temperature / temperature)
vel = state.vel * scale
angVel = state.angVel * scale * can_rotate
return replace(state, vel=vel, angVel=angVel)
[docs]
def scale_to_temperature(
state: State,
target_temperature: float,
can_rotate: bool,
subtract_drift: bool,
k_B: float = 1.0,
) -> State:
"""
Scale the velocities of a state to a desired temperature
state: State
target_temperature: float - desired target temperature
can_rotate: bool - whether to include the rigid body rotations
subtract_drift: bool - whether to remove center of mass drift (usually only relevant for small systems)
k_B: Optional[float] - boltzmanns constant, default is 1.0
"""
# subtract drift
vel = state.vel - jnp.mean(state.vel, axis=-2) * subtract_drift
# compute temperature
temperature = compute_temperature(
replace(state, vel=vel), can_rotate, subtract_drift, k_B
)
# scale to temperature
scale = jnp.sqrt(target_temperature / temperature)
vel = vel * scale
angVel = state.angVel * scale * can_rotate
return replace(state, vel=vel, angVel=angVel)