Source code for homodyne.optimization.nlsq.shear_weighting

"""Shear-Sensitivity Weighting for Anti-Degeneracy Defense.

This module implements angle-dependent loss weighting to prevent gradient
cancellation in the shear term during optimization.

Part of Anti-Degeneracy Defense System v2.9.1.

The Problem
-----------
The shear term gradient is:
    d(g1_shear)/d(gamma_dot_t0) ~ cos(phi0 - phi)

When summed uniformly over all angles:
- Angles near phi0: cos(phi0 - phi) ~ +1 (positive contribution)
- Angles near phi0 +/- 90deg: cos ~ 0 (negligible)
- Angles near phi0 +/- 180deg: cos ~ -1 (negative contribution)

With uniformly distributed angles, positive and negative contributions
CANCEL, leading to near-zero net gradient for gamma_dot_t0. This causes
the shear parameter to collapse to its lower bound.

The Solution
------------
Use angle-dependent loss weighting:

    L = sum_phi w(phi) * sum_tau (g2_model - g2_exp)^2

where w(phi) emphasizes shear-sensitive angles:

    w(phi) = w_min + (1 - w_min) * abs(cos(phi0_current - phi))^alpha

This converts gradient cancellation into a weighted sum where shear-sensitive
angles (parallel/antiparallel to flow) contribute more than perpendicular
angles. All angles still contribute to prevent information loss.

Configuration
-------------
shear_weighting:
    enable: true                    # Enable shear-sensitivity weighting
    min_weight: 0.3                 # Minimum weight (0-1)
    alpha: 1.0                      # Shear sensitivity exponent (1 = linear)
    update_frequency: 1             # Update weights every N outer iterations
    initial_phi0: null              # Initial phi0 guess (null = use config)
"""

from __future__ import annotations

from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING

import jax
import jax.numpy as jnp
import numpy as np

from homodyne.optimization.nlsq.config import safe_float
from homodyne.utils.logging import get_logger

if TYPE_CHECKING:
    from collections.abc import Mapping

    from jax import Array

logger = get_logger(__name__)


# Performance Optimization (Spec 001 - FR-001, T014): JIT-compiled weight computation
# static_argnums=(4,): `normalize` is a config bool — never traced, prevents spurious
# retrace when the bool's concrete value changes between calls.
@partial(jax.jit, static_argnums=(4,))
def _compute_weights_jax(
    phi_angles: jnp.ndarray,
    phi0: float,
    min_weight: float,
    alpha: float,
    normalize: bool,
) -> jnp.ndarray:
    """JIT-compiled shear weight computation for optimal performance.

    Computes angle-dependent weights that emphasize shear-sensitive angles
    (parallel/antiparallel to flow direction).

    Parameters
    ----------
    phi_angles : jnp.ndarray
        Phi angles in degrees.
    phi0 : float
        Current phi0 estimate in degrees.
    min_weight : float
        Minimum weight for perpendicular angles (0-1).
    alpha : float
        Shear sensitivity exponent.
    normalize : bool
        Whether to normalize weights so mean = 1.

    Returns
    -------
    jnp.ndarray
        Weight array of shape (n_phi,).
    """
    # Convert to radians
    phi0_rad = jnp.radians(phi0)
    phi_rad = jnp.radians(phi_angles)

    # Compute shear sensitivity: |cos(phi0 - phi)|
    # Underflow protection: use jnp.where (gradient-safe) instead of jnp.maximum.
    # phi0 is a traced parameter; jnp.maximum zeros its gradient when cos_factor ≤ 1e-10.
    _cos_abs = jnp.abs(jnp.cos(phi0_rad - phi_rad))
    cos_factor = jnp.where(_cos_abs > 1e-10, _cos_abs, 1e-10)

    # Apply exponent and scale
    # w(phi) = w_min + (1 - w_min) * |cos(phi0 - phi)|^alpha
    weights = min_weight + (1.0 - min_weight) * (cos_factor**alpha)

    # Normalize if enabled using jax.lax.cond for JIT compatibility
    return jax.lax.cond(
        normalize,
        lambda w: w / jnp.mean(w),
        lambda w: w,
        weights,
    )


[docs] @dataclass class ShearWeightingConfig: """Configuration for shear-sensitivity weighting. Attributes ---------- enable : bool Enable shear-sensitivity weighting. Default True. min_weight : float Minimum weight for perpendicular angles. Range [0, 1]. Default 0.3. alpha : float Shear sensitivity exponent. Higher = more aggressive weighting. Default 1.0 (linear). update_frequency : int Update weights every N outer iterations. Default 1. initial_phi0 : float or None Initial phi0 guess in degrees. None = use config or 0.0. normalize : bool Normalize weights so sum = n_phi. Default True. """ enable: bool = True min_weight: float = 0.3 alpha: float = 1.0 update_frequency: int = 1 initial_phi0: float | None = None normalize: bool = True
[docs] @classmethod def from_config(cls, config: Mapping) -> ShearWeightingConfig: """Create from configuration dictionary. Parameters ---------- config : Mapping Configuration dictionary. Returns ------- ShearWeightingConfig Configuration object. """ sw_config = config.get("shear_weighting", {}) return cls( enable=sw_config.get("enable", True), min_weight=safe_float(sw_config.get("min_weight"), 0.3), alpha=safe_float(sw_config.get("alpha"), 1.0), update_frequency=int(sw_config.get("update_frequency", 1)), initial_phi0=safe_float(sw_config.get("initial_phi0"), 0.0) if sw_config.get("initial_phi0") is not None else None, normalize=sw_config.get("normalize", True), )
[docs] class ShearSensitivityWeighting: """Shear-sensitivity weighted loss for anti-degeneracy defense. This class manages angle-dependent weights that emphasize shear-sensitive angles during optimization, preventing gradient cancellation. Parameters ---------- phi_angles : np.ndarray Array of phi angles in degrees. n_physical : int Number of physical parameters. phi0_index : int Index of phi0 in physical parameters (typically 6 for laminar_flow). config : ShearWeightingConfig Weighting configuration. Examples -------- >>> phi_angles = np.array([-30, 0, 30, 60, 90, 120]) >>> weighter = ShearSensitivityWeighting(phi_angles, n_physical=7, phi0_index=6) >>> weights = weighter.get_weights(phi0_current=-5.0) >>> # Angles near -5 deg and 175 deg get higher weight """
[docs] def __init__( self, phi_angles: np.ndarray, n_physical: int, phi0_index: int, config: ShearWeightingConfig | None = None, ): self.phi_angles = np.asarray(phi_angles, dtype=np.float64) self.n_phi = len(self.phi_angles) self.n_physical = n_physical self.phi0_index = phi0_index self.config = config or ShearWeightingConfig() # Current phi0 estimate self._phi0_current = self.config.initial_phi0 or 0.0 # Precomputed weight lookup (per phi index) self._weights = self._compute_weights(self._phi0_current) self._weights_jax = jnp.asarray(self._weights) # Tracking self._update_count = 0 if self.config.enable: logger.info( f"ShearSensitivityWeighting initialized: " f"n_phi={self.n_phi}, min_weight={self.config.min_weight:.2f}, " f"alpha={self.config.alpha:.1f}, initial_phi0={self._phi0_current:.1f} deg" )
def _compute_weights(self, phi0: float) -> np.ndarray: """Compute angle weights for given phi0. Performance Optimization (Spec 001 - FR-001, T015): Uses JIT-compiled computation for 2-3x speedup on repeated calls. Parameters ---------- phi0 : float Current phi0 estimate in degrees. Returns ------- np.ndarray Weight array of shape (n_phi,). """ # Performance Optimization (Spec 001 - FR-001, T015): Use JIT-compiled version result = _compute_weights_jax( jnp.asarray(self.phi_angles), phi0, self.config.min_weight, self.config.alpha, self.config.normalize, ) return np.asarray(result)
[docs] def update_phi0(self, params: np.ndarray, iteration: int = 0) -> None: """Update phi0 estimate from current parameters. Parameters ---------- params : np.ndarray Current parameter vector. Physical parameters should be at the end. iteration : int Current iteration number. """ if not self.config.enable: return # Check if we should update this iteration if iteration % self.config.update_frequency != 0: return # Extract phi0 from parameters # Parameter layout: [per_angle_params, physical_params] # phi0 is the last physical parameter (index phi0_index from the end of physical) n_per_angle = len(params) - self.n_physical phi0_idx = n_per_angle + self.phi0_index new_phi0 = float(params[phi0_idx]) # Check if phi0 has changed significantly if abs(new_phi0 - self._phi0_current) > 0.1: # 0.1 degree threshold self._phi0_current = new_phi0 self._weights = self._compute_weights(new_phi0) self._weights_jax = jnp.asarray(self._weights) self._update_count += 1 logger.debug( f"ShearSensitivityWeighting updated: " f"phi0={new_phi0:.2f} deg, weights range=[{self._weights.min():.3f}, " f"{self._weights.max():.3f}]" )
[docs] def get_weights(self, phi0_current: float | None = None) -> np.ndarray: """Get current angle weights. Parameters ---------- phi0_current : float, optional Override phi0 for weight computation. If None, uses stored value. Returns ------- np.ndarray Weight array of shape (n_phi,). """ if phi0_current is not None and phi0_current != self._phi0_current: return self._compute_weights(phi0_current) return self._weights
[docs] def get_weights_jax(self) -> Array: """Get current angle weights as JAX array. Returns ------- jax.Array Weight array of shape (n_phi,). """ return self._weights_jax
[docs] def apply_weights_to_loss(self, residuals: Array, phi_indices: Array) -> Array: """Apply angle weights to residuals for loss computation. Computes weighted mean squared error: L = sum_i w[phi_idx[i]] * residuals[i]^2 / sum_i w[phi_idx[i]] Parameters ---------- residuals : jax.Array Residuals array of shape (n_data,). phi_indices : jax.Array Phi index for each data point, shape (n_data,). Returns ------- jax.Array Weighted loss (scalar). """ if not self.config.enable: return jnp.mean(residuals**2) * len(residuals) # Lookup weights for each data point weights = self._weights_jax[phi_indices.astype(jnp.int32)] # Weighted mean squared error weighted_residuals_sq = weights * residuals**2 weighted_loss = jnp.sum(weighted_residuals_sq) return weighted_loss
[docs] def compute_weighted_mse(self, residuals: Array, phi_indices: Array) -> Array: """Compute weighted MSE (for gradient computation). Parameters ---------- residuals : jax.Array Residuals array of shape (n_data,). phi_indices : jax.Array Phi index for each data point, shape (n_data,). Returns ------- jax.Array Weighted MSE (scalar). """ if not self.config.enable: return jnp.mean(residuals**2) # Lookup weights for each data point weights = self._weights_jax[phi_indices.astype(jnp.int32)] # Weighted mean: sum(w * r^2) / sum(w) weighted_mse = jnp.sum(weights * residuals**2) / jnp.sum(weights) return weighted_mse
[docs] def get_diagnostics(self) -> dict: """Get weighting diagnostics. Returns ------- dict Diagnostic information. """ return { "enabled": self.config.enable, "min_weight": self.config.min_weight, "alpha": self.config.alpha, "current_phi0": self._phi0_current, "update_count": self._update_count, "weights_range": [float(self._weights.min()), float(self._weights.max())], "weights_mean": float(self._weights.mean()), "weights_std": float(self._weights.std()), }
@property def phi0_current(self) -> float: """Current phi0 estimate in degrees.""" return self._phi0_current
[docs] def create_shear_weighting( phi_angles: np.ndarray, n_physical: int, config: Mapping | None = None, physical_param_names: list[str] | None = None, ) -> ShearSensitivityWeighting | None: """Factory function to create shear weighting if enabled. Parameters ---------- phi_angles : np.ndarray Phi angles in degrees. n_physical : int Number of physical parameters. config : Mapping, optional Configuration dictionary. Returns ------- ShearSensitivityWeighting or None Weighting object if enabled, None otherwise. """ if config is None: return None sw_config = ShearWeightingConfig.from_config(config) if not sw_config.enable: logger.debug("Shear-sensitivity weighting disabled by config") return None # phi0 is typically the last of the 7 physical parameters in laminar_flow # Physical params: [D0, alpha, D_offset, gamma_dot_t0, beta, gamma_dot_t_offset, phi0] if physical_param_names is not None and "phi0" not in physical_param_names: logger.debug("phi0 not in physical params -- shear weighting disabled") return None phi0_index = ( physical_param_names.index("phi0") if physical_param_names is not None and "phi0" in physical_param_names else 6 ) return ShearSensitivityWeighting( phi_angles=phi_angles, n_physical=n_physical, phi0_index=phi0_index, config=sw_config, )