Source code for homodyne.utils.async_io
"""Async I/O utilities for pipeline overlap.
Thread-based prefetching and background writing to hide I/O latency.
GIL-safe since HDF5 and numpy release the GIL during I/O.
"""
from __future__ import annotations
import json
from collections.abc import Callable, Iterator
from concurrent.futures import Future, ThreadPoolExecutor
from pathlib import Path
from threading import Lock, Thread
from typing import Any, TypeVar
import numpy as np
from homodyne.utils.logging import get_logger
logger = get_logger(__name__)
T = TypeVar("T")
R = TypeVar("R")
[docs]
class PrefetchLoader(Iterator[R]):
"""Thread-based prefetch iterator.
Loads the next item in a background thread while the current
item is being processed.
Parameters
----------
source : Iterator[T]
Source items to load.
load_fn : callable
Transform applied to each item in background thread.
"""
[docs]
def __init__(self, source: Iterator[T], load_fn: Callable[[T], R]) -> None:
self._source = source
self._load_fn = load_fn
self._prefetched: R | None = None
self._has_prefetched = False
self._exhausted = False
self._thread: Thread | None = None
self._error: Exception | None = None
self._start_prefetch()
def _start_prefetch(self) -> None:
if self._exhausted:
return
def _load() -> None:
try:
item = next(self._source)
self._prefetched = self._load_fn(item)
self._has_prefetched = True
except StopIteration:
self._exhausted = True
except Exception as e:
self._error = e
self._exhausted = True
# daemon=True: prefetch is read-only; safe to abandon on exit
self._thread = Thread(target=_load, daemon=True)
self._thread.start()
def __iter__(self) -> PrefetchLoader[R]:
return self
def __next__(self) -> R:
if self._thread is not None:
self._thread.join(timeout=120.0)
if self._thread.is_alive():
self._exhausted = True
self._thread = None
raise RuntimeError(
"Prefetch thread did not complete within 120s timeout"
)
self._thread = None
if self._error is not None:
raise self._error
if self._exhausted and not self._has_prefetched:
raise StopIteration
result = self._prefetched
self._has_prefetched = False
self._prefetched = None
self._start_prefetch()
return result # type: ignore[return-value]
[docs]
class AsyncWriter:
"""Background thread pool for result serialization.
Parameters
----------
max_workers : int
Maximum concurrent write threads.
"""
[docs]
def __init__(self, max_workers: int = 2) -> None:
self._executor = ThreadPoolExecutor(max_workers=max_workers)
self._futures: list[Future[None]] = []
self._lock = Lock()
self._shutdown = False
[docs]
def submit_npz(self, path: Path, data: dict[str, np.ndarray]) -> None:
"""Write NPZ file in background."""
def _write() -> None:
try:
self._write_npz(path, data)
except Exception:
logger.error("Failed to write NPZ: %s", path)
raise
future = self._executor.submit(_write)
with self._lock:
self._futures.append(future)
[docs]
def submit_json(self, path: Path, data: dict[str, Any]) -> None:
"""Write JSON file in background."""
def _write() -> None:
try:
self._write_json(path, data)
except Exception:
logger.error("Failed to write JSON: %s", path)
raise
future = self._executor.submit(_write)
with self._lock:
self._futures.append(future)
[docs]
def submit_task(self, fn: Callable[..., None], *args: Any, **kwargs: Any) -> None:
"""Submit an arbitrary callable for background execution."""
future = self._executor.submit(fn, *args, **kwargs)
with self._lock:
self._futures.append(future)
[docs]
def wait_all(self, timeout: float = 60.0) -> list[Exception]:
"""Wait for all pending writes. Returns list of errors.
TimeoutError is not treated as a failure — the write is still
in progress and will complete during shutdown().
"""
with self._lock:
pending = list(self._futures)
errors: list[Exception] = []
for future in pending:
try:
future.result(timeout=timeout)
except TimeoutError:
logger.info(
"Background write still in progress after %.0fs "
"(will complete during shutdown)",
timeout,
)
except Exception as e:
logger.warning("Background write failed (%s): %s", type(e).__name__, e)
logger.debug("Background write traceback:", exc_info=True)
errors.append(e)
# Remove only the futures we waited on; concurrent submits are preserved
with self._lock:
for f in pending:
try:
self._futures.remove(f)
except ValueError:
pass
return errors
[docs]
def shutdown(self) -> None:
"""Wait for pending writes and shut down. Idempotent."""
if self._shutdown:
return
self._shutdown = True
self.wait_all()
self._executor.shutdown(wait=True)
@staticmethod
def _write_npz(path: Path, data: dict[str, np.ndarray]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
np.savez_compressed(str(path), **data) # type: ignore[arg-type]
@staticmethod
def _write_json(path: Path, data: dict[str, Any]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2, default=str)
def __enter__(self) -> AsyncWriter:
return self
def __exit__(self, *exc: object) -> None:
self.shutdown()