Source code for zyra.visualization.vector_particles_manager

# SPDX-License-Identifier: Apache-2.0
"""Render particle advection frames over a vector field (U/V).

Supports NetCDF (time, lat, lon) variables or 3D NumPy stacks. Particles are
seeded on a grid, at random, or from a CSV and advected using Euler or RK2.
"""

from __future__ import annotations

from dataclasses import asdict, dataclass
from io import BytesIO
from pathlib import Path
from typing import Any, Optional, Sequence

from zyra.utils.geo_utils import detect_crs_from_path, warn_if_mismatch

from .base import Renderer
from .basemap import add_basemap_cartopy
from .styles import DEFAULT_EXTENT, FIGURE_DPI, MAP_STYLES, apply_matplotlib_style


[docs] @dataclass class ParticleFrame: index: int path: str
[docs] class VectorParticlesManager(Renderer): def __init__( self, *, basemap: Optional[str] = None, extent: Optional[Sequence[float]] = None, color: str = "#333333", size: float = 0.5, method: str = "euler", ) -> None: self.basemap = basemap self.extent = list(extent) if extent is not None else list(DEFAULT_EXTENT) self.color = color self.size = size self.method = method self._manifest = {}
[docs] def configure(self, **kwargs: Any) -> None: self.basemap = kwargs.get("basemap", self.basemap) self.extent = list(kwargs.get("extent", self.extent)) self.color = kwargs.get("color", self.color) self.size = kwargs.get("size", self.size) self.method = kwargs.get("method", self.method)
# Data handling def _load_stacks( self, *, input_path: Optional[str] = None, uvar: Optional[str] = None, vvar: Optional[str] = None, u_path: Optional[str] = None, v_path: Optional[str] = None, ) -> tuple[Any, Any]: import numpy as np if u_path and v_path: U = np.load(u_path) V = np.load(v_path) return U, V if input_path and input_path.lower().endswith((".nc", ".nc4")): import xarray as xr if not uvar or not vvar: raise ValueError("NetCDF inputs require --uvar and --vvar") ds = xr.open_dataset(input_path) try: U = ds[uvar].values V = ds[vvar].values finally: ds.close() return U, V raise ValueError( "Provide either --u/--v .npy stacks or --input .nc with --uvar/--vvar" ) # Seeding helpers def _seed_particles(self, seed: str, particles: int) -> tuple[Any, Any]: import numpy as np west, east, south, north = self.extent if seed == "grid": # Square grid close to requested count nx = int(max(2, round((particles) ** 0.5))) ny = nx xs = np.linspace(west, east, nx) ys = np.linspace(south, north, ny) X, Y = np.meshgrid(xs, ys) return X.ravel(), Y.ravel() elif seed == "random": X = np.random.uniform(west, east, size=particles) Y = np.random.uniform(south, north, size=particles) return X, Y else: raise ValueError( "custom seeding requires render(..., custom_seed=path_to_csv)" ) def _seed_custom(self, csv_path: str) -> tuple[Any, Any]: import pandas as pd df = pd.read_csv(csv_path) if not {"lon", "lat"}.issubset(df.columns): raise ValueError("custom seed CSV must have columns 'lon' and 'lat'") return df["lon"].to_numpy(), df["lat"].to_numpy() # Velocity sampling (nearest neighbor for simplicity) def _sample_uv(self, U: Any, V: Any, lon: Any, lat: Any) -> tuple[Any, Any]: import numpy as np ny, nx = U.shape[-2], U.shape[-1] west, east, south, north = self.extent # Map lon/lat to fractional indices fx = (lon - west) / (east - west) * (nx - 1) fy = (lat - south) / (north - south) * (ny - 1) ix = np.clip(np.round(fx).astype(int), 0, nx - 1) iy = np.clip(np.round(fy).astype(int), 0, ny - 1) return U[iy, ix], V[iy, ix] def _step_euler( self, U: Any, V: Any, lon: Any, lat: Any, dt: float ) -> tuple[Any, Any]: u, v = self._sample_uv(U, V, lon, lat) return lon + u * dt, lat + v * dt def _step_rk2( self, U: Any, V: Any, lon: Any, lat: Any, dt: float ) -> tuple[Any, Any]: # Midpoint method u1, v1 = self._sample_uv(U, V, lon, lat) lon_mid = lon + 0.5 * dt * u1 lat_mid = lat + 0.5 * dt * v1 u2, v2 = self._sample_uv(U, V, lon_mid, lat_mid) return lon + dt * u2, lat + dt * v2 def _wrap_clamp(self, lon: Any, lat: Any) -> tuple[Any, Any]: import numpy as np west, east, south, north = self.extent # Wrap lon at 180/-180 boundaries; assume world extent [-180, 180] span = east - west lon = ((lon - west) % span) + west lat = np.clip(lat, south, north) return lon, lat
[docs] def render(self, data: Any = None, **kwargs: Any): # Rendering loop writes frames to disk and remembers a manifest # Inputs input_path = kwargs.get("input_path") uvar = kwargs.get("uvar") vvar = kwargs.get("vvar") u_kw = kwargs.get("u") v_kw = kwargs.get("v") seed = kwargs.get("seed", "grid") particles = int(kwargs.get("particles", 200)) custom_seed = kwargs.get("custom_seed") dt = float(kwargs.get("dt", 0.01)) steps_per_frame = int(kwargs.get("steps_per_frame", 1)) method = kwargs.get("method", self.method) width = int(kwargs.get("width", 1024)) height = int(kwargs.get("height", 512)) dpi = int(kwargs.get("dpi", FIGURE_DPI)) color = kwargs.get("color", self.color) size = float(kwargs.get("size", self.size)) output_dir = Path(kwargs.get("output_dir", ".")) filename_template = kwargs.get("filename_template", "frame_{index:04d}.png") import cartopy.crs as ccrs import matplotlib.pyplot as plt import numpy as np # CRS detection try: user_crs = kwargs.get("crs") reproject = bool(kwargs.get("reproject", False)) in_path = ( input_path or (u_kw if isinstance(u_kw, (str, bytes)) else None) or (v_kw if isinstance(v_kw, (str, bytes)) else None) ) in_crs = user_crs or (detect_crs_from_path(in_path) if in_path else None) warn_if_mismatch(in_crs, reproject=reproject, context="particles") except Exception: pass # Load stacks or accept arrays directly if hasattr(u_kw, "ndim") and hasattr(v_kw, "ndim"): U = np.asarray(u_kw) V = np.asarray(v_kw) if U.ndim == 2 and V.ndim == 2: U = U[None, ...] V = V[None, ...] else: U, V = self._load_stacks( input_path=input_path, uvar=uvar, vvar=vvar, u_path=u_kw if isinstance(u_kw, (str, bytes)) else None, v_path=v_kw if isinstance(v_kw, (str, bytes)) else None, ) U = np.asarray(U) V = np.asarray(V) if U.ndim == 2 and V.ndim == 2: U = U[None, ...] V = V[None, ...] if U.shape != V.shape or U.ndim != 3: raise ValueError("U/V must be 3D stacks [time, y, x] with matching shapes") T = U.shape[0] # Seed particles if seed == "custom": if not custom_seed: raise ValueError("Provide custom_seed path for seed='custom'") lon, lat = self._seed_custom(custom_seed) else: lon, lat = self._seed_particles(seed, particles) output_dir.mkdir(parents=True, exist_ok=True) frames: list[ParticleFrame] = [] # Select integrator step_fn = ( self._step_rk2 if method.lower() in ("rk2", "midpoint") else self._step_euler ) # Frame loop for i in range(T): # Integrate substeps within frame using current time slice for _ in range(max(1, steps_per_frame)): lon, lat = step_fn(U[i], V[i], lon, lat, dt) lon, lat = self._wrap_clamp(lon, lat) # Draw apply_matplotlib_style() fig, ax = plt.subplots( figsize=(width / dpi, height / dpi), dpi=dpi, subplot_kw={"projection": ccrs.PlateCarree()}, ) add_basemap_cartopy( ax, self.extent, image_path=self.basemap, features=MAP_STYLES.get("features"), ) ax.scatter( lon, lat, s=size, c=color, transform=ccrs.PlateCarree(), alpha=0.9, linewidths=0, ) ax.set_global() ax.axis("off") fig.subplots_adjust(left=0, right=1, top=1, bottom=0) fpath = output_dir / filename_template.format(index=i) fig.savefig(fpath, bbox_inches="tight", pad_inches=0) plt.close(fig) frames.append(ParticleFrame(index=i, path=str(fpath))) self._manifest = { "mode": "particles", "count": len(frames), "frames": [asdict(f) for f in frames], "params": { "seed": seed, "particles": particles, "dt": dt, "steps_per_frame": steps_per_frame, "method": method, "color": color, "size": size, }, } return self._manifest
[docs] def save(self, output_path: Optional[str] = None, *, as_buffer: bool = False): import json if not self._manifest: return None if as_buffer: bio = BytesIO() bio.write(json.dumps(self._manifest, indent=2).encode("utf-8")) bio.seek(0) return bio if output_path is None: # default manifest next to frames first = self._manifest.get("frames", [{}])[0].get("path") # Use current directory as explicit fallback when no frames exist base = Path(first).parent if first else Path(".") # noqa: PTH201 output_path = str(base / "manifest.json") Path(output_path).write_text( json.dumps(self._manifest, indent=2), encoding="utf-8" ) return output_path