Source code for homodyne.optimization.numerical_validation

"""Numerical validation for optimization at critical points.

This module provides validation functions to detect numerical issues
(NaN, Inf, bounds violations) at three critical points during optimization:
1. After gradient computation
2. After parameter update
3. After loss calculation

These validations help catch numerical instabilities early and enable
targeted recovery strategies.
"""

from typing import Any

import jax.numpy as jnp
import numpy as np

from homodyne.optimization.exceptions import NLSQNumericalError


[docs] class NumericalValidator: """Validator for numerical stability at critical optimization points. This class provides methods to validate numerical values at three critical points: gradients, parameters, and loss values. Detection of NaN/Inf enables targeted recovery strategies. Attributes ---------- enable_validation : bool Whether to perform validation (can disable for speed) bounds : tuple of np.ndarray or None Parameter bounds (lower, upper) for bounds checking Examples -------- >>> validator = NumericalValidator(enable_validation=True) >>> try: ... validator.validate_gradients(gradients) ... validator.validate_parameters(params, bounds) ... validator.validate_loss(loss_value) ... except NLSQNumericalError as e: ... print(f"Numerical error at {e.detection_point}") """
[docs] def __init__( self, enable_validation: bool = True, bounds: tuple[np.ndarray, np.ndarray] | None = None, ): """Initialize numerical validator. Parameters ---------- enable_validation : bool, optional Whether to perform validation, by default True bounds : tuple of np.ndarray, optional Parameter bounds (lower, upper), by default None """ self.enable_validation = enable_validation self.bounds = bounds
[docs] def validate_gradients(self, gradients: Any) -> None: """Validate gradients for NaN/Inf after Jacobian computation. This is validation point 1: Gradients can become non-finite due to overflow in the model function or ill-conditioned Jacobian. Parameters ---------- gradients : array-like Gradient values to validate Raises ------ NLSQNumericalError If gradients contain NaN or Inf values """ if not self.enable_validation: return # Convert to JAX array for consistent finite check grad_array = jnp.asarray(gradients) if not jnp.isfinite(grad_array).all(): # Find indices of invalid values invalid_mask = ~jnp.isfinite(grad_array) invalid_indices = jnp.where(invalid_mask)[0] invalid_values = [ f"grad[{int(idx)}]={float(grad_array[idx])}" for idx in invalid_indices[:5] # Report first 5 ] raise NLSQNumericalError( f"Non-finite gradients detected at {len(invalid_indices)} locations. " f"This typically indicates:\n" f" 1. Overflow in model function evaluation\n" f" 2. Ill-conditioned Jacobian matrix\n" f" 3. Learning rate too large\n" f"First few invalid values: {invalid_values}", detection_point="gradient", invalid_values=invalid_values, error_context={"n_invalid": int(len(invalid_indices))}, )
[docs] def validate_parameters( self, parameters: Any, bounds: tuple[np.ndarray, np.ndarray] | None = None, ) -> None: """Validate parameters for NaN/Inf and bounds violations after update. This is validation point 2: Parameters can become non-finite after update steps, especially with aggressive step sizes. Parameters ---------- parameters : array-like Parameter values to validate bounds : tuple of np.ndarray, optional Parameter bounds (lower, upper), overrides instance bounds Raises ------ NLSQNumericalError If parameters contain NaN or Inf values """ if not self.enable_validation: return # Convert to JAX array param_array = jnp.asarray(parameters) # Check for NaN/Inf if not jnp.isfinite(param_array).all(): invalid_mask = ~jnp.isfinite(param_array) invalid_indices = jnp.where(invalid_mask)[0] invalid_values = [ f"param[{int(idx)}]={float(param_array[idx])}" for idx in invalid_indices[:5] ] raise NLSQNumericalError( f"Non-finite parameters detected at {len(invalid_indices)} locations. " f"This typically indicates:\n" f" 1. Step size too large\n" f" 2. Unbounded parameter growth\n" f" 3. Numerical overflow in parameter update\n" f"First few invalid values: {invalid_values}", detection_point="parameter", invalid_values=invalid_values, error_context={"n_invalid": int(len(invalid_indices))}, ) # Check bounds violations if bounds provided bounds_to_check = bounds or self.bounds if bounds_to_check is not None: lower, upper = bounds_to_check lower = jnp.asarray(lower) upper = jnp.asarray(upper) violations_lower = param_array < lower violations_upper = param_array > upper if jnp.any(violations_lower) or jnp.any(violations_upper): n_violations = int( jnp.sum(violations_lower) + jnp.sum(violations_upper) ) # Collect violation indices via vectorized jnp.where, then do a # single device-to-host transfer rather than one per element. lower_indices = np.asarray(jnp.where(violations_lower)[0]) upper_indices = np.asarray(jnp.where(violations_upper)[0]) param_np = np.asarray(param_array) lower_np = np.asarray(lower) upper_np = np.asarray(upper) violation_info = [ f"param[{i}]={param_np[i]:.6g} < lower={lower_np[i]:.6g}" for i in lower_indices[:5] ] + [ f"param[{i}]={param_np[i]:.6g} > upper={upper_np[i]:.6g}" for i in upper_indices[: max(0, 5 - len(lower_indices))] ] # Note: Bounds violations are often clipped automatically by optimizer # This is more of a warning condition than a hard error # Could log instead of raising, but raising allows recovery strategies raise NLSQNumericalError( f"Parameter bounds violations detected at {n_violations} locations.\n" f"Violations: {violation_info[:5]}", detection_point="parameter_bounds", invalid_values=violation_info[:10], error_context={"n_violations": n_violations}, )
[docs] def validate_loss(self, loss_value: Any) -> None: """Validate loss value for NaN/Inf after loss computation. This is validation point 3: Loss can become non-finite due to overflow in residual computation or invalid parameter values. Parameters ---------- loss_value : scalar Loss value to validate Raises ------ NLSQNumericalError If loss is NaN or Inf """ if not self.enable_validation: return # Convert to scalar loss_scalar = float(loss_value) if not np.isfinite(loss_scalar): raise NLSQNumericalError( f"Non-finite loss value detected: {loss_scalar}. " f"This typically indicates:\n" f" 1. Overflow in residual computation\n" f" 2. Invalid parameter values\n" f" 3. Model function returning NaN/Inf\n" f"Consider:\n" f" - Reducing step size\n" f" - Tightening parameter bounds\n" f" - Normalizing data\n" f" - Checking model function for numerical issues", detection_point="loss", invalid_values=[f"loss={loss_scalar}"], error_context={"loss_value": loss_scalar}, )
[docs] def set_bounds(self, bounds: tuple[np.ndarray, np.ndarray]) -> None: """Update parameter bounds for validation. Parameters ---------- bounds : tuple of np.ndarray New parameter bounds (lower, upper) """ self.bounds = bounds
[docs] def disable(self) -> None: """Disable validation for performance-critical sections.""" self.enable_validation = False
[docs] def enable(self) -> None: """Re-enable validation after disabling.""" self.enable_validation = True