Source code for zyra.utils.grib

# SPDX-License-Identifier: Apache-2.0
"""GRIB utilities used by connectors and managers.

This module centralizes protocol-agnostic helpers for working with GRIB2
index files (.idx), calculating byte ranges, and performing parallel
multi-range downloads.

Notes
-----
- The `.idx` file path is assumed to be the GRIB file path with a `.idx`
  suffix appended, unless a path already ending in `.idx` is provided.
- Pattern filtering uses regular expressions via :func:`re.search`.
"""

from __future__ import annotations

import re
from concurrent.futures import ThreadPoolExecutor, as_completed
from io import BytesIO
from typing import Callable, Iterable


[docs] def ensure_idx_path(path: str) -> str: """Return the `.idx` path for a GRIB file or pass through an explicit idx path. Parameters ---------- path : str The GRIB file path or `.idx` path. Returns ------- str ``path + '.idx'`` if ``path`` does not already end with ``.idx``, otherwise returns ``path`` unchanged. """ return path if path.endswith(".idx") else f"{path}.idx"
[docs] def parse_idx_lines(idx_bytes_or_text: bytes | str) -> list[str]: """Parse a GRIB index payload into non-empty lines. Parameters ---------- idx_bytes_or_text : bytes or str Raw `.idx` file content. Returns ------- list of str The non-empty, newline-split lines of the index. """ if isinstance(idx_bytes_or_text, (bytes, bytearray)): text = idx_bytes_or_text.decode() else: text = idx_bytes_or_text lines = [ln for ln in text.splitlines() if ln] return lines
[docs] def idx_to_byteranges(lines: list[str], search_regex: str) -> dict[str, str]: """Convert `.idx` lines plus a variable regex into HTTP Range headers. Parameters ---------- lines : list of str Lines from a GRIB `.idx` file. search_regex : str Regular expression to select desired GRIB lines (e.g., "PRES:surface"). Returns ------- dict Mapping of ``{"bytes=start-end": matching_idx_line}`` suitable for use as Range headers. """ expr = re.compile(search_regex) byte_ranges: dict[str, str] = {} for n, line in enumerate(lines, start=1): if expr.search(line): parts = line.split(":") if len(parts) < 2: continue rangestart = parts[1] # End is the start of the next record (if present) rangeend = "" if n < len(lines): nxt = lines[n].split(":") if len(nxt) > 1: try: rangeend = str(int(nxt[1]) - 1) except ValueError: rangeend = nxt[1] byte_ranges[f"bytes={rangestart}-{rangeend}"] = line return byte_ranges
[docs] def compute_chunks(total_size: int, chunk_size: int = 500 * 1024 * 1024) -> list[str]: """Compute contiguous byte ranges that partition a file. The final range uses the file size as the inclusive end byte (matching the behavior used by ``nodd_fetch.py``). Parameters ---------- total_size : int Size of the file in bytes. chunk_size : int, default 500MB Upper bound for each chunk. Returns ------- list of str Range header strings, e.g., ``["bytes=0-1048575", ...]``. """ if total_size <= 0: return [] ranges: list[str] = [] start_byte = 0 # Build split points like [chunk, 2*chunk, ...] up to but not including total_size split_points = list(range(0, total_size, chunk_size))[1:] for next_byte in split_points: ranges.append(f"bytes={start_byte}-{int(next_byte) - 1}") start_byte = next_byte ranges.append(f"bytes={start_byte}-{int(total_size)}") return ranges
[docs] def parallel_download_byteranges( download_func: Callable[[str, str], bytes], key_or_url: str, byte_ranges: Iterable[str], *, max_workers: int = 10, ) -> bytes: """Download multiple byte ranges in parallel and concatenate in input order. Parameters ---------- download_func : Callable Function accepting ``(key_or_url, range_header)`` and returning bytes. key_or_url : str The resource identifier for the remote object. byte_ranges : Iterable[str] Iterable of Range header strings (e.g., "bytes=0-99"). Order matters and is preserved in the output concatenation. max_workers : int, default=10 Maximum number of worker threads. Returns ------- bytes The concatenated payload of all requested ranges in the input order. """ # Preserve order by indexing the ranges and reassembling in order. indexed = list(enumerate(byte_ranges)) if not indexed: return b"" results: dict[int, bytes] = {} with ThreadPoolExecutor(max_workers=max_workers) as executor: future_map = { executor.submit(download_func, key_or_url, rng): idx for idx, rng in indexed } for fut in as_completed(future_map): idx = future_map[fut] results[idx] = fut.result() or b"" buf = BytesIO() for idx, _ in indexed: buf.write(results.get(idx, b"")) return buf.getvalue()