Source code for homodyne.optimization.nlsq.jacobian

"""Jacobian computation utilities for NLSQ optimization.

This module extracts Jacobian-related functions from nlsq_wrapper.py
to reduce file size and improve maintainability.

Extracted from nlsq_wrapper.py as part of technical debt remediation (Dec 2025).
"""

from __future__ import annotations

from collections.abc import Callable
from typing import Any

import jax
import jax.numpy as jnp
import numpy as np


[docs] def compute_jacobian_stats( residual_fn: Callable[..., Any], x_subset: np.ndarray, params: np.ndarray, scaling_factor: float, ) -> tuple[np.ndarray | None, np.ndarray | None]: """Compute Jacobian statistics for convergence diagnostics. Computes the Jacobian matrix and derives: - JTJ (Jacobian transpose times Jacobian) for Hessian approximation - Column norms for parameter sensitivity analysis Parameters ---------- residual_fn : Callable Residual function to differentiate. x_subset : np.ndarray Subset of x data for Jacobian computation. params : np.ndarray Current parameter values. scaling_factor : float Scaling factor for JTJ computation. Returns ------- tuple[np.ndarray | None, np.ndarray | None] (JTJ matrix, column norms) or (None, None) on failure. """ try: params_jnp = jnp.asarray(params) if hasattr(residual_fn, "jax_residual"): def residual_vector(p): return jnp.asarray(residual_fn.jax_residual(jnp.asarray(p))).reshape(-1) else: def residual_vector(p): return jnp.asarray(residual_fn(x_subset, *tuple(p))).reshape(-1) # Use jacfwd (JVP-based): O(n × cost_f) vs jacrev's O(m × cost_f). # For XPCS m >> n (e.g., 20K residuals, 9 params), jacfwd is ~260x faster. jac = jax.jacfwd(residual_vector)(params_jnp) jac_np = np.asarray(jac) # Performance Optimization (Spec 001 - FR-010, T048): Check condition number # to determine optimal J^T J computation method. # For ill-conditioned Jacobians (cond > 1e6), use QR-based computation. try: cond_number = np.linalg.cond(jac_np) except np.linalg.LinAlgError: cond_number = np.inf if cond_number > 1e6: # Performance Optimization (Spec 001 - FR-010, T049): QR-based J^T J # For ill-conditioned Jacobians, J^T J = R^T R is more numerically stable. # QR decomposition: J = Q @ R where Q is orthogonal, R is upper triangular. Q, R = np.linalg.qr(jac_np) jtj = R.T @ R * scaling_factor else: # Standard computation for well-conditioned Jacobians jtj = jac_np.T @ jac_np * scaling_factor col_norms = np.linalg.norm(jac_np, axis=0) * np.sqrt(scaling_factor) return jtj, col_norms except (ValueError, RuntimeError, np.linalg.LinAlgError): return None, None
[docs] def compute_jacobian_condition_number( residual_fn: Callable[..., Any], x_subset: np.ndarray, params: np.ndarray, ) -> float | None: """Compute condition number of Jacobian matrix. The condition number indicates how sensitive the optimization is to parameter perturbations. High values (>1e6) suggest ill-conditioning. Parameters ---------- residual_fn : Callable Residual function to differentiate. x_subset : np.ndarray Subset of x data for Jacobian computation. params : np.ndarray Current parameter values. Returns ------- float | None Condition number or None on failure. """ try: params_jnp = jnp.asarray(params) if hasattr(residual_fn, "jax_residual"): def residual_vector(p): return jnp.asarray(residual_fn.jax_residual(jnp.asarray(p))).reshape(-1) else: def residual_vector(p): return jnp.asarray(residual_fn(x_subset, *tuple(p))).reshape(-1) # Use jacfwd (JVP-based): O(n × cost_f) vs jacrev's O(m × cost_f). # For XPCS m >> n (e.g., 20K residuals, 9 params), jacfwd is ~260x faster. jac = jax.jacfwd(residual_vector)(params_jnp) jac_np = np.asarray(jac) return float(np.linalg.cond(jac_np)) except (ValueError, RuntimeError, np.linalg.LinAlgError): return None
[docs] def analyze_parameter_sensitivity( residual_fn: Callable[..., Any], x_subset: np.ndarray, params: np.ndarray, param_names: list[str], ) -> dict[str, float]: """Analyze parameter sensitivity from Jacobian column norms. Higher column norms indicate parameters that have more influence on the residuals. Parameters ---------- residual_fn : Callable Residual function to differentiate. x_subset : np.ndarray Subset of x data for Jacobian computation. params : np.ndarray Current parameter values. param_names : list[str] Parameter names for labeling. Returns ------- dict[str, float] Mapping from parameter name to sensitivity (normalized 0-1). """ _, col_norms = compute_jacobian_stats(residual_fn, x_subset, params, 1.0) if col_norms is None: return {} # Normalize to 0-1 range max_norm = np.max(col_norms) if max_norm > 0: normalized = col_norms / max_norm else: normalized = np.zeros_like(col_norms) return { name: float(norm) for name, norm in zip(param_names, normalized, strict=False) }
[docs] def estimate_gradient_noise( residual_fn: Callable[..., Any], x_subset: np.ndarray, params: np.ndarray, n_samples: int = 5, perturbation: float = 1e-6, seed: int = 42, ) -> float | None: """Estimate gradient noise from multiple Jacobian computations. Computes Jacobian multiple times with small perturbations to estimate numerical noise in gradient computation. Parameters ---------- residual_fn : Callable Residual function to differentiate. x_subset : np.ndarray Subset of x data for Jacobian computation. params : np.ndarray Current parameter values. n_samples : int Number of perturbed samples. perturbation : float Relative perturbation size. Returns ------- float | None Estimated gradient noise (coefficient of variation) or None on failure. """ try: params_base = np.asarray(params, dtype=float) jacobians = [] rng = np.random.default_rng(seed=seed) # Define residual_vector once outside the loop (branch condition is loop-invariant) if hasattr(residual_fn, "jax_residual"): def residual_vector(p): return jnp.asarray(residual_fn.jax_residual(jnp.asarray(p))).reshape(-1) else: def residual_vector(p): return jnp.asarray(residual_fn(x_subset, *tuple(p))).reshape(-1) for _ in range(n_samples): # Add small perturbation noise = ( rng.standard_normal(len(params_base)) * perturbation * np.abs(params_base) ) params_perturbed = params_base + noise params_jnp = jnp.asarray(params_perturbed) # Use jacfwd (JVP-based): O(n × cost_f), faster for m >> n jac = jax.jacfwd(residual_vector)(params_jnp) jacobians.append(np.asarray(jac)) # Compute coefficient of variation across samples jac_stack = np.stack(jacobians, axis=0) jac_mean = np.mean(jac_stack, axis=0) jac_std = np.std(jac_stack, axis=0) # Avoid division by zero with np.errstate(divide="ignore", invalid="ignore"): cv = np.where(np.abs(jac_mean) > 1e-10, jac_std / np.abs(jac_mean), 0.0) return float(np.median(cv)) except (ValueError, RuntimeError, np.linalg.LinAlgError): return None