Source code for homodyne.optimization.nlsq.gradient_monitor

"""Gradient Collapse Monitor for Anti-Degeneracy Defense.

This module provides runtime detection of gradient collapse (physical params
losing gradient signal) with automatic response actions.

Part of Anti-Degeneracy Defense System v2.9.0.
See: docs/specs/anti-degeneracy-defense-v2.9.0.md

Detection Mechanism::

    Monitor the ratio:
        ratio = norm(grad_physical) / norm(grad_per_angle)

    If ratio < threshold for N consecutive iterations:
        - Gradient collapse detected
        - Physical params are losing signal to per-angle params

Response Actions
----------------
- "warn": Log warning only
- "hierarchical": Switch to hierarchical optimization mode
- "reset": Reset per-angle params to mean values
- "abort": Abort optimization and return best params so far
"""

from __future__ import annotations

import collections
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from typing import Literal, cast

import numpy as np

from homodyne.optimization.nlsq.config import safe_float, safe_int
from homodyne.utils.logging import get_logger

logger = get_logger(__name__)


[docs] @dataclass class GradientMonitorConfig: """Configuration for gradient collapse detection. Attributes ---------- enable : bool Whether to enable gradient monitoring. Default True. ratio_threshold : float Ratio of norm(grad_physical) / norm(grad_per_angle) below this triggers detection. Default 0.01 (physical gradient is 1% of per-angle gradient). consecutive_triggers : int Must trigger N consecutive times to confirm collapse. Default 5. response_mode : str Response action on collapse detection: - "warn": Log warning only - "hierarchical": Switch to hierarchical optimization - "reset": Reset per-angle params to mean - "abort": Abort and return best params reset_per_angle_to_mean : bool When resetting, reset per-angle to mean values. Default True. lambda_multiplier_on_collapse : float Multiply regularization λ by this on collapse. Default 10.0. check_interval : int Check every N iterations. Default 1 (every iteration). """ enable: bool = True ratio_threshold: float = 0.01 consecutive_triggers: int = 5 response_mode: Literal["warn", "hierarchical", "reset", "abort"] = "hierarchical" reset_per_angle_to_mean: bool = True lambda_multiplier_on_collapse: float = 10.0 check_interval: int = 1 # NEW (Dec 2025): Watch specific parameter indices for gradient collapse # For laminar_flow: index 2*n_phi + 3 is gamma_dot_t0 watch_parameters: list[int] | None = None watch_threshold: float = 1e-8 # Gradient magnitude below this triggers warning watch_consecutive_triggers: int = ( 3 # Must trigger N consecutive times (like ratio-based) ) watch_min_iteration: int = ( 5 # Skip checks before this iteration (warmup grace period) )
[docs] @classmethod def from_dict(cls, config_dict: dict) -> GradientMonitorConfig: """Create config from dictionary with safe type conversion.""" # Parse watch_parameters list watch_params_raw = config_dict.get("watch_parameters") watch_parameters = None if watch_params_raw is not None: if isinstance(watch_params_raw, list): watch_parameters = [int(x) for x in watch_params_raw] elif isinstance(watch_params_raw, int): watch_parameters = [watch_params_raw] return cls( enable=bool(config_dict.get("enable", True)), ratio_threshold=safe_float(config_dict.get("ratio_threshold"), 0.01), consecutive_triggers=safe_int(config_dict.get("consecutive_triggers"), 5), response_mode=cast( Literal["warn", "hierarchical", "reset", "abort"], config_dict.get("response", "hierarchical"), ), reset_per_angle_to_mean=bool( config_dict.get("reset_per_angle_to_mean", True) ), lambda_multiplier_on_collapse=safe_float( config_dict.get("lambda_multiplier_on_collapse"), 10.0 ), check_interval=safe_int(config_dict.get("check_interval"), 1), watch_parameters=watch_parameters, watch_threshold=safe_float(config_dict.get("watch_threshold"), 1e-8), watch_consecutive_triggers=safe_int( config_dict.get("watch_consecutive_triggers"), 3 ), watch_min_iteration=safe_int(config_dict.get("watch_min_iteration"), 5), )
[docs] @dataclass class CollapseEvent: """Record of a gradient collapse event. Attributes ---------- iteration : int Iteration when collapse was detected. ratio : float Gradient ratio at detection. physical_grad_norm : float Physical parameter gradient norm. per_angle_grad_norm : float Per-angle parameter gradient norm. response_mode : str Response action taken. """ iteration: int ratio: float physical_grad_norm: float per_angle_grad_norm: float response_mode: str
[docs] class GradientCollapseMonitor: """Monitor for detecting and responding to gradient collapse. This monitor tracks the ratio of physical to per-angle gradient norms during optimization. When the ratio drops below a threshold for consecutive iterations, it indicates that physical parameters are losing gradient signal (being absorbed by per-angle parameters). Parameters ---------- config : GradientMonitorConfig Monitor configuration. physical_indices : list of int Indices of physical parameters in the full parameter vector. per_angle_indices : list of int Indices of per-angle parameters in the full parameter vector. Attributes ---------- collapse_detected : bool Whether gradient collapse has been detected. consecutive_count : int Current count of consecutive low-ratio iterations. Notes ----- History is capped at MAX_HISTORY_SIZE to prevent memory leaks during long-running optimizations. Older entries are discarded when the limit is reached. Examples -------- >>> config = GradientMonitorConfig(ratio_threshold=0.01, consecutive_triggers=5) >>> monitor = GradientCollapseMonitor(config, physical_indices=[6,7,8,9,10,11,12], ... per_angle_indices=list(range(6))) >>> for iter in range(100): ... gradients = compute_gradients(params) ... status = monitor.check(gradients, iter) ... if status == "COLLAPSE_DETECTED": ... response = monitor.get_response() ... # Take action based on response """ # Maximum history entries to prevent memory leaks # At ~100 bytes per entry, 1000 entries = ~100 KB max MAX_HISTORY_SIZE: int = 1000
[docs] def __init__( self, config: GradientMonitorConfig, physical_indices: Sequence[int] | np.ndarray, per_angle_indices: Sequence[int] | np.ndarray, ): """Initialize gradient collapse monitor. Parameters ---------- config : GradientMonitorConfig Monitor configuration. physical_indices : Sequence[int] or np.ndarray Indices of physical parameters. Converted to numpy array internally to support both NumPy and JAX array indexing. per_angle_indices : Sequence[int] or np.ndarray Indices of per-angle parameters (or Fourier coefficients when Fourier reparameterization is active). Converted to numpy array internally. Notes ----- When Fourier reparameterization is active, per_angle_indices should correspond to Fourier coefficient indices (typically 10 for order=2), not independent per-angle indices (2 * n_phi). """ self.config = config # Use numpy arrays for indices to support both NumPy and JAX array indexing # JAX arrays don't support Python list indexing (non-tuple sequence error) self.physical_indices: np.ndarray = np.asarray(physical_indices, dtype=np.intp) self.per_angle_indices: np.ndarray = np.asarray( per_angle_indices, dtype=np.intp ) # Use a deque with bounded maxlen so that appending automatically # drops the oldest entry — O(1) on both ends vs O(n) list.pop(0). self.history: collections.deque[dict] = collections.deque( maxlen=self.MAX_HISTORY_SIZE ) self.consecutive_count: int = 0 self.collapse_detected: bool = False self.collapse_events: list[CollapseEvent] = [] # Track best params for recovery self.best_params: np.ndarray | None = None self.best_loss: float = float("inf") # Track consecutive triggers for watched parameters self._watch_consecutive_counts: dict[int, int] = {} self._watch_collapse_detected: dict[int, bool] = {} if config.watch_parameters: for param_idx in config.watch_parameters: self._watch_consecutive_counts[param_idx] = 0 self._watch_collapse_detected[param_idx] = False
[docs] def check( self, gradients: np.ndarray, iteration: int, params: np.ndarray | None = None, loss: float | None = None, ) -> str: """Check for gradient collapse. Parameters ---------- gradients : np.ndarray Full gradient vector. iteration : int Current iteration number. params : np.ndarray, optional Current parameters (for response actions and tracking). loss : float, optional Current loss value (for tracking best params). Returns ------- str Status: "OK", "WARNING", "COLLAPSE_DETECTED" """ if not self.config.enable: return "OK" # Skip if not on check interval if iteration % self.config.check_interval != 0: return "OK" # Track best params if params is not None and loss is not None: if loss < self.best_loss: self.best_loss = loss self.best_params = params.copy() # Compute gradient norms physical_grad_norm = np.linalg.norm(gradients[self.physical_indices]) per_angle_grad_norm = np.linalg.norm(gradients[self.per_angle_indices]) # Compute ratio (avoid division by zero) ratio = physical_grad_norm / (per_angle_grad_norm + 1e-12) # Record history. deque(maxlen=MAX_HISTORY_SIZE) drops the oldest # entry automatically on append — no manual pop loop needed. self.history.append( { "iteration": iteration, "physical_grad_norm": float(physical_grad_norm), "per_angle_grad_norm": float(per_angle_grad_norm), "ratio": float(ratio), } ) # Check for collapse if ratio < self.config.ratio_threshold: self.consecutive_count += 1 else: self.consecutive_count = 0 # Re-arm detection after recovery so future collapses are tracked self.collapse_detected = False # Trigger collapse detection (re-arms after recovery) if self.consecutive_count >= self.config.consecutive_triggers: if not self.collapse_detected: self.collapse_detected = True event = CollapseEvent( iteration=iteration, ratio=float(ratio), physical_grad_norm=float(physical_grad_norm), per_angle_grad_norm=float(per_angle_grad_norm), response_mode=self.config.response_mode, ) self.collapse_events.append(event) logger.warning( f"GRADIENT COLLAPSE DETECTED at iteration {iteration}! " f"ratio={ratio:.6f} < threshold={self.config.ratio_threshold}" ) logger.warning(f" Physical gradient norm: {physical_grad_norm:.6e}") logger.warning(f" Per-angle gradient norm: {per_angle_grad_norm:.6e}") logger.warning(f" Response mode: {self.config.response_mode}") return "COLLAPSE_DETECTED" # NEW (Dec 2025): Check watched parameters for gradient collapse # This specifically monitors parameters like gamma_dot_t0 that can # collapse to zero during L-BFGS warmup when data is angle-sequential # Uses consecutive trigger mechanism to avoid false positives during warmup if self.config.watch_parameters is not None: # Skip checks before minimum iteration (warmup grace period) if iteration >= self.config.watch_min_iteration: for param_idx in self.config.watch_parameters: if param_idx < len(gradients): grad_mag = abs(float(gradients[param_idx])) # Store in history for diagnostics self.history[-1][f"watched_param_{param_idx}_grad"] = grad_mag if grad_mag < self.config.watch_threshold: self._watch_consecutive_counts[param_idx] += 1 else: # Reset consecutive count when gradient recovers self._watch_consecutive_counts[param_idx] = 0 self._watch_collapse_detected[param_idx] = False # Check for collapse (consecutive triggers threshold) if ( self._watch_consecutive_counts[param_idx] >= self.config.watch_consecutive_triggers and not self._watch_collapse_detected[param_idx] ): self._watch_collapse_detected[param_idx] = True logger.warning( f"WATCHED PARAMETER GRADIENT COLLAPSE CONFIRMED at iteration {iteration}! " f"param[{param_idx}] gradient={grad_mag:.2e} < " f"threshold={self.config.watch_threshold:.2e} " f"for {self._watch_consecutive_counts[param_idx]} consecutive iterations" ) elif ( grad_mag < self.config.watch_threshold and self._watch_consecutive_counts[param_idx] == 1 ): # Log debug info on first trigger (not yet confirmed) logger.debug( f"Watched parameter gradient low at iteration {iteration}: " f"param[{param_idx}] gradient={grad_mag:.2e}" ) if self.consecutive_count > 0: return "WARNING" return "OK"
[docs] def get_response(self) -> dict | None: """Get response action after collapse detection. Returns ------- dict or None Response action dictionary, or None if no collapse. """ if not self.collapse_detected: return None return { "mode": self.config.response_mode, "reset_per_angle": self.config.reset_per_angle_to_mean, "lambda_multiplier": self.config.lambda_multiplier_on_collapse, "best_params": self.best_params, "best_loss": self.best_loss, "history": list(self.history)[-10:], # Last 10 entries "collapse_events": self.collapse_events, }
[docs] def compute_reset_params(self, params: np.ndarray, n_phi: int) -> np.ndarray: """Compute parameters with per-angle values reset to mean. Parameters ---------- params : np.ndarray Current parameter vector. n_phi : int Number of phi angles. Returns ------- np.ndarray Parameters with per-angle values reset. """ reset_params = params.copy() # Assuming per-angle layout: [contrast_0..n_phi, offset_0..n_phi, physical...] if len(self.per_angle_indices) >= 2: # Reset contrast to mean contrast_indices = self.per_angle_indices[:n_phi] contrast_mean = np.nanmean(params[contrast_indices]) reset_params[contrast_indices] = contrast_mean # Reset offset to mean offset_indices = self.per_angle_indices[n_phi : 2 * n_phi] offset_mean = np.nanmean(params[offset_indices]) reset_params[offset_indices] = offset_mean logger.info( f"Reset per-angle params: contrast={contrast_mean:.4f}, " f"offset={offset_mean:.4f}" ) return reset_params
[docs] def reset(self) -> None: """Reset monitor state for new optimization run.""" self.history = collections.deque(maxlen=self.MAX_HISTORY_SIZE) self.consecutive_count = 0 self.collapse_detected = False self.collapse_events = [] self.best_params = None self.best_loss = float("inf") # Reset watched parameter tracking if self.config.watch_parameters: for param_idx in self.config.watch_parameters: self._watch_consecutive_counts[param_idx] = 0 self._watch_collapse_detected[param_idx] = False
[docs] def get_diagnostics(self) -> dict: """Get monitoring diagnostics for logging. Returns ------- dict Diagnostic information. """ if not self.history: return { "enabled": self.config.enable, "n_checks": 0, } ratios = [h["ratio"] for h in self.history] physical_norms = [h["physical_grad_norm"] for h in self.history] diag = { "enabled": self.config.enable, "n_checks": len(self.history), "min_ratio": min(ratios), "max_ratio": max(ratios), "mean_ratio": float(np.nanmean(ratios)), "final_ratio": ratios[-1] if ratios else None, "min_physical_grad": min(physical_norms), "max_physical_grad": max(physical_norms), "mean_physical_grad": float(np.nanmean(physical_norms)), "collapse_detected": self.collapse_detected, "consecutive_triggers": self.consecutive_count, "n_collapse_events": len(self.collapse_events), "response_mode": self.config.response_mode, "threshold": self.config.ratio_threshold, } # Add watched parameter diagnostics if self.config.watch_parameters: diag["watch_parameters"] = self.config.watch_parameters diag["watch_consecutive_counts"] = dict(self._watch_consecutive_counts) diag["watch_collapse_detected"] = dict(self._watch_collapse_detected) return diag
[docs] def log_summary(self) -> None: """Log monitoring summary.""" diag = self.get_diagnostics() if not diag["enabled"]: logger.info("Gradient monitoring: DISABLED") return if diag["n_checks"] == 0: logger.info("Gradient monitoring: No checks performed") return logger.info("Gradient Collapse Monitor Summary:") logger.info(f" Checks performed: {diag['n_checks']}") logger.info( f" Gradient ratio: min={diag['min_ratio']:.6f}, " f"max={diag['max_ratio']:.6f}, mean={diag['mean_ratio']:.6f}" ) logger.info(f" Threshold: {diag['threshold']}") if diag["collapse_detected"]: logger.warning(f" COLLAPSE DETECTED: {diag['n_collapse_events']} events") logger.warning(f" Response mode: {diag['response_mode']}") else: logger.info(" Status: No collapse detected")
[docs] def create_gradient_function_with_monitoring( grad_fn: Callable[[np.ndarray], np.ndarray], monitor: GradientCollapseMonitor, ) -> Callable[[np.ndarray], np.ndarray]: """Wrap gradient function to include monitoring. Parameters ---------- grad_fn : Callable[[np.ndarray], np.ndarray] Original gradient function. monitor : GradientCollapseMonitor Monitor instance. Returns ------- Callable[[np.ndarray], np.ndarray] Wrapped gradient function that records to monitor. """ iteration_counter = [0] # Mutable counter def monitored_grad_fn(params: np.ndarray) -> np.ndarray: gradients = grad_fn(params) monitor.check(gradients, iteration_counter[0], params=params) iteration_counter[0] += 1 return gradients return monitored_grad_fn