jaxdem.utils.environment#

Utility functions to handle environments and LIDAR sensor.

Functions

cross_lidar_2d(pos_a, pos_b, system, ...)

2-D LIDAR proximity and IDs from pos_a sensing targets in pos_b.

cross_lidar_3d(pos_a, pos_b, system, ...)

3-D LIDAR proximity and IDs from pos_a sensing targets in pos_b.

env_step(env, model, key, *[, n])

Advance the environment n steps using actions from model.

env_trajectory_rollout(env, model, key, *, n)

Roll out a trajectory by applying model in chunks of stride steps and collecting the environment after each chunk.

lidar_2d(state, system, lidar_range, n_bins, ...)

2-D LIDAR proximity readings and neighbor IDs.

lidar_3d(state, system, lidar_range, ...[, ...])

3-D LIDAR proximity readings and neighbor IDs.

jaxdem.utils.environment.cross_lidar_2d(pos_a: jax.Array, pos_b: jax.Array, system: System, lidar_range: float, n_bins: int, max_neighbors: int) tuple[jax.Array, jax.Array, jax.Array][source]#

2-D LIDAR proximity and IDs from pos_a sensing targets in pos_b.

Computes all-pairs displacements from pos_a to pos_b, bins by azimuthal angle, and returns per-bin proximity and closest target IDs.

Parameters:
  • pos_a (jax.Array) – Sensor positions, shape (N_A, dim).

  • pos_b (jax.Array) – Target positions, shape (N_B, dim).

  • system (System) – System configuration.

  • lidar_range (float) – Maximum detection range and reference distance for proximity.

  • n_bins (int) – Number of angular bins spanning \([-\pi, \pi)\).

  • max_neighbors (int) – Unused. Kept for backward compatibility.

Returns:

(proximity, ids, overflow) where proximity and ids have shape (N_A, n_bins) and overflow is always False. Empty bins get ids = -1.

Return type:

Tuple[jax.Array, jax.Array, jax.Array]

Notes

Uses an all-pairs approach and does not invoke the collider. Returned ids are indices into pos_b regardless of how pos_a may have been reordered by a cell-list collider.

Examples

>>> prox, ids, overflow = cross_lidar_2d(agents, obstacles, system,
...                                      lidar_range=5.0, n_bins=36,
...                                      max_neighbors=64)
jaxdem.utils.environment.cross_lidar_3d(pos_a: jax.Array, pos_b: jax.Array, system: System, lidar_range: float, n_azimuth: int, n_elevation: int, max_neighbors: int) tuple[jax.Array, jax.Array, jax.Array][source]#

3-D LIDAR proximity and IDs from pos_a sensing targets in pos_b.

Computes all-pairs displacements from pos_a to pos_b, bins on a spherical grid, and returns per-bin proximity and closest target IDs.

Parameters:
  • pos_a (jax.Array) – Sensor positions, shape (N_A, 3).

  • pos_b (jax.Array) – Target positions, shape (N_B, 3).

  • system (System) – System configuration.

  • lidar_range (float) – Maximum detection range and reference distance for proximity.

  • n_azimuth (int) – Number of azimuthal bins.

  • n_elevation (int) – Number of elevation bins.

  • max_neighbors (int) – Unused. Kept for backward compatibility.

Returns:

(proximity, ids, overflow) where proximity and ids have shape (N_A, n_azimuth * n_elevation) and overflow is always False. Empty bins get ids = -1.

Return type:

Tuple[jax.Array, jax.Array, jax.Array]

Notes

Uses an all-pairs approach and does not invoke the collider. Returned ids are indices into pos_b regardless of how pos_a may have been reordered by a cell-list collider.

Examples

>>> prox, ids, overflow = cross_lidar_3d(agents, obstacles, system,
...                                      lidar_range=5.0, n_azimuth=36,
...                                      n_elevation=18, max_neighbors=64)
jaxdem.utils.environment.env_step(env: Environment, model: Callable[..., Any], key: jax.Array, *, n: int = 1, **kw: Any) tuple[Environment, jax.Array][source]#

Advance the environment n steps using actions from model.

Parameters:
  • env (Environment) – Initial environment pytree (batchable).

  • model (Callable) – Callable with signature model(obs, key, **kw) -> action.

  • key (jax.Array) – JAX random key. The returned key is the advanced version that should be used for subsequent calls.

  • n (int) – Number of steps to perform.

  • **kw (Any) – Extra keyword arguments forwarded to model.

Returns:

Updated environment and the advanced random key.

Return type:

Tuple[Environment, jax.Array]

Examples

>>> env, key = env_step(env, model, key, n=10, objective=goal)
jaxdem.utils.environment.env_trajectory_rollout(env: Environment, model: Callable[..., Any], key: jax.Array, *, n: int, stride: int = 1, **kw: Any) tuple[Environment, jax.Array, Environment][source]#

Roll out a trajectory by applying model in chunks of stride steps and collecting the environment after each chunk.

Parameters:
  • env (Environment) – Initial environment pytree.

  • model (Callable) – Callable with signature model(obs, key, **kw) -> action.

  • key (jax.Array) – JAX random key. The returned key is the advanced version that should be used for subsequent calls.

  • n (int) – Number of chunks to roll out. Total internal steps = n * stride.

  • stride (int) – Steps per chunk between recorded snapshots.

  • **kw (Any) – Extra keyword arguments passed to model on every step.

Returns:

Final environment, advanced random key, and a stacked pytree of environments with length n, each snapshot taken after a chunk of stride steps.

Return type:

Tuple[Environment, jax.Array, Environment]

Examples

>>> env, key, traj = env_trajectory_rollout(env, model, key, n=100, stride=5, objective=goal)
jaxdem.utils.environment.lidar_2d(state: State, system: System, lidar_range: float, n_bins: int, max_neighbors: int, sense_edges: bool = False) tuple[State, System, jax.Array, jax.Array, jax.Array][source]#

2-D LIDAR proximity readings and neighbor IDs.

For every particle in state the displacement vectors to all other particles are projected onto the \(xy\)-plane and binned by azimuthal angle into n_bins uniform sectors spanning \([-\pi, \pi)\). Each bin stores the proximity value and the index of the closest neighbor in that sector:

\[p_k = \max(0,\; r_{\max} - d_{\min,k})\]

This works identically for 2-D and 3-D position data; in the 3-D case the \(z\)-component of the displacement is simply ignored during binning while the full Euclidean distance is used for proximity.

Parameters:
  • state (State) – Simulation state (positions, radii, etc.).

  • system (System) – System configuration including domain.

  • lidar_range (float) – Maximum detection range and reference distance for proximity.

  • n_bins (int) – Number of angular bins (rays) spanning \([-\pi, \pi)\).

  • max_neighbors (int) – Unused. Kept for backward compatibility.

  • sense_edges (bool, optional) – If True, domain boundaries are included as proximity sources. Wall detections receive an ID of -1. Only meaningful for bounded domains. Default is False.

Returns:

(state, system, proximity, ids, overflow) where state and system are unchanged, proximity and ids have shape (N, n_bins), and overflow is always False. Bins with no detection have ids set to the particle’s own index.

Return type:

Tuple[State, System, jax.Array, jax.Array, jax.Array]

Notes

This function computes all-pairs displacements directly from state.pos and does not invoke the collider. The returned ids are indices into state.pos in whatever order it has at call time, so results are correct regardless of whether a cell-list collider has reordered the state.

Examples

>>> state, system, prox, ids, overflow = lidar_2d(state, system,
...     lidar_range=5.0, n_bins=36, max_neighbors=64)
jaxdem.utils.environment.lidar_3d(state: State, system: System, lidar_range: float, n_azimuth: int, n_elevation: int, max_neighbors: int, sense_edges: bool = False) tuple[State, System, jax.Array, jax.Array, jax.Array][source]#

3-D LIDAR proximity readings and neighbor IDs.

Similar to lidar_2d() but bins neighbors on a spherical grid defined by n_azimuth azimuthal sectors in \([-\pi, \pi)\) and n_elevation elevation bands in \([-\pi/2, \pi/2]\). The returned proximity and ID arrays have shape (N, n_azimuth * n_elevation) with flat indexing az * n_elevation + el.

Parameters:
  • state (State) – Simulation state.

  • system (System) – System configuration including domain.

  • lidar_range (float) – Maximum detection range and reference distance for proximity.

  • n_azimuth (int) – Number of azimuthal bins.

  • n_elevation (int) – Number of elevation bins.

  • max_neighbors (int) – Unused. Kept for backward compatibility.

  • sense_edges (bool, optional) – If True, domain boundaries are included as proximity sources. Wall detections receive an ID of -1. Default is False.

Returns:

(state, system, proximity, ids, overflow) where state and system are unchanged, proximity and ids have shape (N, n_azimuth * n_elevation), and overflow is always False.

Return type:

Tuple[State, System, jax.Array, jax.Array, jax.Array]

Notes

Uses an all-pairs approach and does not invoke the collider. Returned ids index into state.pos in its current order, so results are correct regardless of collider-induced reordering.

Examples

>>> state, system, prox, ids, overflow = lidar_3d(state, system,
...     lidar_range=5.0, n_azimuth=36, n_elevation=18, max_neighbors=64)