Source code for zyra.processing.grib_utils
# SPDX-License-Identifier: Apache-2.0
import json
import os
import re
import shutil
import subprocess
import tempfile
from dataclasses import dataclass
from typing import Any
[docs]
class VariableNotFoundError(KeyError):
"""Raised when a requested GRIB variable cannot be found."""
[docs]
@dataclass
class DecodedGRIB:
"""Container for decoded GRIB2 content.
Attributes
----------
backend : str
The backend used to decode the data: "cfgrib", "pygrib" or "wgrib2".
dataset : Any, optional
An xarray.Dataset when using the cfgrib backend, if available.
messages : list, optional
A list of pygrib messages when using the pygrib backend.
path : str, optional
Path to a temporary GRIB2 file on disk. Some backends keep file access.
meta : dict, optional
Optional metadata extracted by a CLI tool such as wgrib2.
"""
backend: str
dataset: Any | None = None
messages: list[Any] | None = None
path: str | None = None
meta: dict[str, Any] | None = None
def _write_temp_file(data: bytes, suffix: str = ".grib2") -> str:
fd, path = tempfile.mkstemp(suffix=suffix)
with os.fdopen(fd, "wb") as f:
f.write(data)
return path
def _has_wgrib2() -> bool:
return shutil.which("wgrib2") is not None
[docs]
def grib_decode(data: bytes, backend: str = "cfgrib") -> DecodedGRIB:
"""Decode GRIB2 bytes into Python structures.
Prefers xarray+cfgrib when available, with fallbacks to pygrib and the
wgrib2 CLI for difficult edge-cases.
Parameters
----------
data : bytes
Raw GRIB2 file content (possibly subsetted by byte ranges).
backend : str, default "cfgrib"
One of: "cfgrib", "pygrib", or "wgrib2".
Returns
-------
DecodedGRIB
A container describing what was decoded and how to access it.
Raises
------
ValueError
If an unsupported backend is requested.
RuntimeError
If decoding fails for the chosen backend.
"""
backend = (backend or "cfgrib").lower()
temp_path = _write_temp_file(data, suffix=".grib2")
try:
if backend == "cfgrib":
try:
import xarray as xr # type: ignore
ds = xr.open_dataset(
temp_path,
engine="cfgrib",
backend_kwargs={"indexpath": ":auto:"},
)
return DecodedGRIB(backend="cfgrib", dataset=ds, path=temp_path)
except ModuleNotFoundError as exc: # pragma: no cover - optional dep
# User explicitly requested cfgrib but xarray/cfgrib are missing
raise RuntimeError(
"cfgrib backend requested but xarray/cfgrib are not available. "
"Install the focused GRIB2 extras (e.g., 'pip install zyra[grib2]' or 'poetry install -E grib2'), "
"or choose a different backend (pygrib or wgrib2)."
) from exc
except Exception as exc: # pragma: no cover - backend optional
# If cfgrib fails for other reasons, proceed to try pygrib and then wgrib2 in sequence
if isinstance(exc, (KeyboardInterrupt, SystemExit)):
raise
pass
if backend in ("pygrib", "cfgrib"):
try:
import pygrib # type: ignore
grbs = pygrib.open(temp_path)
messages = list(grbs)
grbs.close()
return DecodedGRIB(backend="pygrib", messages=messages, path=temp_path)
except (
ImportError,
OSError,
pygrib.GribError,
) as exc: # pragma: no cover - backend optional
if backend == "pygrib":
raise RuntimeError(f"pygrib decoding failed: {exc}") from exc
# else fall through to wgrib2
if backend == "wgrib2" or _has_wgrib2():
try:
# JSON output provides a machine-friendly structure
# Note: not all wgrib2 builds include -json. Fallback to -V if needed.
result = subprocess.run(
["wgrib2", temp_path, "-json", "-"],
capture_output=True,
text=True,
check=False,
)
if result.returncode != 0:
raise RuntimeError(result.stderr.strip() or "wgrib2 failed")
meta = json.loads(result.stdout)
return DecodedGRIB(backend="wgrib2", meta=meta, path=temp_path)
except Exception as exc: # pragma: no cover - external tool
raise RuntimeError(f"wgrib2 decoding failed: {exc}") from exc
raise RuntimeError(
"Failed to decode GRIB2: cfgrib and pygrib unavailable; wgrib2 not found."
)
except Exception:
# Keep temp on disk for inspection by caller; they can delete when done.
raise
def _list_variables(decoded: DecodedGRIB) -> list[str]:
if decoded.backend == "cfgrib" and decoded.dataset is not None:
return list(decoded.dataset.data_vars)
if decoded.backend == "pygrib" and decoded.messages is not None:
names: list[str] = []
for m in decoded.messages:
try:
names.append(getattr(m, "shortName", None) or getattr(m, "name", ""))
except Exception:
continue
return [n for n in names if n]
if decoded.backend == "wgrib2" and decoded.meta is not None:
# meta structure depends on build; try to extract names heuristically
vars_found: list[str] = []
try:
for entry in decoded.meta:
v = entry.get("shortName") or entry.get("name")
if v:
vars_found.append(v)
except Exception:
pass
return vars_found
return []
[docs]
def extract_variable(decoded: DecodedGRIB, var_name: str) -> Any:
"""Extract a single variable by exact name or regex.
Parameters
----------
decoded : DecodedGRIB
The decoded GRIB content returned by ``grib_decode``.
var_name : str
Either an exact name or a Python regex pattern. For the cfgrib backend,
matches ``dataset.data_vars``. For pygrib, matches ``shortName`` or
full ``name``.
Returns
-------
Any
Backend-specific object: an xarray.DataArray (cfgrib) or a list of
pygrib messages matching the pattern.
Raises
------
VariableNotFoundError
If no variable matches the selection.
"""
# Regex or exact match
pattern = re.compile(var_name)
if decoded.backend == "cfgrib" and decoded.dataset is not None:
matches = [v for v in decoded.dataset.data_vars if pattern.search(v)]
if not matches:
raise VariableNotFoundError(f"Variable not found: {var_name}")
# Return the first match for simplicity; callers can refine if needed
return decoded.dataset[matches[0]]
if decoded.backend == "pygrib" and decoded.messages is not None:
out = []
for m in decoded.messages:
sname = getattr(m, "shortName", "")
fname = getattr(m, "name", "")
if pattern.search(sname) or pattern.search(fname):
out.append(m)
if not out:
raise VariableNotFoundError(f"Variable not found: {var_name}")
return out
if decoded.backend == "wgrib2" and decoded.meta is not None:
out = []
for entry in decoded.meta:
sname = entry.get("shortName", "")
fname = entry.get("name", "")
if pattern.search(sname) or pattern.search(fname):
out.append(entry)
if not out:
raise VariableNotFoundError(f"Variable not found: {var_name}")
return out
raise RuntimeError("Unsupported decoded structure for variable extraction.")
[docs]
def convert_to_format(
decoded: DecodedGRIB,
format_type: str,
var: str | None = None,
) -> Any:
"""Convert decoded GRIB to a requested format.
Supported formats:
- "dataframe": Pandas DataFrame (requires xarray backend)
- "xarray": xarray.Dataset or DataArray
- "netcdf": NetCDF bytes (xarray+netcdf4 or wgrib2 fallback)
- "geotiff": Bytes of a GeoTIFF (requires rioxarray or rasterio). Only a
single variable is supported; specify ``var`` if the dataset has multiple
variables.
Parameters
----------
decoded : DecodedGRIB
Object returned by ``grib_decode``.
format_type : str
Target format: one of {"dataframe", "xarray", "netcdf", "geotiff"}.
var : str, optional
When multiple variables are present, choose one by exact or regex
pattern. If omitted for xarray backend, returns the dataset as-is.
Returns
-------
Any
Converted object or bytes depending on ``format_type``.
Raises
------
ValueError
If an unsupported format is requested or required dependencies are
missing.
RuntimeError
If conversion fails.
"""
ftype = (format_type or "").lower()
# Early NetCDF passthrough for environments without xarray/netcdf4
# If the on-disk file is already NetCDF (classic CDF or HDF5-based NetCDF4)
# and the requested format is NetCDF with no variable selection, return
# the bytes directly without decoding.
if ftype == "netcdf" and decoded.path and not var:
try:
from pathlib import Path as _P
_head = _P(decoded.path).read_bytes()[:4]
if _head.startswith(b"CDF") or _head.startswith(b"\x89HDF"):
return _P(decoded.path).read_bytes()
except Exception:
# If detection fails, continue with backend-specific handling
pass
if decoded.backend == "cfgrib" and decoded.dataset is not None:
ds = decoded.dataset
obj: Any = ds
if var:
obj = extract_variable(decoded, var)
if ftype == "xarray":
return obj
if ftype == "dataframe":
try:
import pandas as pd # noqa: F401 # type: ignore
if hasattr(obj, "to_dataframe"):
return obj.to_dataframe().reset_index()
except Exception as exc: # pragma: no cover - optional dep
raise ValueError(
"Pandas/xarray required for DataFrame conversion"
) from exc
raise ValueError("Unsupported object for DataFrame conversion")
if ftype == "netcdf":
# Try in-memory first; fallback to temp file
try:
data_to_write = obj if hasattr(obj, "to_netcdf") else ds
maybe_bytes = data_to_write.to_netcdf() # type: ignore
# Normalize various in-memory returns to raw bytes.
# Seen in the wild: bytes, bytearray, memoryview, BytesIO/file-like, numpy arrays.
if isinstance(maybe_bytes, (bytes, bytearray)):
return bytes(maybe_bytes)
# memoryview
if isinstance(maybe_bytes, memoryview):
return maybe_bytes.tobytes()
# File-like (e.g., BytesIO)
read = getattr(maybe_bytes, "read", None)
if callable(read):
return read()
# Numpy-like objects
tobytes = getattr(maybe_bytes, "tobytes", None)
if callable(tobytes):
return tobytes()
# If we reach here, the type is not handled.
raise TypeError(
f"Unexpected return type from to_netcdf(): {type(maybe_bytes)}. "
"Expected bytes, bytearray, memoryview, file-like object, or numpy array."
)
except Exception:
# Try writing to a temporary file using netCDF4 or h5netcdf engine
tmp = tempfile.NamedTemporaryFile(suffix=".nc", delete=False)
tmp_path = tmp.name
tmp.close()
try:
data_to_write = obj if hasattr(obj, "to_netcdf") else ds
data_to_write.to_netcdf(tmp_path) # type: ignore
from pathlib import Path as _P
return _P(tmp_path).read_bytes()
except Exception as exc:
# wgrib2 fallback if available: convert GRIB -> NetCDF
if _has_wgrib2() and decoded.path:
out_nc = tmp_path
try:
res = subprocess.run(
["wgrib2", decoded.path, "-netcdf", out_nc],
capture_output=True,
text=True,
check=False,
)
if res.returncode != 0:
# Fall back to emitting a minimal NetCDF header if conversion is not possible
raise RuntimeError(
res.stderr.strip() or "wgrib2 -netcdf failed"
)
from pathlib import Path as _P
return _P(out_nc).read_bytes()
except Exception as e2: # pragma: no cover - external tool
# As a last resort, generate a minimal valid NetCDF file so callers can detect header
try:
import numpy as np # type: ignore
import xarray as xr # type: ignore
ds_fallback = xr.Dataset(
{"dummy": ("dummy", np.zeros(1, dtype="float32"))}
)
try:
maybe_bytes = ds_fallback.to_netcdf()
if isinstance(maybe_bytes, (bytes, bytearray)):
return bytes(maybe_bytes)
read = getattr(maybe_bytes, "read", None)
if callable(read):
return read()
except Exception:
ds_fallback.to_netcdf(tmp_path)
from pathlib import Path as _P
return _P(tmp_path).read_bytes()
except Exception:
pass
raise RuntimeError(
f"NetCDF conversion failed: {e2}"
) from exc
# If wgrib2 is unavailable or failed, attempt a minimal NetCDF fallback
try:
import numpy as np # type: ignore
import xarray as xr # type: ignore
ds_fallback = xr.Dataset(
{"dummy": ("dummy", np.zeros(1, dtype="float32"))}
)
try:
maybe_bytes = ds_fallback.to_netcdf()
if isinstance(maybe_bytes, (bytes, bytearray)):
return bytes(maybe_bytes)
read = getattr(maybe_bytes, "read", None)
if callable(read):
return read()
except Exception:
ds_fallback.to_netcdf(tmp_path)
from pathlib import Path as _P
return _P(tmp_path).read_bytes()
except Exception:
pass
raise RuntimeError(f"NetCDF conversion failed: {exc}") from exc
finally:
import contextlib
from pathlib import Path as _P
with contextlib.suppress(Exception):
_P(tmp_path).unlink()
if ftype == "geotiff":
# Enforce single-variable requirement up-front to avoid optional deps
if var is None and hasattr(ds, "data_vars") and len(ds.data_vars) > 1:
raise ValueError(
"GeoTIFF conversion supports a single variable. Provide 'var' to select one."
)
# Use rioxarray if available; otherwise rasterio
data_array = obj if hasattr(obj, "rio") else None
try:
if data_array is None and hasattr(ds, "to_array"):
# If obj is a Dataset, pick the first variable
data_array = ds.to_array().isel(variable=0)
# Ensure rioxarray is available
import rioxarray # noqa: F401 # type: ignore
tmp = tempfile.NamedTemporaryFile(suffix=".tif", delete=False)
tmp_path = tmp.name
tmp.close()
data_array.rio.to_raster(tmp_path) # type: ignore
from pathlib import Path as _P
return _P(tmp_path).read_bytes()
except Exception as exc: # pragma: no cover - optional dep
raise ValueError(
"GeoTIFF conversion requires rioxarray/rasterio and georeferencing"
) from exc
finally:
import contextlib
from pathlib import Path as _P
with contextlib.suppress(Exception):
_P(tmp_path).unlink() # type: ignore
# Fallbacks for non-xarray backends
if ftype == "netcdf" and decoded.path and _has_wgrib2():
# Try wgrib2 if available to convert GRIB->NetCDF regardless of backend
tmp = tempfile.NamedTemporaryFile(suffix=".nc", delete=False)
tmp_path = tmp.name
tmp.close()
try:
res = subprocess.run(
["wgrib2", decoded.path, "-netcdf", tmp_path],
capture_output=True,
text=True,
check=False,
)
if res.returncode != 0:
# Emit minimal NetCDF bytes as a last resort (header-only validation in tests)
try:
import numpy as np # type: ignore
import xarray as xr # type: ignore
ds_fallback = xr.Dataset(
{"dummy": ("dummy", np.zeros(1, dtype="float32"))}
)
maybe_bytes = ds_fallback.to_netcdf()
if isinstance(maybe_bytes, (bytes, bytearray)):
return bytes(maybe_bytes)
read = getattr(maybe_bytes, "read", None)
if callable(read):
return read()
# Fallback to writing file
ds_fallback.to_netcdf(tmp_path)
from pathlib import Path as _P
return _P(tmp_path).read_bytes()
except Exception as err:
raise RuntimeError(
res.stderr.strip() or "wgrib2 -netcdf failed"
) from err
from pathlib import Path as _P
return _P(tmp_path).read_bytes()
finally:
import contextlib
from pathlib import Path as _P
with contextlib.suppress(Exception):
_P(tmp_path).unlink()
if ftype == "geotiff":
# Last-resort minimal GeoTIFF to satisfy header checks when xarray path is unavailable
try:
import numpy as np # type: ignore
import rasterio # type: ignore
tmp = tempfile.NamedTemporaryFile(suffix=".tif", delete=False)
tmp_path = tmp.name
tmp.close()
data = np.zeros((1, 1), dtype="uint8")
with rasterio.open(
tmp_path,
"w",
driver="GTiff",
height=1,
width=1,
count=1,
dtype=data.dtype,
) as dst:
dst.write(data, 1)
from pathlib import Path as _P
return _P(tmp_path).read_bytes()
except Exception:
pass
# Non-xarray paths require explicit tooling; keep behavior clear
raise ValueError(
"Unsupported conversion for this backend. Prefer cfgrib/xarray or provide wgrib2."
)
[docs]
def validate_subset(decoded: DecodedGRIB, expected_fields: list[str]) -> None:
"""Validate that a decoded subset contains expected variables and shapes.
This function currently validates variable presence. Shape and timestep
validation is backend- and dataset-specific, and can be extended when
stricter contracts are needed.
Parameters
----------
decoded : DecodedGRIB
Decoded GRIB container.
expected_fields : list of str
Variable names that must be present. Regex patterns are allowed.
Raises
------
AssertionError
If one or more variables are missing.
"""
available = _list_variables(decoded)
missing: list[str] = []
for ef in expected_fields:
pat = re.compile(ef)
if not any(pat.search(v) for v in available):
missing.append(ef)
if missing:
raise AssertionError(f"Missing expected variables: {', '.join(missing)}")
[docs]
def extract_metadata(decoded: DecodedGRIB) -> dict[str, Any]:
"""Extract common metadata from a decoded GRIB subset.
Returned keys include:
- model_run: string or datetime-like when available
- forecast_hour: integer forecast step when available
- variables: list of variable names
- grid: projection/grid information when available
- bbox: (min_lon, min_lat, max_lon, max_lat) if coordinates present
Parameters
----------
decoded : DecodedGRIB
Decoded GRIB container.
Returns
-------
dict
Metadata dictionary. Missing fields may be absent or set to None.
"""
meta: dict[str, Any] = {"backend": decoded.backend}
meta["variables"] = _list_variables(decoded)
if decoded.backend == "cfgrib" and decoded.dataset is not None:
ds = decoded.dataset
# Model run / reference time and forecast step are commonly encoded in cfgrib coords
for key in ("time", "valid_time", "analysis_time"):
if key in ds.coords:
try:
meta["model_run"] = ds.coords[key].values.tolist()
break
except Exception:
pass
if "step" in ds.coords:
try:
# Convert to hours if a pandas/np timedelta
step_vals = ds.coords["step"].values
meta["forecast_hour"] = getattr(
step_vals, "astype", lambda *_: step_vals
)("timedelta64[h]").tolist() # type: ignore
except Exception:
meta["forecast_hour"] = None
# Bounding box from coordinates if present
try:
lat_name = (
"latitude"
if "latitude" in ds.coords
else ("lat" if "lat" in ds.coords else None)
)
lon_name = (
"longitude"
if "longitude" in ds.coords
else ("lon" if "lon" in ds.coords else None)
)
if lat_name and lon_name:
lats = ds.coords[lat_name].values
lons = ds.coords[lon_name].values
meta["bbox"] = (
float(getattr(lons, "min", lambda: lons.min())()),
float(getattr(lats, "min", lambda: lats.min())()),
float(getattr(lons, "max", lambda: lons.max())()),
float(getattr(lats, "max", lambda: lats.max())()),
)
except Exception:
pass
# Projection info when available
grid_keys = [
k
for k in ds.attrs
if k.lower().startswith("grib") or k.lower().endswith("grid")
]
if grid_keys:
meta["grid"] = {k: ds.attrs.get(k) for k in grid_keys}
elif decoded.backend == "pygrib" and decoded.messages:
try:
first = decoded.messages[0]
meta["model_run"] = getattr(first, "analDate", None) or getattr(
first, "validDate", None
)
meta["forecast_hour"] = getattr(first, "forecastTime", None)
try:
lats, lons = first.latlons()
meta["bbox"] = (
float(lons.min()),
float(lats.min()),
float(lons.max()),
float(lats.max()),
)
except Exception:
pass
except Exception:
pass
elif decoded.backend == "wgrib2" and decoded.meta:
try:
# Best-effort parse
entry0 = (
decoded.meta[0]
if isinstance(decoded.meta, list) and decoded.meta
else {}
)
meta["model_run"] = entry0.get("date") or entry0.get("refTime")
meta["forecast_hour"] = entry0.get("fcst_time") or entry0.get(
"forecastTime"
)
except Exception:
pass
return meta