jaxdem.utils.linalg#
Utility functions to help with linear algebra.
Functions
|
Computes the cross product of two vectors, 'a' and 'b', along their last axis. |
|
Computes the cross product of angular velocity vector (w) and a position vector (r), often used to find tangential velocity: v = w x r. |
|
Dot product of vectors along the last axis. |
|
Norm (magnitude) of vectors along the last axis. |
|
Squared norm of vectors along the last axis. |
|
Normalize vectors to unit vectors along the last axis. |
Normalize vectors along the last axis and return both unit vectors and norms. |
- jaxdem.utils.linalg.cross(a: Array, b: Array) Array[source]#
Computes the cross product of two vectors, ‘a’ and ‘b’, along their last axis.
For 3D vectors (\(D=3\)), the result is a vector orthogonal to both ‘a’ and ‘b’:
\[\vec{c} = \vec{a} \times \vec{b} = (a_y b_z - a_z b_y) \mathbf{i} + (a_z b_x - a_x b_z) \mathbf{j} + (a_x b_y - a_y b_x) \mathbf{k}\]For 2D vectors (\(D=2\)), the result is the scalar magnitude of the 3D cross product:
\[c = a_x b_y - a_y b_x\]- Parameters:
a (jax.Array) – First vector. Shape (…, D), where D is the dimension (2 or 3).
b (jax.Array) – Second vector. Shape (…, D).
- Returns:
The cross product. - If D=3: shape is (…, 3). - If D=2: shape is (…, 1).
- Return type:
jax.Array
- Raises:
ValueError – If the last dimension is not 2 or 3, or if the dimensions of ‘a’ and ‘b’ do not match.
- jaxdem.utils.linalg.dot(a: Array, b: Array) Array[source]#
Dot product of vectors along the last axis.
\[c = \vec{a} \cdot \vec{b} = \sum_{i} a_i b_i\]- Parameters:
a (jax.Array) – First vector. Shape (…, D).
b (jax.Array) – Second vector. Shape (…, D).
- Returns:
The dot product. Shape (…).
- Return type:
jax.Array
- jaxdem.utils.linalg.norm2(v: Array) Array[source]#
Squared norm of vectors along the last axis.
\[\|v\|^2 = \vec{v} \cdot \vec{v} = \sum_{i} v_i^2\]- Parameters:
v (jax.Array) – Input vector. Shape (…, D).
- Returns:
The squared norm. Shape (…).
- Return type:
jax.Array
- jaxdem.utils.linalg.norm(v: Array) Array[source]#
Norm (magnitude) of vectors along the last axis.
\[\|v\| = \sqrt{\sum_{i} v_i^2}\]- Parameters:
v (jax.Array) – Input vector. Shape (…, D).
- Returns:
The norm. Shape (…).
- Return type:
jax.Array
- jaxdem.utils.linalg.unit(v: Array) Array[source]#
Normalize vectors to unit vectors along the last axis.
If the vector is zero, the result is zero.
\[\begin{split}\hat{v} = \begin{cases} \frac{\vec{v}}{\|v\|} & \text{if } \|v\| > 0 \\ \vec{0} & \text{otherwise} \end{cases}\end{split}\]- Parameters:
v (jax.Array) – Input vector. Shape (…, D).
- Returns:
Unit vector. Shape (…, D).
- Return type:
jax.Array
- jaxdem.utils.linalg.unit_and_norm(v: Array) tuple[Array, Array][source]#
Normalize vectors along the last axis and return both unit vectors and norms.
- Parameters:
v (jax.Array) – Input vector. Shape (…, D).
- Returns:
A tuple of (unit vectors, norms).
- Return type:
Tuple[jax.Array, jax.Array]
- jaxdem.utils.linalg.cross_3X3D_1X2D(w: Array, r: Array) Array[source]#
Computes the cross product of angular velocity vector (w) and a position vector (r), often used to find tangential velocity: v = w x r.
For 3D vectors, standard 3D cross product is used:
\[\vec{v} = \vec{w} \times \vec{r}\]For 2D vectors, angular velocity \(w\) is a scalar (z-component) and position \(\vec{r}\) is 2D:
\[\vec{v} = (-w \cdot r_y, \, w \cdot r_x)\]- Parameters:
w (jax.Array) – Angular velocity. Shape (…, 3) in 3D, or (…, 1) or (…) in 2D.
r (jax.Array) – Position vector. Shape (…, 3) or (…, 2).
- Returns:
Tangential velocity. Shape matches r.
- Return type:
jax.Array
- Raises:
ValueError – If dimensions are incompatible.