Source code for homodyne.optimization.nlsq.adaptive_regularization

"""Adaptive Relative Regularization for Anti-Degeneracy Defense.

This module implements CV-based (Coefficient of Variation) regularization
that scales properly with data, replacing the ineffective absolute variance
regularization.

Part of Anti-Degeneracy Defense System v2.9.0.
See: docs/specs/anti-degeneracy-defense-v2.9.0.md

Mathematical Formulation::

    Current (ineffective):
        L_reg = lambda * Var(params) * n_points

    Proposed (CV-based):
        CV = std(params) / abs(mean(params))
        L_reg = lambda * CV^2 * MSE * n_points

    Auto-tuned lambda:
        lambda = target_contribution / target_cv^2

        Example: Allow 10% variation (CV=0.1), contribute 10% to loss
        lambda = 0.1 / 0.01 = 10
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Literal, cast

import jax.numpy as jnp
import numpy as np

from homodyne.optimization.nlsq.config import safe_float
from homodyne.utils.logging import get_logger

logger = get_logger(__name__)


[docs] @dataclass class AdaptiveRegularizationConfig: """Configuration for adaptive relative regularization. Attributes ---------- enable : bool Whether to enable regularization. Default True. mode : str Regularization mode: "absolute", "relative", or "auto". - "absolute": Original variance-based (L_reg = λ × Var × n) - "relative": CV-based (L_reg = λ × CV² × MSE × n) - "auto": Use relative for n_phi > 5, absolute otherwise lambda_base : float Base regularization strength. Default 1.0 (100× stronger than v2.8). target_cv : float Target coefficient of variation. Default 0.10 (10% variation allowed). target_contribution : float Target fraction of MSE to contribute. Default 0.10 (10% of loss). auto_tune_lambda : bool Whether to auto-compute λ from target_cv and target_contribution. max_cv : float Maximum allowed CV before hard constraint warning. Default 0.20. group_indices : list of tuple, optional Parameter group indices [(start, end), ...]. Auto-computed if None. """ enable: bool = True mode: Literal["absolute", "relative", "auto"] = "relative" # Relative (CV-based) mode settings lambda_base: float = 1.0 # 100× stronger than v2.8 default of 0.01 target_cv: float = 0.10 # 10% variation target target_contribution: float = 0.10 # 10% of MSE contribution # Auto-tune λ based on target_cv and target_contribution auto_tune_lambda: bool = True # Maximum allowed CV before hard constraint kicks in max_cv: float = 0.20 # 20% max variation # Group indices (auto-computed if None) group_indices: list[tuple[int, int]] | None = None
[docs] @classmethod def from_dict(cls, config_dict: dict) -> AdaptiveRegularizationConfig: """Create config from dictionary with safe type conversion.""" return cls( enable=bool(config_dict.get("enable", True)), mode=cast( Literal["absolute", "relative", "auto"], config_dict.get("mode", "relative"), ), lambda_base=safe_float(config_dict.get("lambda"), 1.0), target_cv=safe_float(config_dict.get("target_cv"), 0.10), target_contribution=safe_float( config_dict.get("target_contribution"), 0.10 ), auto_tune_lambda=bool(config_dict.get("auto_tune_lambda", True)), max_cv=safe_float(config_dict.get("max_cv"), 0.20), group_indices=config_dict.get("group_indices"), )
[docs] class AdaptiveRegularizer: """CV-based adaptive regularization for per-angle parameters. This regularizer addresses the fundamental problem where absolute variance regularization (λ=0.01) contributed only ~0.05% to total loss, providing no effective constraint on per-angle parameter variation. The CV-based approach ensures regularization scales properly: - CV is dimensionless (ratio of std to mean) - Auto-tuned λ makes regularization ~10% of MSE - Prevents per-angle parameters from absorbing physical signals Parameters ---------- config : AdaptiveRegularizationConfig Regularization configuration. n_phi : int Number of unique phi angles. Attributes ---------- lambda_value : float Effective regularization strength (auto-tuned or from config). group_indices : list of tuple Parameter groups to regularize. Examples -------- >>> config = AdaptiveRegularizationConfig(target_cv=0.10, target_contribution=0.10) >>> regularizer = AdaptiveRegularizer(config, n_phi=23) >>> reg_term = regularizer.compute_regularization( ... params, mse=0.04, n_points=23_000_000 ... ) """
[docs] def __init__( self, config: AdaptiveRegularizationConfig, n_phi: int, n_params: int | None = None, ): """Initialize adaptive regularizer. Parameters ---------- config : AdaptiveRegularizationConfig Regularization configuration. n_phi : int Number of unique phi angles. n_params : int, optional Actual parameter vector length. When provided and less than 2 * n_phi + n_physical, auto_averaged mode is assumed (2 scaling params instead of 2 * n_phi). """ self.config = config self.n_phi = n_phi # Auto-compute group indices if not provided if config.group_indices is None: # For auto_averaged mode, the parameter vector has only 2 scaling # params (1 contrast_avg + 1 offset_avg) instead of 2*n_phi. # Detect this by checking if n_params < 2*n_phi. if n_params is not None and n_params < 2 * n_phi: # auto_averaged: indices [0,1) for contrast_avg, [1,2) for offset_avg self.group_indices = [ (0, 1), # contrast_avg group (1, 2), # offset_avg group ] else: self.group_indices = [ (0, n_phi), # contrast group (n_phi, 2 * n_phi), # offset group ] else: self.group_indices = list(config.group_indices) # Auto-tune λ if enabled if config.auto_tune_lambda and config.target_cv > 0: self.lambda_value = config.target_contribution / (config.target_cv**2) logger.debug( f"Auto-tuned lambda = {self.lambda_value:.2f} " f"(target_cv={config.target_cv}, " f"contribution={config.target_contribution})" ) else: self.lambda_value = config.lambda_base # Diagnostics self._last_cv_values: dict[int, float] = {} self._last_reg_contribution: float = 0.0
[docs] def compute_regularization( self, params: np.ndarray, mse: float, n_points: int ) -> float: """Compute regularization term to add to loss. Parameters ---------- params : np.ndarray Full parameter vector. mse : float Current mean squared error. n_points : int Number of data points. Returns ------- float Regularization term to add to loss (SSE scale). """ if not self.config.enable: return 0.0 total_reg = 0.0 self._last_cv_values = {} for group_idx, (start, end) in enumerate(self.group_indices): if start >= len(params) or end > len(params): logger.warning( f"Group indices ({start}, {end}) out of bounds " f"for params length {len(params)}" ) continue group_params = params[start:end] # Security: Validate n_group to prevent division by zero n_group = end - start if n_group < 2: logger.warning( f"Group ({start}, {end}) has fewer than 2 elements, skipping regularization" ) continue if self.config.mode == "relative" or ( self.config.mode == "auto" and self.n_phi > 5 ): # CV-based regularization (NaN-safe: degenerate optimizer steps can yield NaN params) mean_val = np.nanmean(group_params) std_val = np.nanstd(group_params) if abs(mean_val) > 1e-10: cv = std_val / abs(mean_val) else: cv = std_val # Fallback to absolute std self._last_cv_values[group_idx] = cv # L_reg = lambda x CV^2 x MSE x n_points group_reg = self.lambda_value * (cv**2) * mse * n_points else: # Original absolute variance var_val = np.nanvar(group_params) group_reg = self.lambda_value * var_val * n_points # Still compute CV for diagnostics mean_val = np.nanmean(group_params) std_val = np.nanstd(group_params) if abs(mean_val) > 1e-10: self._last_cv_values[group_idx] = std_val / abs(mean_val) total_reg += group_reg self._last_reg_contribution = total_reg return total_reg
[docs] def compute_regularization_jax( self, params: jnp.ndarray, mse: jnp.ndarray, n_points: int ) -> jnp.ndarray: """Compute regularization term using JAX for autodiff compatibility. This method uses JAX operations (jnp) instead of NumPy, making it compatible with JAX's JIT compilation and autodiff (jax.grad). Use this method when the regularization needs to be part of a differentiable loss function. Parameters ---------- params : jnp.ndarray Full parameter vector (JAX array, possibly traced). mse : jnp.ndarray Current mean squared error (JAX scalar, possibly traced). n_points : int Number of data points. Returns ------- jnp.ndarray Regularization term to add to loss (SSE scale, JAX scalar). """ if not self.config.enable: return jnp.array(0.0) total_reg = jnp.array(0.0) for _group_idx, (start, end) in enumerate(self.group_indices): if start >= len(params) or end > len(params): continue group_params = params[start:end] # Security: Validate n_group to prevent division by zero n_group = end - start if n_group < 2: continue if self.config.mode == "relative" or ( self.config.mode == "auto" and self.n_phi > 5 ): # CV-based regularization using JAX operations mean_val = jnp.mean(group_params) std_val = jnp.std(group_params) # Use jnp.where for safe division (avoids conditional on traced value) cv = jnp.where( jnp.abs(mean_val) > 1e-10, std_val / jnp.abs(mean_val), std_val, # Fallback to absolute std ) # L_reg = λ × CV² × MSE × n_points group_reg = self.lambda_value * (cv**2) * mse * n_points else: # Original absolute variance using JAX operations var_val = jnp.var(group_params) group_reg = self.lambda_value * var_val * n_points total_reg = total_reg + group_reg return total_reg
[docs] def compute_regularization_gradient( self, params: np.ndarray, mse: float, n_points: int ) -> np.ndarray: """Compute gradient of regularization term. Parameters ---------- params : np.ndarray Full parameter vector. mse : float Current mean squared error. n_points : int Number of data points. Returns ------- np.ndarray Gradient w.r.t. all parameters (zeros for non-regularized params). """ grad = np.zeros_like(params, dtype=np.float64) if not self.config.enable: return grad for start, end in self.group_indices: if start >= len(params) or end > len(params): continue group_params = params[start:end] n_group = end - start # Security: Validate n_group to prevent division by zero if n_group < 2: continue mean_val = np.nanmean(group_params) std_val = np.nanstd(group_params) if self.config.mode == "relative" or ( self.config.mode == "auto" and self.n_phi > 5 ): # CV-based gradient if abs(mean_val) > 1e-10 and std_val > 1e-10: cv = std_val / abs(mean_val) # ∂CV²/∂p_i = 2CV × ∂CV/∂p_i # ∂CV/∂p_i = ∂(std/mean)/∂p_i # = (∂std/∂p_i × mean - std × ∂mean/∂p_i) / mean² # ∂std/∂p_i = (p_i - mean) / (n × std) # ∂mean/∂p_i = 1/n for i, p_i in enumerate(group_params): d_std = (p_i - mean_val) / (n_group * std_val) d_mean = 1.0 / n_group # Handle sign of mean if mean_val >= 0: d_cv = (d_std * mean_val - std_val * d_mean) / (mean_val**2) else: d_cv = (d_std * (-mean_val) - std_val * (-d_mean)) / ( mean_val**2 ) d_cv_sq = 2 * cv * d_cv grad[start + i] = self.lambda_value * d_cv_sq * mse * n_points else: # Absolute variance gradient # ∂Var/∂p_i = 2(p_i - mean) / n for i, p_i in enumerate(group_params): grad[start + i] = ( self.lambda_value * 2 * (p_i - mean_val) / n_group * n_points ) return grad
[docs] def check_constraint_violation(self, params: np.ndarray) -> dict[str, dict]: """Check if CV exceeds max_cv threshold. Parameters ---------- params : np.ndarray Full parameter vector. Returns ------- dict Dictionary of violations, empty if none. """ violations = {} for group_idx, (start, end) in enumerate(self.group_indices): if start >= len(params) or end > len(params): continue group_params = params[start:end] mean_val = np.nanmean(group_params) std_val = np.nanstd(group_params) cv = std_val / (abs(mean_val) + 1e-10) if cv > self.config.max_cv: group_name = "contrast" if group_idx == 0 else "offset" violations[f"group_{group_idx}_{group_name}"] = { "cv": float(cv), "max_cv": self.config.max_cv, "mean": float(mean_val), "std": float(std_val), "min": float(np.nanmin(group_params)), "max": float(np.nanmax(group_params)), } return violations
[docs] def get_diagnostics(self) -> dict: """Get regularization diagnostics for logging. Returns ------- dict Diagnostic information including CV values and contribution. """ return { "enabled": self.config.enable, "mode": self.config.mode, "lambda": self.lambda_value, "n_phi": self.n_phi, "group_indices": self.group_indices, "last_cv_values": self._last_cv_values, "last_reg_contribution": self._last_reg_contribution, "target_cv": self.config.target_cv, "max_cv": self.config.max_cv, }
[docs] def log_summary(self, params: np.ndarray, mse: float, n_points: int) -> None: """Log regularization summary. Parameters ---------- params : np.ndarray Full parameter vector. mse : float Current mean squared error. n_points : int Number of data points. """ if not self.config.enable: logger.info("Adaptive regularization: DISABLED") return # Compute regularization reg_term = self.compute_regularization(params, mse, n_points) sse = mse * n_points contribution_pct = 100 * reg_term / (sse + reg_term) if sse > 0 else 0 logger.info("Adaptive Regularization Summary:") logger.info(f" Mode: {self.config.mode}") logger.info(f" Lambda: {self.lambda_value:.2f}") logger.info(f" Regularization term: {reg_term:.2e}") logger.info(f" SSE: {sse:.2e}") logger.info(f" Contribution: {contribution_pct:.2f}%") for group_idx, cv in self._last_cv_values.items(): group_name = "contrast" if group_idx == 0 else "offset" status = "OK" if cv <= self.config.max_cv else "VIOLATION" logger.info(f" Group {group_idx} ({group_name}): CV={cv:.4f} [{status}]") violations = self.check_constraint_violation(params) if violations: logger.warning(f" CV violations detected: {list(violations.keys())}")