Source code for homodyne.optimization.checkpoint_manager

"""Checkpoint management for streaming optimization.

This module provides checkpoint save/load functionality for fault-tolerant
streaming optimization. Checkpoints are stored in HDF5 format with compression
and checksum validation.

Key Features:
- HDF5-based checkpoint storage with compression
- Checksum validation for integrity
- Automatic cleanup of old checkpoints
- Version compatibility checking
- Fast save time (< 2 seconds target)

The CheckpointManager complements NLSQ's built-in checkpointing by storing
homodyne-specific state (batch statistics, recovery actions, best parameters).
"""

import hashlib
import os
import time
from pathlib import Path

import h5py
import numpy as np

from homodyne._version import __version__
from homodyne.optimization.exceptions import NLSQCheckpointError
from homodyne.utils.logging import get_logger

# T059: Module-level logger for checkpoint management
logger = get_logger(__name__)


[docs] class CheckpointManager: """Manage checkpoint save/load for streaming optimization. This class provides checkpoint management for homodyne-specific state during streaming optimization. It complements NLSQ's built-in checkpoint functionality by storing additional metadata, batch statistics, and recovery action history. Features: - HDF5-based checkpoint storage with compression - Checksum validation for integrity - Automatic cleanup of old checkpoints - Version compatibility checking Attributes ---------- checkpoint_dir : Path Directory for checkpoint files checkpoint_frequency : int Save checkpoint every N batches keep_last_n : int Keep last N checkpoints (default: 3) enable_compression : bool Use HDF5 compression (default: True) Examples -------- >>> manager = CheckpointManager("./checkpoints", checkpoint_frequency=10) >>> # Save checkpoint >>> path = manager.save_checkpoint( ... batch_idx=10, ... parameters=params, ... optimizer_state={'iteration': 42}, ... loss=0.123, ... ) >>> # Load checkpoint >>> data = manager.load_checkpoint(path) >>> params = data['parameters'] >>> batch_idx = data['batch_idx'] """
[docs] def __init__( self, checkpoint_dir: str | Path, checkpoint_frequency: int = 10, keep_last_n: int = 3, enable_compression: bool = True, ): """Initialize checkpoint manager. Parameters ---------- checkpoint_dir : str or Path Directory for checkpoint files checkpoint_frequency : int, optional Save checkpoint every N batches, by default 10 keep_last_n : int, optional Keep last N checkpoints, by default 3 enable_compression : bool, optional Use HDF5 compression, by default True """ self.checkpoint_dir = Path(checkpoint_dir) self.checkpoint_frequency = checkpoint_frequency self.keep_last_n = keep_last_n self.enable_compression = enable_compression # Create checkpoint directory if it doesn't exist try: self.checkpoint_dir.mkdir(parents=True, exist_ok=True) except OSError as e: raise NLSQCheckpointError( f"Cannot create checkpoint directory '{self.checkpoint_dir}': {e}", error_context={"checkpoint_dir": str(self.checkpoint_dir)}, ) from e # T059: Log checkpoint manager initialization logger.debug( f"CheckpointManager initialized: dir={self.checkpoint_dir}, " f"frequency={checkpoint_frequency}, keep_last={keep_last_n}" )
[docs] def save_checkpoint( self, batch_idx: int, parameters: np.ndarray, optimizer_state: dict, loss: float, metadata: dict | None = None, ) -> Path: """Save checkpoint to HDF5 file. Saves checkpoint with compression, checksum validation, and version information. Target save time is < 2 seconds for typical parameter sets. Parameters ---------- batch_idx : int Current batch index parameters : np.ndarray Current parameter values optimizer_state : dict Optimizer internal state loss : float Current loss value metadata : dict, optional Additional metadata (batch statistics, recovery actions, etc.) Returns ------- Path Path to saved checkpoint file Raises ------ NLSQCheckpointError If checkpoint save fails Notes ----- Checkpoint file naming: `homodyne_state_batch_{batch_idx:04d}.h5` """ start_time = time.time() # Generate checkpoint filename checkpoint_path = ( self.checkpoint_dir / f"homodyne_state_batch_{batch_idx:04d}.h5" ) try: # Serialize optimizer state to bytes for checksum (trusted checkpoint) import pickle # nosec B403 optimizer_bytes = pickle.dumps(optimizer_state) # Compute checksum checksum = self._compute_checksum(optimizer_bytes) # Save to HDF5 with compression with h5py.File(checkpoint_path, "w") as f: # Save parameters if self.enable_compression: f.create_dataset( "parameters", data=parameters, compression="gzip", compression_opts=4, ) else: f.create_dataset("parameters", data=parameters) # Save optimizer state as pickled bytes f.create_dataset("optimizer_state", data=np.void(optimizer_bytes)) # Save scalar attributes f.attrs["batch_idx"] = batch_idx f.attrs["loss"] = loss f.attrs["checksum"] = checksum f.attrs["version"] = __version__ f.attrs["timestamp"] = time.time() # Save metadata if provided if metadata is not None: metadata_group = f.create_group("metadata") for key, value in metadata.items(): if isinstance(value, (int, float, str, bool)): metadata_group.attrs[key] = value elif isinstance(value, (list, dict)): # Store complex types as JSON strings import json metadata_group.attrs[key] = json.dumps(value) elif isinstance(value, np.ndarray): metadata_group.create_dataset(key, data=value) os.chmod(checkpoint_path, 0o600) elapsed = time.time() - start_time # T059: Log checkpoint save completion file_size_kb = checkpoint_path.stat().st_size / 1024 logger.info( f"Checkpoint saved: batch={batch_idx}, loss={loss:.6g}, " f"file={checkpoint_path.name} ({file_size_kb:.1f} KB), " f"time={elapsed:.2f}s" ) # Check if save time exceeds target if elapsed > 2.0: logger.warning( f"Checkpoint save took {elapsed:.2f}s (target: < 2s). " f"Consider disabling compression or reducing checkpoint frequency." ) return checkpoint_path except (OSError, ValueError, TypeError, RuntimeError) as e: raise NLSQCheckpointError( f"Failed to save checkpoint at batch {batch_idx}: {e}", error_context={ "batch_idx": batch_idx, "checkpoint_path": str(checkpoint_path), }, ) from e
[docs] def load_checkpoint(self, checkpoint_path: Path) -> dict: """Load and validate checkpoint. Loads checkpoint from HDF5 file and validates checksum integrity. Security: Uses pickle.loads() for optimizer state deserialization. This is safe because checkpoint files are created exclusively by save_checkpoint() with checksum validation, stored in application-controlled output directories, and the serialized bytes are embedded within HDF5 containers created by this class. Parameters ---------- checkpoint_path : Path Path to checkpoint file Returns ------- dict Checkpoint data with keys: - batch_idx: int - Batch index when checkpoint was saved - parameters: np.ndarray - Parameter values - optimizer_state: dict - Optimizer internal state - loss: float - Loss value at checkpoint - metadata: dict - Additional metadata (if available) - version: str - Homodyne version - timestamp: float - Unix timestamp Raises ------ NLSQCheckpointError If checkpoint is corrupted, invalid, or missing """ if not checkpoint_path.exists(): raise NLSQCheckpointError( f"Checkpoint file not found: {checkpoint_path}", error_context={"checkpoint_path": str(checkpoint_path)}, ) try: with h5py.File(checkpoint_path, "r") as f: # Load required data parameters = f["parameters"][:] optimizer_bytes = bytes(f["optimizer_state"][()]) # Load attributes batch_idx = f.attrs["batch_idx"] loss = f.attrs["loss"] stored_checksum = f.attrs["checksum"] version = f.attrs.get("version", "unknown") timestamp = f.attrs.get("timestamp", 0.0) # Version compatibility check if version != __version__: logger.warning( f"Checkpoint version {version} != current {__version__}. " "Optimizer state may be incompatible." ) # Validate checksum computed_checksum = self._compute_checksum(optimizer_bytes) if computed_checksum != stored_checksum: raise NLSQCheckpointError( "Checkpoint checksum mismatch. File may be corrupted.", error_context={ "checkpoint_path": str(checkpoint_path), "stored_checksum": stored_checksum, "computed_checksum": computed_checksum, }, ) # Deserialize optimizer state import pickle # nosec B403: internal checkpoint serialization optimizer_state = pickle.loads(optimizer_bytes) # nosec B301 # Load metadata if available metadata = {} if "metadata" in f: metadata_group = f["metadata"] for key in metadata_group.attrs: value = metadata_group.attrs[key] # Try to parse JSON strings if isinstance(value, str): try: import json metadata[key] = json.loads(value) except (json.JSONDecodeError, TypeError): metadata[key] = value else: metadata[key] = value # Load array datasets for key in metadata_group.keys(): metadata[key] = metadata_group[key][:] # T059: Log checkpoint load completion logger.info( f"Checkpoint loaded: batch={batch_idx}, loss={loss:.6g}, " f"version={version}, file={checkpoint_path.name}" ) return { "batch_idx": int(batch_idx), "parameters": parameters, "optimizer_state": optimizer_state, "loss": float(loss), "metadata": metadata, "version": version, "timestamp": float(timestamp), } except (OSError, KeyError, ValueError) as e: raise NLSQCheckpointError( f"Failed to load checkpoint: {e}", error_context={"checkpoint_path": str(checkpoint_path)}, ) from e
[docs] def find_latest_checkpoint(self) -> Path | None: """Find most recent valid checkpoint. Searches checkpoint directory for valid checkpoint files and returns the one with the highest batch index. Returns ------- Path or None Path to latest checkpoint, or None if none exist Notes ----- Only returns checkpoints that pass validation. """ # Find all checkpoint files checkpoint_files = list(self.checkpoint_dir.glob("homodyne_state_batch_*.h5")) if not checkpoint_files: return None # Sort by batch index (extracted from filename) def get_batch_idx(path: Path) -> int: try: # Extract batch index from filename: homodyne_state_batch_0010.h5 return int(path.stem.split("_")[-1]) except (ValueError, IndexError): return -1 checkpoint_files.sort(key=get_batch_idx, reverse=True) # Find first valid checkpoint for checkpoint_path in checkpoint_files: if self.validate_checkpoint(checkpoint_path): # T060: Log recovery point found logger.info(f"Found valid recovery checkpoint: {checkpoint_path.name}") return checkpoint_path logger.debug("No valid checkpoint found for recovery") return None
[docs] def cleanup_old_checkpoints(self) -> list[Path]: """Remove old checkpoints, keeping last N. Keeps the most recent N checkpoints based on batch index and removes older ones to manage disk space. Returns ------- list of Path Paths of deleted checkpoints Notes ----- Only deletes checkpoints, never removes the keep_last_n most recent ones. """ # Find all checkpoint files checkpoint_files = list(self.checkpoint_dir.glob("homodyne_state_batch_*.h5")) if len(checkpoint_files) <= self.keep_last_n: return [] # Nothing to clean up # Sort by batch index def get_batch_idx(path: Path) -> int: try: return int(path.stem.split("_")[-1]) except (ValueError, IndexError): return -1 checkpoint_files.sort(key=get_batch_idx, reverse=True) # Keep last N, delete the rest to_delete = checkpoint_files[self.keep_last_n :] deleted = [] for checkpoint_path in to_delete: try: checkpoint_path.unlink() deleted.append(checkpoint_path) except OSError: # Log warning but continue logger.warning(f"Failed to delete checkpoint: {checkpoint_path}") # T059: Log cleanup results if deleted: logger.debug( f"Checkpoint cleanup: deleted {len(deleted)} old checkpoints, " f"kept {self.keep_last_n}" ) return deleted
[docs] def validate_checkpoint(self, checkpoint_path: Path) -> bool: """Validate checkpoint integrity. Checks that checkpoint file exists, can be opened, has required fields, and passes checksum validation. Parameters ---------- checkpoint_path : Path Path to checkpoint file Returns ------- bool True if valid, False otherwise """ if not checkpoint_path.exists(): return False try: with h5py.File(checkpoint_path, "r") as f: # Check required fields required_fields = ["parameters", "optimizer_state"] required_attrs = ["batch_idx", "loss", "checksum"] for field in required_fields: if field not in f: return False for attr in required_attrs: if attr not in f.attrs: return False # Validate checksum optimizer_bytes = bytes(f["optimizer_state"][()]) stored_checksum = f.attrs["checksum"] computed_checksum = self._compute_checksum(optimizer_bytes) if computed_checksum != stored_checksum: return False return True except (OSError, KeyError, ValueError): return False
def _compute_checksum(self, data: bytes) -> str: """Compute SHA256 checksum of data. Parameters ---------- data : bytes Data to checksum Returns ------- str Hexadecimal checksum string """ return hashlib.sha256(data).hexdigest()