jaxdem.utils.linalg#

Utility functions to help with linear algebra.

Functions

cross(a, b)

Computes the cross product of two vectors, 'a' and 'b', along their last axis.

cross_3X3D_1X2D(w, r)

Computes the cross product of angular velocity vector (w) and a position vector (r), often used to find tangential velocity: v = w x r.

dot(a, b)

Dot product of vectors along the last axis.

norm(v)

Norm (magnitude) of vectors along the last axis.

norm2(v)

Squared norm of vectors along the last axis.

unit(v)

Normalize vectors to unit vectors along the last axis.

unit_and_norm(v)

Normalize vectors along the last axis and return both unit vectors and norms.

jaxdem.utils.linalg.cross(a: Array, b: Array) Array[source]#

Computes the cross product of two vectors, ‘a’ and ‘b’, along their last axis.

For 3D vectors (\(D=3\)), the result is a vector orthogonal to both ‘a’ and ‘b’:

\[\vec{c} = \vec{a} \times \vec{b} = (a_y b_z - a_z b_y) \mathbf{i} + (a_z b_x - a_x b_z) \mathbf{j} + (a_x b_y - a_y b_x) \mathbf{k}\]

For 2D vectors (\(D=2\)), the result is the scalar magnitude of the 3D cross product:

\[c = a_x b_y - a_y b_x\]
Parameters:
  • a (jax.Array) – First vector. Shape (…, D), where D is the dimension (2 or 3).

  • b (jax.Array) – Second vector. Shape (…, D).

Returns:

The cross product. - If D=3: shape is (…, 3). - If D=2: shape is (…, 1).

Return type:

jax.Array

Raises:

ValueError – If the last dimension is not 2 or 3, or if the dimensions of ‘a’ and ‘b’ do not match.

jaxdem.utils.linalg.dot(a: Array, b: Array) Array[source]#

Dot product of vectors along the last axis.

\[c = \vec{a} \cdot \vec{b} = \sum_{i} a_i b_i\]
Parameters:
  • a (jax.Array) – First vector. Shape (…, D).

  • b (jax.Array) – Second vector. Shape (…, D).

Returns:

The dot product. Shape (…).

Return type:

jax.Array

jaxdem.utils.linalg.norm2(v: Array) Array[source]#

Squared norm of vectors along the last axis.

\[\|v\|^2 = \vec{v} \cdot \vec{v} = \sum_{i} v_i^2\]
Parameters:

v (jax.Array) – Input vector. Shape (…, D).

Returns:

The squared norm. Shape (…).

Return type:

jax.Array

jaxdem.utils.linalg.norm(v: Array) Array[source]#

Norm (magnitude) of vectors along the last axis.

\[\|v\| = \sqrt{\sum_{i} v_i^2}\]
Parameters:

v (jax.Array) – Input vector. Shape (…, D).

Returns:

The norm. Shape (…).

Return type:

jax.Array

jaxdem.utils.linalg.unit(v: Array) Array[source]#

Normalize vectors to unit vectors along the last axis.

If the vector is zero, the result is zero.

\[\begin{split}\hat{v} = \begin{cases} \frac{\vec{v}}{\|v\|} & \text{if } \|v\| > 0 \\ \vec{0} & \text{otherwise} \end{cases}\end{split}\]
Parameters:

v (jax.Array) – Input vector. Shape (…, D).

Returns:

Unit vector. Shape (…, D).

Return type:

jax.Array

jaxdem.utils.linalg.unit_and_norm(v: Array) tuple[Array, Array][source]#

Normalize vectors along the last axis and return both unit vectors and norms.

Parameters:

v (jax.Array) – Input vector. Shape (…, D).

Returns:

A tuple of (unit vectors, norms).

Return type:

Tuple[jax.Array, jax.Array]

jaxdem.utils.linalg.cross_3X3D_1X2D(w: Array, r: Array) Array[source]#

Computes the cross product of angular velocity vector (w) and a position vector (r), often used to find tangential velocity: v = w x r.

For 3D vectors, standard 3D cross product is used:

\[\vec{v} = \vec{w} \times \vec{r}\]

For 2D vectors, angular velocity \(w\) is a scalar (z-component) and position \(\vec{r}\) is 2D:

\[\vec{v} = (-w \cdot r_y, \, w \cdot r_x)\]
Parameters:
  • w (jax.Array) – Angular velocity. Shape (…, 3) in 3D, or (…, 1) or (…) in 2D.

  • r (jax.Array) – Position vector. Shape (…, 3) or (…, 2).

Returns:

Tangential velocity. Shape matches r.

Return type:

jax.Array

Raises:

ValueError – If dimensions are incompatible.