jaxdem.utils.linalg#

Utility functions to help with linear algebra.

Functions

cross(a, b)

Computes the cross product of two vectors, 'a' and 'b', along their last axis.

cross_3X3D_1X2D(w, r)

Computes the cross product of angular velocity vector (w) and a position vector (r), often used to find tangential velocity: v = w x r.

dot(a, b)

Dot product of vectors along the last axis.

norm(v)

Norm of vectors along the last axis.

norm2(v)

Squared norm of vectors along the last axis.

unit(v)

Normalize vectors along the last axis.

unit_and_norm(v)

Normalize vectors along the last axis and return the norm.

jaxdem.utils.linalg.cross(a: Array, b: Array) Array[source]#

Computes the cross product of two vectors, ‘a’ and ‘b’, along their last axis.

For 3D vectors (D=3), the result is a vector orthogonal to both ‘a’ and ‘b’. For 2D vectors (D=2), the result is the scalar magnitude of the 3D cross product when a third zero component is assumed, often interpreted as the signed area of the parallelogram spanned by the vectors.

Parameters:
  • a (JAX Array with shape (..., D), where D is the dimension (2 or 3).)

  • b (JAX Array with shape (..., D), where D must match a's dimension.)

Returns:

  • A JAX Array representing the cross product.

  • - If D=3 (shape is (…, 3).)

  • - If D=2 (shape is (…, 1) (a scalar wrapped in an array).)

Raises:

ValueError – If the last dimension (D) is not 2 or 3, or if the last dimensions of ‘a’ and ‘b’ do not match.:

jaxdem.utils.linalg.dot(a: Array, b: Array) Array[source]#

Dot product of vectors along the last axis.

a, b: (…, D) returns: (…), the dot product.

jaxdem.utils.linalg.norm2(v: Array) Array[source]#

Squared norm of vectors along the last axis.

v: (…, D) returns: (…), the squared norm.

jaxdem.utils.linalg.norm(v: Array) Array[source]#

Norm of vectors along the last axis.

v: (…, D) returns: (…), the norm.

jaxdem.utils.linalg.unit(v: Array) Array[source]#

Normalize vectors along the last axis.

v: (…, D) returns: (…, D), unit vectors; zeros map to zeros.

jaxdem.utils.linalg.unit_and_norm(v: Array) tuple[Array, Array][source]#

Normalize vectors along the last axis and return the norm.

v: (…, D) returns: ((…, D), (…, 1)), unit vectors and their norms; zeros map to zeros.

jaxdem.utils.linalg.cross_3X3D_1X2D(w: Array, r: Array) Array[source]#

Computes the cross product of angular velocity vector (w) and a position vector (r), often used to find tangential velocity: v = w x r.

This function handles two scenarios based on the dimension of ‘r’:

  1. 3D Case (r.shape[-1] == 3): - w must be a 3D vector (w.shape[-1] == 3). - Computes the standard 3D cross product: w x r.

  2. 2D Case (r.shape[-1] == 2): - w is treated as a scalar (the z-component of angular velocity, w_z). - The computation is equivalent to: (0, 0, w_z) x (r_x, r_y, 0). - The result is the 2D tangential velocity vector (v_x, v_y) in the xy-plane.

Parameters:
  • w (JAX Array. In the 3D case, shape is (..., 3). In the 2D case, shape is (..., 1) or (...).)

  • r (JAX Array. Shape is (..., 3) or (..., 2).)

Returns:

  • A JAX Array representing the tangential velocity (w x r).

  • - If r is 3D, the output shape is (…, 3).

  • - If r is 2D, the output shape is (…, 2).

Raises:

ValueError – If r is not 2D or 3D, or if dimensions are incompatible.: