jaxdem.utils.environment#
Utility functions to handle environments and LIDAR sensor.
Functions
|
2-D LIDAR proximity and IDs from |
|
3-D LIDAR proximity and IDs from |
|
Advance the environment n steps using actions from model. |
|
Roll out a trajectory by applying model in chunks of stride steps and collecting the environment after each chunk. |
|
2-D LIDAR proximity readings and neighbor IDs. |
|
3-D LIDAR proximity readings and neighbor IDs. |
- jaxdem.utils.environment.cross_lidar_2d(pos_a: jax.Array, pos_b: jax.Array, system: System, lidar_range: float, n_bins: int, max_neighbors: int) tuple[jax.Array, jax.Array, jax.Array][source]#
2-D LIDAR proximity and IDs from
pos_asensing targets inpos_b.Computes all-pairs displacements from
pos_atopos_b, bins by azimuthal angle, and returns per-bin proximity and closest target IDs.- Parameters:
pos_a (jax.Array) – Sensor positions, shape
(N_A, dim).pos_b (jax.Array) – Target positions, shape
(N_B, dim).system (System) – System configuration.
lidar_range (float) – Maximum detection range and reference distance for proximity.
n_bins (int) – Number of angular bins spanning \([-\pi, \pi)\).
max_neighbors (int) – Unused. Kept for backward compatibility.
- Returns:
(proximity, ids, overflow)whereproximityandidshave shape(N_A, n_bins)andoverflowis alwaysFalse. Empty bins getids = -1.- Return type:
Tuple[jax.Array, jax.Array, jax.Array]
Notes
Uses an all-pairs approach and does not invoke the collider. Returned
idsare indices intopos_bregardless of howpos_amay have been reordered by a cell-list collider.Examples
>>> prox, ids, overflow = cross_lidar_2d(agents, obstacles, system, ... lidar_range=5.0, n_bins=36, ... max_neighbors=64)
- jaxdem.utils.environment.cross_lidar_3d(pos_a: jax.Array, pos_b: jax.Array, system: System, lidar_range: float, n_azimuth: int, n_elevation: int, max_neighbors: int) tuple[jax.Array, jax.Array, jax.Array][source]#
3-D LIDAR proximity and IDs from
pos_asensing targets inpos_b.Computes all-pairs displacements from
pos_atopos_b, bins on a spherical grid, and returns per-bin proximity and closest target IDs.- Parameters:
pos_a (jax.Array) – Sensor positions, shape
(N_A, 3).pos_b (jax.Array) – Target positions, shape
(N_B, 3).system (System) – System configuration.
lidar_range (float) – Maximum detection range and reference distance for proximity.
n_azimuth (int) – Number of azimuthal bins.
n_elevation (int) – Number of elevation bins.
max_neighbors (int) – Unused. Kept for backward compatibility.
- Returns:
(proximity, ids, overflow)whereproximityandidshave shape(N_A, n_azimuth * n_elevation)andoverflowis alwaysFalse. Empty bins getids = -1.- Return type:
Tuple[jax.Array, jax.Array, jax.Array]
Notes
Uses an all-pairs approach and does not invoke the collider. Returned
idsare indices intopos_bregardless of howpos_amay have been reordered by a cell-list collider.Examples
>>> prox, ids, overflow = cross_lidar_3d(agents, obstacles, system, ... lidar_range=5.0, n_azimuth=36, ... n_elevation=18, max_neighbors=64)
- jaxdem.utils.environment.env_step(env: Environment, model: Callable[..., Any], key: jax.Array, *, n: int = 1, **kw: Any) tuple[Environment, jax.Array][source]#
Advance the environment n steps using actions from model.
- Parameters:
env (Environment) – Initial environment pytree (batchable).
model (Callable) – Callable with signature model(obs, key, **kw) -> action.
key (jax.Array) – JAX random key. The returned key is the advanced version that should be used for subsequent calls.
n (int) – Number of steps to perform.
**kw (Any) – Extra keyword arguments forwarded to model.
- Returns:
Updated environment and the advanced random key.
- Return type:
Tuple[Environment, jax.Array]
Examples
>>> env, key = env_step(env, model, key, n=10, objective=goal)
- jaxdem.utils.environment.env_trajectory_rollout(env: Environment, model: Callable[..., Any], key: jax.Array, *, n: int, stride: int = 1, **kw: Any) tuple[Environment, jax.Array, Environment][source]#
Roll out a trajectory by applying model in chunks of stride steps and collecting the environment after each chunk.
- Parameters:
env (Environment) – Initial environment pytree.
model (Callable) – Callable with signature model(obs, key, **kw) -> action.
key (jax.Array) – JAX random key. The returned key is the advanced version that should be used for subsequent calls.
n (int) – Number of chunks to roll out. Total internal steps = n * stride.
stride (int) – Steps per chunk between recorded snapshots.
**kw (Any) – Extra keyword arguments passed to model on every step.
- Returns:
Final environment, advanced random key, and a stacked pytree of environments with length n, each snapshot taken after a chunk of stride steps.
- Return type:
Tuple[Environment, jax.Array, Environment]
Examples
>>> env, key, traj = env_trajectory_rollout(env, model, key, n=100, stride=5, objective=goal)
- jaxdem.utils.environment.lidar_2d(state: State, system: System, lidar_range: float, n_bins: int, max_neighbors: int, sense_edges: bool = False) tuple[State, System, jax.Array, jax.Array, jax.Array][source]#
2-D LIDAR proximity readings and neighbor IDs.
For every particle in
statethe displacement vectors to all other particles are projected onto the \(xy\)-plane and binned by azimuthal angle inton_binsuniform sectors spanning \([-\pi, \pi)\). Each bin stores the proximity value and the index of the closest neighbor in that sector:\[p_k = \max(0,\; r_{\max} - d_{\min,k})\]This works identically for 2-D and 3-D position data; in the 3-D case the \(z\)-component of the displacement is simply ignored during binning while the full Euclidean distance is used for proximity.
- Parameters:
state (State) – Simulation state (positions, radii, etc.).
system (System) – System configuration including domain.
lidar_range (float) – Maximum detection range and reference distance for proximity.
n_bins (int) – Number of angular bins (rays) spanning \([-\pi, \pi)\).
max_neighbors (int) – Unused. Kept for backward compatibility.
sense_edges (bool, optional) – If
True, domain boundaries are included as proximity sources. Wall detections receive an ID of-1. Only meaningful for bounded domains. Default isFalse.
- Returns:
(state, system, proximity, ids, overflow)wherestateandsystemare unchanged,proximityandidshave shape(N, n_bins), andoverflowis alwaysFalse. Bins with no detection haveidsset to the particle’s own index.- Return type:
Notes
This function computes all-pairs displacements directly from
state.posand does not invoke the collider. The returnedidsare indices intostate.posin whatever order it has at call time, so results are correct regardless of whether a cell-list collider has reordered the state.Examples
>>> state, system, prox, ids, overflow = lidar_2d(state, system, ... lidar_range=5.0, n_bins=36, max_neighbors=64)
- jaxdem.utils.environment.lidar_3d(state: State, system: System, lidar_range: float, n_azimuth: int, n_elevation: int, max_neighbors: int, sense_edges: bool = False) tuple[State, System, jax.Array, jax.Array, jax.Array][source]#
3-D LIDAR proximity readings and neighbor IDs.
Similar to
lidar_2d()but bins neighbors on a spherical grid defined byn_azimuthazimuthal sectors in \([-\pi, \pi)\) andn_elevationelevation bands in \([-\pi/2, \pi/2]\). The returned proximity and ID arrays have shape(N, n_azimuth * n_elevation)with flat indexingaz * n_elevation + el.- Parameters:
state (State) – Simulation state.
system (System) – System configuration including domain.
lidar_range (float) – Maximum detection range and reference distance for proximity.
n_azimuth (int) – Number of azimuthal bins.
n_elevation (int) – Number of elevation bins.
max_neighbors (int) – Unused. Kept for backward compatibility.
sense_edges (bool, optional) – If
True, domain boundaries are included as proximity sources. Wall detections receive an ID of-1. Default isFalse.
- Returns:
(state, system, proximity, ids, overflow)wherestateandsystemare unchanged,proximityandidshave shape(N, n_azimuth * n_elevation), andoverflowis alwaysFalse.- Return type:
Notes
Uses an all-pairs approach and does not invoke the collider. Returned
idsindex intostate.posin its current order, so results are correct regardless of collider-induced reordering.Examples
>>> state, system, prox, ids, overflow = lidar_3d(state, system, ... lidar_range=5.0, n_azimuth=36, n_elevation=18, max_neighbors=64)