Vicsek Model#

This is an implementation of the Vicsek model in 2D. A state is created with 200 particles in periodic boundaries, each interacting via a pairwise purely-repulsive harmonic potential. Particles move with a constant velocity `v0` in a direction that is set by a random component proportional to `eta` and another component proportional to the average velocity of all neighboring particles within a distance `neighbor_radius`. In this implementation, the random vector noise characterizes this as an “extrinsic” noise Vicsek model, as opposed to the “intrinsic” noise variant. We use a `trajectory_rollout` to run the dynamics using a jit-compiled loop, which also returns the trajectory data. We then calculate the polarization order parameter (norm of the average velocity vectors) for each saved frame.

import jax

jax.config.update("jax_enable_x64", True)  # type: ignore[no-untyped-call]
import jax.numpy as jnp
import jaxdem as jd
import numpy as np

from jaxdem.utils.randomSphereConfiguration import random_sphere_configuration

This function sets the initial data

def build_microstate(
    N: int,
    phi: float,
    dim: int,
    dt: float,
    neighbor_radius: float,
    eta: float,
    v0: float,
    seed: int,
) -> tuple[jd.State, jd.System]:
    # important to set this to be large enough such that the collider does not overflow
    max_neighbors = 64

    # Mono-disperse radii
    particle_radii = jd.utils.dispersity.get_polydisperse_radii(N, [1.0], [1.0])

    pos, box_size = random_sphere_configuration(particle_radii.tolist(), phi, dim, seed)

    state = jd.State.create(
        pos=pos,
        rad=particle_radii,
        mass=jnp.ones(N),
    )

    mats = [jd.Material.create("elastic", young=1.0, poisson=0.5, density=1.0)]
    matcher = jd.MaterialMatchmaker.create("harmonic")
    mat_table = jd.MaterialTable.from_materials(mats, matcher=matcher)

    system = jd.System.create(
        state_shape=state.shape,
        dt=dt,
        # Vicsek integrator:
        linear_integrator_type="vicsek_extrinsic",
        linear_integrator_kw={
            "neighbor_radius": jnp.asarray(neighbor_radius, dtype=float),
            "eta": jnp.asarray(eta, dtype=float),
            "v0": jnp.asarray(v0, dtype=float),
            "max_neighbors": max_neighbors,
        },
        rotation_integrator_type="",
        domain_type="periodic",
        domain_kw={"box_size": box_size},
        force_model_type="spring",
        mat_table=mat_table,
        # here, we use the naive (double for-loop) collider since the system is small;
        # if you were to use a larger system, we recommend using the StaticCellList
        # or potentially the NeighborList
        collider_type="naive",
        seed=seed,
    )
    return state, system

Build the initial data

state, system = build_microstate(
    N=200,  # 200 particles
    phi=0.65,  # just dense enough to order
    dim=2,  # 2D
    dt=1e-2,  # semi-arbitrary timestep
    neighbor_radius=1.0,  # particles will align within 2x their radii
    eta=0.2,  # small noise component
    v0=1.0,  # semi-arbitrary velocity
    seed=int(np.random.randint(0, int(1e9))),  # random seed
)

# Run the dynamics for 5K steps, saving every 50th
n_steps = 5_000
save_stride = 50
n_frames = n_steps // save_stride

state_f, system_f, (traj_state, traj_system) = jd.System.trajectory_rollout(
    state,
    system,
    n=n_frames,
    stride=save_stride,
)

polarization = jnp.linalg.norm(jnp.mean(traj_state.vel, axis=-2), axis=-1)
print("Polarization:")
print(polarization)
Polarization:
[0.18903634 0.55963682 0.86249545 0.89739931 0.90308634 0.93341945
 0.94436458 0.92358189 0.90378029 0.94144596 0.94704681 0.9246329
 0.91693226 0.92554952 0.90254184 0.93865573 0.92635339 0.94355505
 0.95769323 0.93169443 0.93541103 0.95478515 0.94860139 0.95617203
 0.95201607 0.92659948 0.8861399  0.93257875 0.91067856 0.91691195
 0.89100899 0.92963837 0.92162171 0.90241514 0.90227092 0.89547257
 0.94591109 0.94993395 0.95186676 0.9535517  0.92099929 0.91966019
 0.92077374 0.89669679 0.91721626 0.92255779 0.92556122 0.92290508
 0.93959659 0.93763581 0.92242346 0.92717586 0.91482355 0.91678101
 0.94825667 0.93011606 0.94323684 0.92436341 0.93355585 0.94331983
 0.92102912 0.94818728 0.93132396 0.93084801 0.90116742 0.87637145
 0.93631454 0.95144618 0.89321811 0.9215996  0.93421132 0.92131456
 0.91924437 0.91756848 0.89859745 0.87777277 0.84559597 0.92722342
 0.92798148 0.94903841 0.93220077 0.91975218 0.92007786 0.88776152
 0.94021008 0.93053204 0.92398551 0.93719044 0.92289964 0.92988852
 0.88674989 0.93050976 0.87525304 0.93311452 0.92894261 0.90065289
 0.96816237 0.9413768  0.93985533 0.91921768]

Total running time of the script: (0 minutes 18.381 seconds)