Source code for homodyne.optimization.nlsq.strategies.residual_jit

"""
JAX JIT-compatible stratified residual function for NLSQ optimization.

This module provides a JIT-compatible version of StratifiedResidualFunction that uses
static shapes and vmap for vectorization, solving the JAX tracing incompatibility.

Key Improvements over original StratifiedResidualFunction:
- Uses jax.vmap for parallel chunk processing (no Python loops)
- Pads chunks to uniform size for static shapes (JIT-compatible)
- Fully JIT-compiled for maximum performance
- Maintains angle stratification guarantee

Author: Homodyne Development Team
Date: 2025-11-13
Version: 2.4.0
"""

from __future__ import annotations

import logging
from typing import Any, cast

import jax
import jax.numpy as jnp
import numpy as np

from homodyne.core.physics_nlsq import compute_g2_scaled
from homodyne.utils.logging import get_logger, log_phase


[docs] class StratifiedResidualFunctionJIT: """ JIT-compatible stratified residual function using padded vmap. This class solves the JAX JIT incompatibility by: 1. Padding all chunks to uniform size (static shapes) 2. Using jax.vmap for vectorized parallel processing 3. Masking padded values in the final residuals The function maintains angle stratification (all chunks contain all angles) while being fully JIT-compilable. Attributes: phi_padded: Padded phi arrays (n_chunks, max_chunk_size) t1_padded: Padded t1 arrays (n_chunks, max_chunk_size) t2_padded: Padded t2 arrays (n_chunks, max_chunk_size) g2_padded: Padded g2 observations (n_chunks, max_chunk_size) mask: Boolean mask for real vs padded data (n_chunks, max_chunk_size) n_chunks: Number of stratified chunks max_chunk_size: Maximum points per chunk (for padding) n_real_points: Total number of real (non-padded) data points """
[docs] def __init__( self, stratified_data: Any, per_angle_scaling: bool, physical_param_names: list[str], logger: logging.Logger | None = None, fixed_contrast_per_angle: np.ndarray | None = None, fixed_offset_per_angle: np.ndarray | None = None, ) -> None: """ Initialize JIT-compatible stratified residual function. Args: stratified_data: Object with .chunks attribute containing angle-stratified chunks per_angle_scaling: Whether per-angle scaling parameters are used physical_param_names: List of physical parameter names logger: Optional logger for diagnostics fixed_contrast_per_angle: Fixed per-angle contrast values (for constant mode). When provided, contrast is NOT included in the parameter vector. fixed_offset_per_angle: Fixed per-angle offset values (for constant mode). When provided, offset is NOT included in the parameter vector. """ self.logger = logger or get_logger(__name__) self.chunks = stratified_data.chunks self.per_angle_scaling = per_angle_scaling self.physical_param_names = physical_param_names # Fixed per-angle scaling for constant mode (v2.17.0+) # When both are provided, params contains ONLY physical parameters self.fixed_contrast_per_angle = None self.fixed_offset_per_angle = None self.use_fixed_scaling = False if fixed_contrast_per_angle is not None and fixed_offset_per_angle is not None: self.fixed_contrast_per_angle = jnp.asarray(fixed_contrast_per_angle) self.fixed_offset_per_angle = jnp.asarray(fixed_offset_per_angle) self.use_fixed_scaling = True self.logger.info( "CONSTANT MODE: Using fixed per-angle scaling from quantiles. " "Parameter vector contains ONLY physical parameters." ) if not self.chunks: raise ValueError("stratified_data.chunks is empty") self.n_chunks = len(self.chunks) # Extract global metadata (same across all chunks) self.q, self.L, self.dt = self._extract_global_metadata() self.phi_unique, self.t1_unique, self.t2_unique = self._extract_unique_values() self.n_phi = len(self.phi_unique) # Prepare sigma array sigma_array = np.asarray(stratified_data.sigma, dtype=np.float64) self.sigma_jax = jnp.asarray(sigma_array) # Create padded arrays with static shapes self.logger.info(f"Creating padded arrays for {self.n_chunks} chunks...") ( self.phi_padded, self.t1_padded, self.t2_padded, self.g2_padded, self.mask, self.max_chunk_size, self.n_real_points, ) = self._create_padded_arrays() self.logger.info( f"Padded arrays created: shape ({self.n_chunks}, {self.max_chunk_size}), " f"real points: {self.n_real_points:,}, " f"padding overhead: {(1 - self.n_real_points / (self.n_chunks * self.max_chunk_size)) * 100:.2f}%" ) # JIT-compile the main residual computation # Note: Buffer donation (donate_argnums) is not used here because the # params array (small, e.g. 9 elements) never matches the output shape # (n_chunks * max_chunk_size), so JAX cannot reuse the buffer. self.logger.info("JIT-compiling residual function...") # T035: Add log_phase for JIT compilation timing with memory tracking with log_phase( "jit_residual_compilation", logger=self.logger, track_memory=True ) as phase: self._residual_fn_jit = jax.jit(self._compute_all_residuals) self.logger.info(f"JIT compilation setup complete in {phase.duration:.3f}s")
def _extract_global_metadata(self) -> tuple[float, float, float | None]: """Extract q, L, dt from chunks (should be same for all chunks).""" q_values = [float(chunk.q) for chunk in self.chunks] L_values = [float(chunk.L) for chunk in self.chunks] dt_values = [ float(chunk.dt) if chunk.dt is not None else None for chunk in self.chunks ] # Validate consistency if not all(abs(q - q_values[0]) < 1e-9 for q in q_values): raise ValueError("Inconsistent q values across chunks") if not all(abs(L - L_values[0]) < 1e-6 for L in L_values): raise ValueError("Inconsistent L values across chunks") q = q_values[0] L = L_values[0] dt = dt_values[0] if dt_values[0] is not None else None self.logger.debug(f"Global metadata: q={q:.6f}, L={L:.1f}, dt={dt}") return q, L, dt def _extract_unique_values(self) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Extract unique phi, t1, t2 values from ALL chunks. CRITICAL: Must extract from all chunks, not just first chunk, because stratified chunking may distribute different subsets of t1/t2 values across different chunks. """ # Concatenate values from all chunks to get complete set all_phi = np.concatenate([chunk.phi for chunk in self.chunks]) all_t1 = np.concatenate([chunk.t1 for chunk in self.chunks]) all_t2 = np.concatenate([chunk.t2 for chunk in self.chunks]) # Extract unique values across all chunks phi_unique = jnp.sort(jnp.unique(jnp.asarray(all_phi))) t1_unique = jnp.sort(jnp.unique(jnp.asarray(all_t1))) t2_unique = jnp.sort(jnp.unique(jnp.asarray(all_t2))) self.logger.debug( f"Unique values (from all chunks): {len(phi_unique)} phi, {len(t1_unique)} t1, {len(t2_unique)} t2" ) # Validation: check if we missed any values by comparing with first chunk first_chunk = self.chunks[0] _phi_first = jnp.sort(jnp.unique(jnp.asarray(first_chunk.phi))) # noqa: F841 t1_first = jnp.sort(jnp.unique(jnp.asarray(first_chunk.t1))) t2_first = jnp.sort(jnp.unique(jnp.asarray(first_chunk.t2))) if len(t1_unique) != len(t1_first) or len(t2_unique) != len(t2_first): self.logger.debug( f"Stratified chunking: chunks have different time point subsets " f"(first chunk: {len(t1_first)} t1, all chunks: {len(t1_unique)} t1) - " f"using complete set from all chunks" ) return phi_unique, t1_unique, t2_unique def _create_padded_arrays( self, ) -> tuple[ jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, int, int ]: """ Create padded arrays with uniform size across all chunks. Returns: phi_padded, t1_padded, t2_padded, g2_padded, mask, max_chunk_size, n_real_points """ # Determine max chunk size chunk_sizes = [len(chunk.phi) for chunk in self.chunks] max_chunk_size = max(chunk_sizes) n_real_points = sum(chunk_sizes) self.logger.debug( f"Max chunk size: {max_chunk_size:,}, total real points: {n_real_points:,}" ) # Initialize padded arrays phi_padded = np.zeros((self.n_chunks, max_chunk_size), dtype=np.float64) t1_padded = np.zeros((self.n_chunks, max_chunk_size), dtype=np.float64) t2_padded = np.zeros((self.n_chunks, max_chunk_size), dtype=np.float64) g2_padded = np.zeros((self.n_chunks, max_chunk_size), dtype=np.float64) mask = np.zeros((self.n_chunks, max_chunk_size), dtype=bool) # Fill arrays with data and create mask for i, chunk in enumerate(self.chunks): n_points = len(chunk.phi) # Copy real data phi_padded[i, :n_points] = chunk.phi t1_padded[i, :n_points] = chunk.t1 t2_padded[i, :n_points] = chunk.t2 g2_padded[i, :n_points] = chunk.g2 mask[i, :n_points] = True # Pad with last valid value (prevents out-of-bounds indexing) if n_points < max_chunk_size: phi_padded[i, n_points:] = chunk.phi[-1] t1_padded[i, n_points:] = chunk.t1[-1] t2_padded[i, n_points:] = chunk.t2[-1] g2_padded[i, n_points:] = chunk.g2[-1] # mask already False for padding # Convert to JAX arrays phi_padded_jax = jnp.asarray(phi_padded) t1_padded_jax = jnp.asarray(t1_padded) t2_padded_jax = jnp.asarray(t2_padded) g2_padded_jax = jnp.asarray(g2_padded) mask_jax = jnp.asarray(mask) return ( phi_padded_jax, t1_padded_jax, t2_padded_jax, g2_padded_jax, mask_jax, max_chunk_size, n_real_points, ) def _compute_single_chunk_residuals( self, phi_chunk: jnp.ndarray, t1_chunk: jnp.ndarray, t2_chunk: jnp.ndarray, g2_obs_chunk: jnp.ndarray, mask_chunk: jnp.ndarray, params_all: jnp.ndarray, ) -> jnp.ndarray: """ Compute residuals for a single padded chunk. This function is designed to be vmapped over the chunk dimension. Args: phi_chunk: Phi values for this chunk (max_chunk_size,) t1_chunk: T1 values for this chunk (max_chunk_size,) t2_chunk: T2 values for this chunk (max_chunk_size,) g2_obs_chunk: Observed g2 for this chunk (max_chunk_size,) mask_chunk: Mask for real vs padded data (max_chunk_size,) params_all: All parameters [scaling_params, physical_params] or just [physical_params] when use_fixed_scaling=True Returns: Masked residuals (max_chunk_size,) - padded values are zeros """ # Extract scaling and physical parameters # Three modes: # 1. use_fixed_scaling=True: params_all = [physical_params only] # contrast/offset come from self.fixed_contrast_per_angle/self.fixed_offset_per_angle # 2. per_angle_scaling=True: params_all = [contrast(n_phi), offset(n_phi), physical] # 3. per_angle_scaling=False: params_all = [contrast, offset, physical] if self.use_fixed_scaling: # CONSTANT MODE: Fixed per-angle scaling from quantiles # params_all contains ONLY physical parameters contrast = self.fixed_contrast_per_angle offset = self.fixed_offset_per_angle physical_params = params_all # All params are physical elif self.per_angle_scaling: contrast = params_all[: self.n_phi] offset = params_all[self.n_phi : 2 * self.n_phi] physical_params = params_all[2 * self.n_phi :] else: contrast = params_all[0] offset = params_all[1] physical_params = params_all[2:] # Compute theoretical g2 using vectorized computation # NOTE: Warning for dt=None is emitted in __call__ (outside JIT trace) dt_value = self.dt if self.dt is not None else 0.001 if self.use_fixed_scaling or self.per_angle_scaling: # Vectorize over phi with corresponding contrast/offset def compute_for_angle( phi_val: float, contrast_val: float, offset_val: float ) -> jnp.ndarray: return jnp.squeeze( compute_g2_scaled( params=physical_params, t1=self.t1_unique, t2=self.t2_unique, phi=jnp.asarray(phi_val), q=self.q, L=self.L, contrast=contrast_val, offset=offset_val, dt=dt_value, ), axis=0, ) compute_g2_vmap = jax.vmap(compute_for_angle, in_axes=(0, 0, 0)) g2_theory_grid = compute_g2_vmap(self.phi_unique, contrast, offset) # type: ignore[arg-type] else: # Legacy: single contrast/offset def compute_for_angle_scalar(phi_val: float) -> jnp.ndarray: # We use cast(float, ...) here to satisfy mypy, but at runtime these are JAX tracers # which compute_g2_scaled handles correctly despite the float type hint. from typing import cast # noqa: F811 — intentional re-import in closure return jnp.squeeze( compute_g2_scaled( params=physical_params, t1=self.t1_unique, t2=self.t2_unique, phi=jnp.asarray(phi_val), q=self.q, L=self.L, contrast=cast(float, contrast), offset=cast(float, offset), dt=dt_value, ), axis=0, ) compute_g2_vmap_scalar = jax.vmap(compute_for_angle_scalar, in_axes=0) g2_theory_grid = compute_g2_vmap_scalar(self.phi_unique) # type: ignore[arg-type] # NOTE: Diagonal correction is intentionally skipped here. # Residuals for t1==t2 points are masked out below via `non_diagonal`, # so theory grid diagonal values are never used in the optimization. # Skipping this call removes ~38% of residual computation time. # Flatten theory grid for indexing g2_theory_flat = g2_theory_grid.flatten() # Find indices of (phi, t1, t2) in the full grid # n_phi dimension used implicitly for grid shape: (n_phi, n_t1, n_t2) n_t1 = len(self.t1_unique) n_t2 = len(self.t2_unique) # Note: clip removed - stratified LS data comes from same chunks that build # unique arrays, so all values are guaranteed to be in range. The clip was # causing optimization to converge to wrong local minima (D0=91342 vs 19253). # Original clip added in ae4848c for streaming optimizer, but not needed here. # Cast to int64 BEFORE multiplication to prevent int32 overflow. # jnp.searchsorted returns int32; for large datasets (n_phi=100, # n_t1=5000, n_t2=5000) the product 99*25_000_000=2.475B exceeds # int32 max (2.147B), silently wrapping to a negative index. phi_indices = jnp.searchsorted(self.phi_unique, phi_chunk).astype(jnp.int64) t1_indices = jnp.searchsorted(self.t1_unique, t1_chunk).astype(jnp.int64) t2_indices = jnp.searchsorted(self.t2_unique, t2_chunk).astype(jnp.int64) # Compute flat indices flat_indices = phi_indices * (n_t1 * n_t2) + t1_indices * n_t2 + t2_indices # Extract theory values for chunk points g2_theory_chunk = g2_theory_flat[flat_indices] # Get sigma values for chunk points sigma_flat = self.sigma_jax.flatten() sigma_chunk = sigma_flat[flat_indices] # Compute weighted residuals — mask out zero-sigma points entirely EPS = 1e-10 valid_sigma = sigma_chunk > EPS safe_sigma = jnp.where(valid_sigma, sigma_chunk, 1.0) residuals_raw = jnp.where( valid_sigma, (g2_obs_chunk - g2_theory_chunk) / safe_sigma, 0.0 ) # v2.14.2+: Mask out both padded values AND diagonal values (t1 == t2) # Diagonal points are autocorrelation artifacts, not physics # CRITICAL FIX (2026-01-15): Compare actual time VALUES, not indices. # t1_indices and t2_indices reference DIFFERENT arrays (t1_unique vs t2_unique), # so comparing indices is wrong. Must compare the actual t1_chunk and t2_chunk values. non_diagonal = jnp.abs(t1_chunk - t2_chunk) > 1e-15 residuals_masked = jnp.where(mask_chunk & non_diagonal, residuals_raw, 0.0) return residuals_masked def _compute_all_residuals(self, params: jnp.ndarray) -> jnp.ndarray: """ Compute residuals for all chunks using vmap (JIT-compiled). Args: params: All parameters (scaling + physical) Returns: Flattened residuals INCLUDING padding (will be filtered in __call__) Shape: (n_chunks * max_chunk_size,) with zeros for padded values """ # Cache vmap'd function to avoid JIT retrace on every call. # params is passed as an explicit unbatched argument (in_axes=None for the # last axis) instead of via closure capture. A new lambda (new Python object # identity) is created each call when params is captured by closure, forcing # JAX to retrace the vmap'd function on every optimizer iteration. if not hasattr(self, "_cached_chunk_vmap"): self._cached_chunk_vmap = jax.vmap( lambda phi, t1, t2, g2, mask, p: self._compute_single_chunk_residuals( phi, t1, t2, g2, mask, p ), in_axes=(0, 0, 0, 0, 0, None), # params (p) not batched ) # Compute residuals for all chunks in parallel residuals_padded = self._cached_chunk_vmap( self.phi_padded, self.t1_padded, self.t2_padded, self.g2_padded, self.mask, params, ) # Shape: (n_chunks, max_chunk_size) # Flatten residuals (padding is already masked to zero in _compute_single_chunk_residuals) residuals_flat = ( residuals_padded.flatten() ) # Shape: (n_chunks * max_chunk_size,) # Return full array (filtering happens in __call__ to avoid JIT boolean indexing) return residuals_flat
[docs] def __call__(self, params: np.ndarray | jnp.ndarray) -> jnp.ndarray: """ Compute residuals (interface for NLSQ least_squares). This method is JIT-traced by NLSQ, so it must use JAX operations only. Padded values are already masked to zero, so they don't contribute to the optimization objective (sum of squared residuals). Args: params: Parameters (numpy or JAX array) Returns: Residuals as JAX array (n_chunks * max_chunk_size,) with zeros for padding Note: Padding zeros don't affect optimization but increase array size """ if self.dt is None: self.logger.warning( "StratifiedResidualFunctionJIT: dt is None; " "using dt=0.001 s as fallback. Physics factors may be incorrect." ) params_jax = jnp.asarray(params, dtype=jnp.float64) residuals_jax = self._residual_fn_jit(params_jax) return cast( jnp.ndarray, residuals_jax ) # Keep as JAX array for JIT compatibility
[docs] def validate_chunk_structure(self) -> bool: """ Validate that all chunks contain all phi angles. Returns: True if validation passes Raises: ValueError: If validation fails """ expected_angles = set( np.unique(np.round(np.asarray(self.phi_unique), decimals=6)) ) n_expected = len(expected_angles) self.logger.info( f"Validating chunk structure: {self.n_chunks} chunks, " f"{n_expected} expected angles per chunk" ) for i, _chunk in enumerate(self.chunks): # Only check real data (not padding) n_real = int(np.sum(self.mask[i])) phi_real = self.phi_padded[i, :n_real] chunk_angles = set(np.unique(np.round(np.asarray(phi_real), decimals=6))) if chunk_angles != expected_angles: missing = expected_angles - chunk_angles extra = chunk_angles - expected_angles raise ValueError( f"Chunk {i} has invalid angle distribution:\n" f" Missing angles: {sorted(missing)}\n" f" Extra angles: {sorted(extra)}\n" f" Expected {n_expected} angles, got {len(chunk_angles)}" ) self.logger.info("Chunk structure validation passed: all chunks angle-complete") return True
[docs] def get_diagnostics(self) -> dict: """Get diagnostic information about the residual function.""" return { "n_chunks": self.n_chunks, "max_chunk_size": self.max_chunk_size, "n_real_points": self.n_real_points, "padding_overhead_pct": ( 1 - self.n_real_points / (self.n_chunks * self.max_chunk_size) ) * 100, "n_phi": self.n_phi, "n_t1": len(self.t1_unique), "n_t2": len(self.t2_unique), "per_angle_scaling": self.per_angle_scaling, "jit_compiled": True, }
[docs] def log_diagnostics(self) -> None: """Log diagnostic information about the residual function.""" diag = self.get_diagnostics() self.logger.info("Stratified Residual Function Diagnostics:") self.logger.info(f" Chunks: {diag['n_chunks']}") self.logger.info(f" Max chunk size: {diag['max_chunk_size']:,}") self.logger.info(f" Real points: {diag['n_real_points']:,}") self.logger.info(f" Padding overhead: {diag['padding_overhead_pct']:.2f}%") self.logger.info(f" Angles (phi): {diag['n_phi']}") self.logger.info(f" Time points (t1): {diag['n_t1']}") self.logger.info(f" Time points (t2): {diag['n_t2']}") self.logger.info(f" Per-angle scaling: {diag['per_angle_scaling']}") self.logger.info(f" JIT compiled: {diag['jit_compiled']}")