jaxdem.domain#

Interface for defining domains. The domain performs boundary conditions coordinate transformation and computes the displacement vector according to the boundary conditions.

Classes

Domain(box_size, anchor)

The base interface for defining the simulation domain and the effect of its boundary conditions.

FreeDomain(box_size, anchor)

A Domain implementation representing an unbounded, "free" space.

PeriodicDomain(box_size, anchor)

A Domain implementation that enforces periodic boundary conditions.

ReflectDomain(box_size, anchor)

A Domain implementation that enforces reflective boundary conditions.

class jaxdem.domain.Domain(box_size: Array, anchor: Array)[source][source]#

Bases: Factory[Domain], ABC

The base interface for defining the simulation domain and the effect of its boundary conditions.

The Domain class defines how:
  • Relative displacement vectors between particles are calculated.

  • Particles’ positions are “shifted” or constrained to remain within the defined simulation boundaries based on the boundary condition type.

Notes

All concrete Domain implementations must support both 2D and 3D simulations. All methods must be JIT-compatible.

Example

To define a custom domain, inherit from Domain and implement its abstract methods:

>>> @Domain.register("my_custom_domain")
>>> @jax.tree_util.register_dataclass
>>> @dataclass(slots=True)
>>> class MyCustomDomain(Domain):
        ...
box_size: Array#

Length of the simulation domain along each dimension. Defines the size of the simulation box.

anchor: Array#

Anchor position of the simulation domain. This represents the minimum coordinate (e.g., the “left-down corner”) of the domain in each dimension.

periodic: ClassVar[bool] = False#

Whether or not the domain is periodic.

This is a class-level attribute that should be set to True for periodic boundary condition implementations.

abstractmethod static displacement(ri: Array, rj: Array, system: System) Array[source][source]#

Computes the displacement vector between two particles \(r_i\) and \(r_j\), considering the domain’s boundary conditions.

Parameters:
  • ri (jax.Array) – Position vector of the first particle \(r_i\). Shape (dim,).

  • rj (jax.Array) – Position vector of the second particle \(r_j\). Shape (dim,).

  • system (System) – The configuration of the simulation, containing the domain instance.

Returns:

The displacement vector \(r_{ij} = r_i - r_j\), adjusted for boundary conditions. Shape (dim,).

Return type:

jax.Array

Raises:

NotImplementedError – This is an abstract method and must be implemented by subclasses.

Example

abstractmethod static shift(state: State, system: System) Tuple[State, System][source][source]#

Applies boundary conditions to particles state.

This method updates the state based on the domain’s rules, ensuring particles remain within the simulation box or handle interactions at boundaries appropriately (e.g., reflection, wrapping).

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

A tuple containing the updated State object adjusted by the boundary conditions and the System object.

Return type:

Tuple[State, System]

Raises:

NotImplementedError – This is an abstract method and must be implemented by subclasses.

Example

classmethod create(key: str, /, **kw: Any) T[source]#

Creates and returns an instance of a registered subclass.

This method looks up the subclass associated with the given key in the factory’s registry and then calls its constructor with the provided arguments.

Parameters:
  • key (str) – The registration key of the subclass to be created.

  • **kw (Any) – Arbitrary keyword arguments to be passed directly to the constructor of the registered subclass.

Returns:

An instance of the registered subclass.

Return type:

T

Raises:
  • KeyError – If the provided key is not found in the factory’s registry.

  • TypeError – If the provided **kw arguments do not match the signature of the registered subclass’s constructor.

Example

Given Foo factory and Bar registered:

>>> bar_instance = Foo.create("bar", value=42)
>>> print(bar_instance)
Bar(value=42)
classmethod register(key: str | None = None) Callable[[Type[T]], Type[T]][source]#

Registers a subclass with the factory’s registry.

This method returns a decorator that can be applied to a class to register it under a specific key.

Parameters:

key (str or None, optional) – The string key under which to register the subclass. If None, the lowercase name of the subclass itself will be used as the key.

Returns:

A decorator function that takes a class and registers it, returning the class unchanged.

Return type:

Callable[[Type[T]], Type[T]]

Raises:

ValueError – If the provided key (or the default class name) is already registered in the factory’s registry.

Example

Register a class named “MyComponent” under the key “mycomp”:

>>> @MyFactory.register("mycomp")
>>> class MyComponent:
>>>     ...

Register a class named “DefaultComponent” using its own name as the key:

>>> @MyFactory.register()
>>> class DefaultComponent:
>>>     ...
class jaxdem.domain.FreeDomain(box_size: Array, anchor: Array)[source][source]#

Bases: Domain

A Domain implementation representing an unbounded, “free” space.

In a FreeDomain, there are no explicit boundary conditions applied to particles. Particles can move indefinitely in any direction, and the concept of a “simulation box” is only used to define the bounding box of the system.

Notes

  • The box_size and anchor attributes are dynamically updated in the shift method to encompass all particles. Some hashing tools require the domain size.

Example

>>> system = jaxdem.System.create(dim=2, domain_type="free", domain_kw=dict(box_size=jnp.array([10., 10.]), anchor=jnp.array([0., 0.])))
>>>
>>> # After a step, particle moves, and the domain's effective box_size and anchor update
>>> state, system = sim_system.domain.shift(state, system)
>>> print("Updated Domain Anchor:", system.domain.anchor)
>>> print("Updated Domain Box Size:", system.domain.box_size)
static displacement(ri: Array, rj: Array, _: System) Array[source][source]#

Computes the displacement vector between two particles.

In a free domain, the displacement is simply the direct vector difference between the particle positions.

Parameters:
  • ri (jax.Array) – Position vector of the first particle \(r_i\).

  • rj (jax.Array) – Position vector of the second particle \(r_j\).

  • _ (System) – The system object.

Returns:

The direct displacement vector \(r_i - r_j\).

Return type:

jax.Array

static shift(state: State, system: System) Tuple[State, System][source][source]#

Updates the System’s domain anchor and box_size to encompass all particles. Does not apply any transformations to the state.

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The current system configuration.

Returns:

The original State object (unchanged) and the System object with updated domain.anchor and domain.box_size.

Return type:

Tuple[State, System]

anchor: jax.Array#

Anchor position of the simulation domain. This represents the minimum coordinate (e.g., the “left-down corner”) of the domain in each dimension.

box_size: jax.Array#

Length of the simulation domain along each dimension. Defines the size of the simulation box.

classmethod create(key: str, /, **kw: Any) T[source]#

Creates and returns an instance of a registered subclass.

This method looks up the subclass associated with the given key in the factory’s registry and then calls its constructor with the provided arguments.

Parameters:
  • key (str) – The registration key of the subclass to be created.

  • **kw (Any) – Arbitrary keyword arguments to be passed directly to the constructor of the registered subclass.

Returns:

An instance of the registered subclass.

Return type:

T

Raises:
  • KeyError – If the provided key is not found in the factory’s registry.

  • TypeError – If the provided **kw arguments do not match the signature of the registered subclass’s constructor.

Example

Given Foo factory and Bar registered:

>>> bar_instance = Foo.create("bar", value=42)
>>> print(bar_instance)
Bar(value=42)
periodic: ClassVar[bool] = False#

Whether or not the domain is periodic.

This is a class-level attribute that should be set to True for periodic boundary condition implementations.

classmethod register(key: str | None = None) Callable[[Type[T]], Type[T]][source]#

Registers a subclass with the factory’s registry.

This method returns a decorator that can be applied to a class to register it under a specific key.

Parameters:

key (str or None, optional) – The string key under which to register the subclass. If None, the lowercase name of the subclass itself will be used as the key.

Returns:

A decorator function that takes a class and registers it, returning the class unchanged.

Return type:

Callable[[Type[T]], Type[T]]

Raises:

ValueError – If the provided key (or the default class name) is already registered in the factory’s registry.

Example

Register a class named “MyComponent” under the key “mycomp”:

>>> @MyFactory.register("mycomp")
>>> class MyComponent:
>>>     ...

Register a class named “DefaultComponent” using its own name as the key:

>>> @MyFactory.register()
>>> class DefaultComponent:
>>>     ...
class jaxdem.domain.ReflectDomain(box_size: Array, anchor: Array)[source][source]#

Bases: Domain

A Domain implementation that enforces reflective boundary conditions.

Particles that attempt to move beyond the defined box_size will have their positions reflected back into the box and their velocities reversed in the direction normal to the boundary.

Notes

  • The reflection occurs at the boundaries defined by anchor and anchor + box_size.

Example

>>> system = jaxdem.System.create(dim=2, domain_type="reflect", domain_kw=dict(box_size=jnp.array([10., 7.]), anchor=jnp.array([1., 0.])))
static displacement(ri: Array, rj: Array, _: System) Array[source][source]#

Computes the displacement vector between two particles.

In a reflective domain, the displacement is simply the direct vector difference.

Parameters:
  • ri (jax.Array) – Position vector of the first particle \(r_i\).

  • rj (jax.Array) – Position vector of the second particle \(r_j\).

  • _ (System) – The system object.

Returns:

The direct displacement vector \(r_i - r_j\).

Return type:

jax.Array

static shift(state: State, system: System) Tuple[State, System][source][source]#

Applies reflective boundary conditions to particles.

Particles are checked against the domain boundaries. If a particle attempts to move beyond a boundary, its position is reflected back into the box, and its velocity component normal to that boundary is reversed.

\[\begin{split}l &= a + R \\ u &= a + B - R \\ v' &= \begin{cases} -v & \text{if } r < l \text{ or } r > u \\ v & \text{otherwise} \end{cases} \\ r' &= \begin{cases} 2l - r & \text{if } r < l \\ r & \text{otherwise} \end{cases} \\ r'' &= \begin{cases} 2u - r' & \text{if } r' > u \\ r' & \text{otherwise} \end{cases} r = r''\end{split}\]
where:
  • \(r\) is the current particle position (state.pos)

  • \(v\) is the current particle velocity (state.vel)

  • \(a\) is the domain anchor (system.domain.anchor)

  • \(B\) is the domain box size (system.domain.box_size)

  • \(R\) is the particle radius (state.rad)

  • \(l\) is the lower boundary for the particle center

  • \(u\) is the upper boundary for the particle center

TO DO: Ensure correctness when adding different types of shapes and angular vel

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

The updated State object with reflected positions and velocities, and the System object.

Return type:

Tuple[State, System]

anchor: jax.Array#

Anchor position of the simulation domain. This represents the minimum coordinate (e.g., the “left-down corner”) of the domain in each dimension.

box_size: jax.Array#

Length of the simulation domain along each dimension. Defines the size of the simulation box.

classmethod create(key: str, /, **kw: Any) T[source]#

Creates and returns an instance of a registered subclass.

This method looks up the subclass associated with the given key in the factory’s registry and then calls its constructor with the provided arguments.

Parameters:
  • key (str) – The registration key of the subclass to be created.

  • **kw (Any) – Arbitrary keyword arguments to be passed directly to the constructor of the registered subclass.

Returns:

An instance of the registered subclass.

Return type:

T

Raises:
  • KeyError – If the provided key is not found in the factory’s registry.

  • TypeError – If the provided **kw arguments do not match the signature of the registered subclass’s constructor.

Example

Given Foo factory and Bar registered:

>>> bar_instance = Foo.create("bar", value=42)
>>> print(bar_instance)
Bar(value=42)
periodic: ClassVar[bool] = False#

Whether or not the domain is periodic.

This is a class-level attribute that should be set to True for periodic boundary condition implementations.

classmethod register(key: str | None = None) Callable[[Type[T]], Type[T]][source]#

Registers a subclass with the factory’s registry.

This method returns a decorator that can be applied to a class to register it under a specific key.

Parameters:

key (str or None, optional) – The string key under which to register the subclass. If None, the lowercase name of the subclass itself will be used as the key.

Returns:

A decorator function that takes a class and registers it, returning the class unchanged.

Return type:

Callable[[Type[T]], Type[T]]

Raises:

ValueError – If the provided key (or the default class name) is already registered in the factory’s registry.

Example

Register a class named “MyComponent” under the key “mycomp”:

>>> @MyFactory.register("mycomp")
>>> class MyComponent:
>>>     ...

Register a class named “DefaultComponent” using its own name as the key:

>>> @MyFactory.register()
>>> class DefaultComponent:
>>>     ...
class jaxdem.domain.PeriodicDomain(box_size: Array, anchor: Array)[source][source]#

Bases: Domain

A Domain implementation that enforces periodic boundary conditions.

Particles that move out of one side of the simulation box re-enter from the opposite side. The displacement vector between particles is computed using the minimum image convention.

Notes

  • This domain type is periodic (periodic = True).

  • The shift method wraps particles back into the primary simulation box.

Example

>>> system = jaxdem.System.create(dim=2, domain_type="periodic", domain_kw=dict(box_size=jnp.array([10., 7.]), anchor=jnp.array([1., 0.])))
anchor: jax.Array#

Anchor position of the simulation domain. This represents the minimum coordinate (e.g., the “left-down corner”) of the domain in each dimension.

box_size: jax.Array#

Length of the simulation domain along each dimension. Defines the size of the simulation box.

classmethod create(key: str, /, **kw: Any) T[source]#

Creates and returns an instance of a registered subclass.

This method looks up the subclass associated with the given key in the factory’s registry and then calls its constructor with the provided arguments.

Parameters:
  • key (str) – The registration key of the subclass to be created.

  • **kw (Any) – Arbitrary keyword arguments to be passed directly to the constructor of the registered subclass.

Returns:

An instance of the registered subclass.

Return type:

T

Raises:
  • KeyError – If the provided key is not found in the factory’s registry.

  • TypeError – If the provided **kw arguments do not match the signature of the registered subclass’s constructor.

Example

Given Foo factory and Bar registered:

>>> bar_instance = Foo.create("bar", value=42)
>>> print(bar_instance)
Bar(value=42)
periodic: ClassVar[bool] = True#

Whether or not the domain is periodic.

This is a class-level attribute that should be set to True for periodic boundary condition implementations.

classmethod register(key: str | None = None) Callable[[Type[T]], Type[T]][source]#

Registers a subclass with the factory’s registry.

This method returns a decorator that can be applied to a class to register it under a specific key.

Parameters:

key (str or None, optional) – The string key under which to register the subclass. If None, the lowercase name of the subclass itself will be used as the key.

Returns:

A decorator function that takes a class and registers it, returning the class unchanged.

Return type:

Callable[[Type[T]], Type[T]]

Raises:

ValueError – If the provided key (or the default class name) is already registered in the factory’s registry.

Example

Register a class named “MyComponent” under the key “mycomp”:

>>> @MyFactory.register("mycomp")
>>> class MyComponent:
>>>     ...

Register a class named “DefaultComponent” using its own name as the key:

>>> @MyFactory.register()
>>> class DefaultComponent:
>>>     ...
static displacement(ri: Array, rj: Array, system: System) Array[source][source]#

Computes the minimum image displacement vector between two particles \(r_i\) and \(r_j\).

For periodic boundary conditions, the displacement is calculated as the shortest vector that connects \(r_j\) to \(r_i\), potentially by crossing periodic boundaries.

Parameters:
  • ri (jax.Array) – Position vector of the first particle \(r_i\).

  • rj (jax.Array) – Position vector of the second particle \(r_j\).

  • system (System) – The configuration of the simulation, containing the domain instance with anchor and box_size for periodicity.

Returns:

The minimum image displacement vector:

\[\begin{split}& r_{ij} = (r_i - a) - (r_j - a) \\ & r_{ij} = r_{ij} - B \cdot \text{round}(r_{ij}/B)\end{split}\]
where:
  • \(a\) is the domain anchor (system.domain.anchor)

  • \(B\) is the domain box size (system.domain.box_size)

Return type:

jax.Array

static shift(state: State, system: System) Tuple[State, System][source][source]#

Wraps particles back into the primary simulation box.

\[\begin{split}r = r - B \cdot \text{floor}((r - a)/B) \\\end{split}\]
where:
  • \(a\) is the domain anchor (system.domain.anchor)

  • \(B\) is the domain box size (system.domain.box_size)

Parameters:
  • state (State) – The current state of the simulation.

  • system (System) – The configuration of the simulation.

Returns:

The updated State object with wrapped particle positions, and the System object.

Return type:

Tuple[State, System]