Minimization of bidisperse spheres (or disks)#

In this example, we’ll minimize the energy of a set of random configurations of bidisperse spheres (or disks) in a 3D (or 2D) periodic box.

The particles use a purely repulsive harmonic interaction potential, meaning that the potential energy is zero when the particles are not in contact, and is otherwise proportional to the square of the overlap distance.

The minimization is performed using the FIRE minimizer.

Imports#

import jax
import jax.numpy as jnp
import jaxdem as jd

# We need to enable double precision to reach the necessary accuracy for our tolerances.
jax.config.update("jax_enable_x64", True)

Parameters#

We’ll minimize 10 systems of 10 particles in parallel. This highlights the utility of system-level parallelism in JaxDEM although it should be noted that the parallelized algorithm is only as fast as the slowest system. We will place the particles down randomly in the box according to an initial packing fraction of 0.4.

N_systems = 10
N = 10
phi = 0.4
dim = 2
e_int = 1.0
dt = 1e-2

def build_microstate(i):
    # assign bidisperse radii
    rad = jnp.ones(N)
    rad = rad.at[: N // 2].set(0.5)
    rad = rad.at[N // 2:].set(0.7)

    # set the box size for the packing fraction and the radii
    volume = jnp.sum((jnp.pi ** (dim / 2) / jax.scipy.special.gamma(dim / 2 + 1)) * rad ** dim)
    L = (volume / phi) ** (1 / dim)
    box_size = jnp.ones(dim) * L

    # create microstate
    key = jax.random.PRNGKey(i)
    pos = jax.random.uniform(key, (N, dim), minval=0.0, maxval=L)
    mass = jnp.ones(N)
    mats = [jd.Material.create("elastic", young=e_int, poisson=0.5, density=1.0)]
    matcher = jd.MaterialMatchmaker.create("harmonic")
    mat_table = jd.MaterialTable.from_materials(mats, matcher=matcher)

    # create system and state
    state = jd.State.create(pos=pos, rad=rad, mass=mass, volume=volume)
    system = jd.System.create(
        state_shape=state.shape,
        dt=dt,
        linear_integrator_type="linearfire",
        rotation_integrator_type="",
        domain_type="periodic",
        force_model_type="spring",
        collider_type="naive",
        # collider_type="celllist",
        # collider_kw=dict(state=state),
        mat_table=mat_table,
        domain_kw=dict(
            box_size=box_size,
        ),
    )
    return state, system

Run the Minimization for Multiple Systems#

We’ll first create the systems and states using jax’s vmap function. This will create 10 states and systems in parallel. We could also use the State.stack method to join a list of states and systems.

state, system = jax.vmap(build_microstate)(jnp.arange(N_systems))

# We'll run the minimization for up to 1M steps
n_steps = 1_000_000

# The minimizer will run until either of the following conditions are met:
# 1. step_count >= max_steps
# 2. PE <= PE_tol (Energy is low enough) and |PE / prev_PE - 1| < pe_diff_tol (Energy stopped changing)
# We will set the tolerance for the potential energy to 1e-16 and the tolerance for the difference in potential energy to 1e-16.
# The minimizer will return the final state, system, number of steps taken, and the final potential energy.
state, system, steps, final_pe = jax.vmap(lambda st, sys: jd.minimizers.minimize(st, sys, max_steps=n_steps, pe_tol=1e-16, pe_diff_tol=1e-16, initialize=True))(state, system)

print(f"Final potential energy: {final_pe}")
print(f"Number of steps taken: {steps}")
Final potential energy: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
Number of steps taken: [ 56  97  97  75  78  86  76  69 124  83]

Run the Minimization for a Single System#

We can also run the minimization on a single system by passing the state and system to the minimization function.

state, system = build_microstate(0)
state, system, steps, final_pe = jd.minimizers.minimize(state, system, max_steps=n_steps, pe_tol=1e-16, pe_diff_tol=1e-16, initialize=True)

print(f"Final potential energy: {final_pe}")
print(f"Number of steps taken: {steps}")
Final potential energy: 0.0
Number of steps taken: 56

Total running time of the script: (0 minutes 1.744 seconds)