Source code for homodyne.optimization.nlsq.strategies.out_of_core

"""Out-of-Core Global Accumulation strategy for NLSQ optimization.

Extracted from wrapper.py to reduce file size and improve maintainability.

This module provides:
- Out-of-core J^T J / J^T r accumulation for massive datasets
- Levenberg-Marquardt iteration with chunk-wise gradient accumulation
- Parallel chunk computation with shared memory pools
"""

from __future__ import annotations

import logging
import os
import time
from typing import Any

import numpy as np

from homodyne.optimization.nlsq.strategies.chunking import (
    calculate_adaptive_chunk_size,
    get_stratified_chunk_iterator,
)
from homodyne.utils.logging import get_logger

logger = get_logger(__name__)


def _effective_param_count_for_ooc(
    per_angle_scaling: bool,
    n_params: int,
    n_phi: int,
    n_physical: int,
    anti_degeneracy_config: dict | None = None,
) -> int:
    """Return the parameter count used for out-of-core covariance scaling."""
    if not per_angle_scaling:
        return n_params

    ad_config = anti_degeneracy_config or {}
    per_angle_mode = ad_config.get("per_angle_mode", "auto")
    threshold = int(ad_config.get("constant_scaling_threshold", 3))

    if per_angle_mode == "constant":
        return n_physical
    if per_angle_mode == "auto" and n_phi >= threshold and n_params == n_physical + 2:
        return 2 * n_phi + n_physical

    return n_params


[docs] def fit_with_out_of_core_accumulation( stratified_data: Any, data: Any, per_angle_scaling: bool, physical_param_names: list[str], initial_params: np.ndarray, bounds: tuple[np.ndarray, np.ndarray] | None, log: logging.Logger | logging.LoggerAdapter[logging.Logger], config: Any, fast_chi2_mode: bool = False, anti_degeneracy_config: dict | None = None, ) -> tuple[np.ndarray, np.ndarray, dict]: """Fit using Out-of-Core Global Accumulation for massive datasets. This strategy virtually chunks the dataset using Index-Based Stratification, accumulates the full Hessian and Gradient (J^T J, J^T r) by iterating over chunks, and takes a global Levenberg-Marquardt step. Guarantees identical convergence to standard NLSQ but with minimal memory. Note (v2.14.1+): This method now uses FULL homodyne physics via compute_g2_scaled(), identical to stratified least-squares. Anti-Degeneracy Defense System support is planned for a future release. Args: stratified_data: Stratified data object (unused, kept for API compat) data: Original XPCS data object with .phi, .t1, .t2, .g2, .q, .L per_angle_scaling: Whether per-angle scaling is enabled physical_param_names: Names of physical parameters initial_params: Initial parameter guess bounds: Parameter bounds (lower, upper) or None log: Logger instance config: Configuration object or dict fast_chi2_mode: If True, subsample chunks for chi2 evaluation anti_degeneracy_config: Anti-degeneracy configuration (reserved) Returns: (popt, pcov, info) tuple """ import jax.numpy as jnp _start_time = time.perf_counter() # noqa: F841 log.info( "Initializing Out-of-Core Global Stratified Optimization (Full Physics)..." ) # 1. Setup Chunking # Use StratifiedIndices if available (Zero-Copy) _use_index_based = False # noqa: F841 # We operate on the ORIGINAL flattened data to avoid pre-materializing # a giant stratified copy (which causes OOM). # We assume `data` object has .phi, .t1, .t2, .g2 # We need to flatten them carefully (using ravel/reshape to avoid copies if possible) # Helper to flatten dimensions def _get_flat_arrays( d: Any, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray | None]: # Same logic as _prepare_data but trying to be lazy/view-based phi_arr = np.asarray(d.phi) t1_arr = np.asarray(d.t1) t2_arr = np.asarray(d.t2) g2_arr = np.asarray(d.g2) sigma_arr = getattr(d, "sigma", None) # Extract 1D from meshgrids if needed (borrowed from _prepare_data) if t1_arr.ndim == 2 and t1_arr.size > 0: t1_arr = t1_arr[:, 0] if t2_arr.ndim == 2 and t2_arr.size > 0: t2_arr = t2_arr[0, :] phi_grid, t1_grid, t2_grid = np.meshgrid(phi_arr, t1_arr, t2_arr, indexing="ij") # Flatten sigma if available if sigma_arr is not None: sigma_arr = np.asarray(sigma_arr) sigma_flat = sigma_arr.ravel() else: sigma_flat = None # These flattens create copies usually, but for 25M points (200MB) it's acceptable ONCE # The OOM comes from creating SECOND and THIRD copies during stratification. return ( phi_grid.ravel(), t1_grid.ravel(), t2_grid.ravel(), g2_arr.ravel(), sigma_flat, ) phi_flat, t1_flat, t2_flat, g2_flat, sigma_flat = _get_flat_arrays(data) # Calculate optimal chunk size n_points = len(phi_flat) n_params = len(initial_params) n_angles = len(np.unique(phi_flat)) chunk_size = calculate_adaptive_chunk_size( total_points=n_points, n_params=n_params, n_angles=n_angles, safety_factor=5.0, ) # Get iterator that yields INDICES for stratified chunks # This allows us to pull stratified data from the flat arrays on demand iterator = get_stratified_chunk_iterator(phi_flat, chunk_size) log.info( f"Out-of-Core Strategy: {len(iterator)} chunks of size ~{chunk_size}\n" f" Pipeline: Chunk(Indices) -> Load -> JIT(Acc) -> Global Step" ) # Pre-compute unique phi for JAX mapping phi_unique = jnp.sort(jnp.unique(phi_flat)) # 2. Setup Optimization State params_curr = jnp.array(initial_params) cfg_dict = ( config.config if hasattr(config, "config") else (config if isinstance(config, dict) else {}) ) # Extract physics constants from data (v2.14.1+: Full homodyne physics) q_val = float(data.q) L_val = float(data.L) dt_raw = getattr(data, "dt", cfg_dict.get("dt", None)) if dt_raw is None: log.warning( "_fit_with_stratified_least_squares (OOC): dt not found in data or config; " "using dt=0.001 s as fallback." ) dt_val = 0.001 else: dt_val = float(dt_raw) # Extract global unique time arrays for meshgrid construction. # IMPORTANT: t1 and t2 must remain separate -- merging them into a single # union array creates a padded square grid (n_t x n_t) which is wrong # for non-symmetric XPCS data where n_t1 != n_t2. All flat-index # arithmetic downstream uses (n_t1, n_t2) as the grid shape. t1_unique_global = jnp.sort(jnp.unique(jnp.asarray(t1_flat))) t2_unique_global = jnp.sort(jnp.unique(jnp.asarray(t2_flat))) n_phi = len(phi_unique) n_t1 = len(t1_unique_global) n_t2 = len(t2_unique_global) n_physical = len(physical_param_names) # Effective parameter count for DOF in s^2 computation. # auto_averaged uses a compressed vector (contrast_avg, offset_avg, physical) # but consumes expanded DOF; constant mode keeps scaling fixed and must not # be expanded or covariance is over-inflated. n_params_effective = _effective_param_count_for_ooc( per_angle_scaling, n_params, n_phi, n_physical, anti_degeneracy_config, ) log.info( f"Full Physics Setup: n_phi={n_phi}, n_t1={n_t1}, n_t2={n_t2}, " f"q={q_val:.4e}, L={L_val:.4e}, dt={dt_val:.4e}" ) max_iter = cfg_dict.get("optimization", {}).get("max_iterations", 50) # Convergence tolerances (v2.22.0: multi-criteria, matching standard NLSQ) xtol = 1e-6 # Relative parameter change (per-component max, not norm) ftol = 1e-6 # Relative cost function change lm_lambda = 0.01 # Initial damping rel_change = float("inf") # Initialize to prevent NameError at loop exit cost_change = float("inf") # Initialize for multi-criteria convergence # ==================================================================== # JIT-compiled Chunk Kernels via factory (single source of truth) # ==================================================================== from homodyne.optimization.nlsq.parallel_accumulator import ( create_ooc_kernels, ) compute_chunk_accumulators, compute_chunk_chi2 = create_ooc_kernels( per_angle_scaling=per_angle_scaling, n_phi=n_phi, phi_unique=phi_unique, t1_unique_global=t1_unique_global, t2_unique_global=t2_unique_global, n_t1=n_t1, n_t2=n_t2, q_val=q_val, L_val=L_val, dt_val=dt_val, ) # Lazy import for parallel chunk accumulation from homodyne.optimization.nlsq.parallel_accumulator import ( OOCComputePool, OOCSharedArrays, accumulate_chunks_parallel, accumulate_chunks_sequential, should_use_parallel_accumulation, should_use_parallel_compute, ) # Create parallel compute pool if beneficial ooc_pool: OOCComputePool | None = None ooc_shared: OOCSharedArrays | None = None n_total_chunks = len(iterator) if should_use_parallel_compute(n_total_chunks): try: # Build chunk boundaries from the stratified iterator chunk_boundaries: list[tuple[int, int]] = [] # Flatten all indices in iterator order into a single array all_indices = [] offset = 0 for indices_chunk in iterator: all_indices.append(indices_chunk) chunk_boundaries.append((offset, offset + len(indices_chunk))) offset += len(indices_chunk) all_indices_arr = np.concatenate(all_indices) # Reorder flat arrays to match iterator order (contiguous chunks) phi_ordered = np.asarray(phi_flat)[all_indices_arr] t1_ordered = np.asarray(t1_flat)[all_indices_arr] t2_ordered = np.asarray(t2_flat)[all_indices_arr] g2_ordered = np.asarray(g2_flat)[all_indices_arr] sigma_ordered = ( np.asarray(sigma_flat)[all_indices_arr] if sigma_flat is not None else None ) ooc_shared = OOCSharedArrays( phi_ordered, t1_ordered, t2_ordered, g2_ordered, sigma_ordered, chunk_boundaries, ) physics_config = { "per_angle_scaling": per_angle_scaling, "n_phi": n_phi, "phi_unique": np.asarray(phi_unique), "t1_unique": np.asarray(t1_unique_global), "t2_unique": np.asarray(t2_unique_global), "n_t1": n_t1, "n_t2": n_t2, "q": q_val, "L": L_val, "dt": dt_val, } n_ooc_workers = max(1, min(4, os.cpu_count() or 1)) ooc_pool = OOCComputePool( n_workers=n_ooc_workers, shared_arrays=ooc_shared, physics_config=physics_config, chunk_boundaries=chunk_boundaries, threads_per_worker=max(1, (os.cpu_count() or 4) // n_ooc_workers), ) log.info( "Parallel OOC compute: %d chunks across %d workers", n_total_chunks, n_ooc_workers, ) except (OSError, RuntimeError, MemoryError) as exc: log.warning( "Parallel OOC pool creation failed (%s), using sequential", exc, ) if ooc_shared is not None: ooc_shared.cleanup() ooc_shared = None ooc_pool = None def evaluate_total_chi2(params_eval: Any) -> float: stride = 10 if fast_chi2_mode else 1 # Use parallel pool for chi2 evaluation when available if ooc_pool is not None: return ooc_pool.compute_chi2(np.asarray(params_eval), stride=stride) # Sequential fallback total_c2 = 0.0 eval_count = 0 for i, ind_c in enumerate(iterator): if i % stride != 0: continue p_c = phi_flat[ind_c] t1_c = t1_flat[ind_c] t2_c = t2_flat[ind_c] g2_c = g2_flat[ind_c] sigma_c = sigma_flat[ind_c] if sigma_flat is not None else 1.0 c2_chunk = compute_chunk_chi2(params_eval, p_c, t1_c, t2_c, g2_c, sigma_c) total_c2 += c2_chunk eval_count += 1 total_chunks = len(iterator) if eval_count > 0: scale = total_chunks / eval_count return total_c2 * scale return 0.0 # Use per-point sigma if available from data, otherwise unit weighting if sigma_flat is None: log.info("No per-point sigma available - using unit weighting for OOC") # Optimization Loop log.info(f"Starting Out-of-Core Loop (Max iter: {max_iter})...") # Track early convergence result for return after cleanup _early_result: tuple[np.ndarray, np.ndarray, dict] | None = None try: for i in range(max_iter): _iter_start = time.perf_counter() # noqa: F841 if ooc_pool is not None: # Parallel compute: dispatch all chunks to pool chunk_results = ooc_pool.compute_accumulators(np.asarray(params_curr)) count = sum( end - start for start, end in chunk_boundaries # noqa: F821 ) else: # Sequential compute: iterate chunks locally chunk_results_local: list[tuple[np.ndarray, np.ndarray, float]] = [] count = 0 for indices_chunk in iterator: phi_c = phi_flat[indices_chunk] t1_c = t1_flat[indices_chunk] t2_c = t2_flat[indices_chunk] g2_c = g2_flat[indices_chunk] sigma_c = ( sigma_flat[indices_chunk] if sigma_flat is not None else 1.0 ) JtJ, Jtr, chi2 = compute_chunk_accumulators( params_curr, phi_c, t1_c, t2_c, g2_c, sigma_c ) chunk_results_local.append( (np.asarray(JtJ), np.asarray(Jtr), float(chi2)) ) count += len(indices_chunk) chunk_results = chunk_results_local # Reduce chunk results (parallel reduction when beneficial) n_chunks = len(chunk_results) if n_chunks == 0: total_JtJ = jnp.zeros((n_params, n_params)) total_Jtr = jnp.zeros(n_params) total_chi2 = 0.0 elif should_use_parallel_accumulation(n_chunks): if i == 0: log.info( "Parallel chunk reduction: %d chunks", n_chunks, ) total_JtJ_np, total_Jtr_np, total_chi2, _ = accumulate_chunks_parallel( chunk_results, n_workers=max(1, min(4, n_chunks // 4)), ) total_JtJ = jnp.asarray(total_JtJ_np) total_Jtr = jnp.asarray(total_Jtr_np) else: if i == 0: log.debug( "Sequential chunk reduction: %d chunks", n_chunks, ) total_JtJ_np, total_Jtr_np, total_chi2, _ = ( accumulate_chunks_sequential(chunk_results) ) total_JtJ = jnp.asarray(total_JtJ_np) total_Jtr = jnp.asarray(total_Jtr_np) # Robust Levenberg-Marquardt Step Loop step_accepted = False # Check for invalid Jacobian/Residuals if jnp.any(jnp.isnan(total_Jtr)) or jnp.any(jnp.isinf(total_JtJ)): log.warning("Gradient/Hessian contains NaNs/Infs. Checking params.") if i == 0: raise RuntimeError("Initial parameters produced invalid gradients.") break diag_idx = jnp.diag_indices_from(total_JtJ) for _lm_iter in range(10): # Max dampings per iter solver_matrix = total_JtJ.at[diag_idx].add( lm_lambda * jnp.diag(total_JtJ) ) try: # use lstsq for robustness against singular matrices step, _, _, _ = jnp.linalg.lstsq( solver_matrix, -total_Jtr, rcond=1e-5 ) except (ValueError, RuntimeError, FloatingPointError): step = jnp.full_like(total_Jtr, jnp.nan) # Signal fail # Check step validity if jnp.any(jnp.isnan(step)): log.warning( f"Bad step (NaN). Increasing damping ({lm_lambda:.1e} -> {lm_lambda * 10:.1e})" ) lm_lambda *= 10 continue # Proposed parameters params_new = params_curr + step # Clip if bounds is not None: lower, upper = bounds params_new = jnp.clip( params_new, jnp.asarray(lower), jnp.asarray(upper) ) # Evaluate New Cost try: chi2_new = evaluate_total_chi2(params_new) except (ValueError, RuntimeError, FloatingPointError) as e: log.warning(f"Eval failed: {e}") chi2_new = jnp.inf # Acceptance check if chi2_new < total_chi2: # Accept ratio = (total_chi2 - chi2_new) / total_chi2 log.info( f"Iter {i + 1}: chi2={float(chi2_new):.4e} (dec {ratio:.1%}), " f"lambda={lm_lambda:.1e}" ) params_curr = params_new lm_lambda *= 0.1 # Decrease damping (trust more) if lm_lambda < 1e-7: lm_lambda = 1e-7 step_accepted = True # Multi-criteria convergence (v2.22.0) # 1. Per-component relative parameter change (scale-invariant) param_scale = jnp.maximum(jnp.abs(params_curr), 1e-10) rel_change = float(jnp.max(jnp.abs(step) / param_scale)) # 2. Relative cost function change cost_change = float(ratio) log.debug( f" Convergence: xtol={rel_change:.2e} " f"(thresh={xtol:.0e}), " f"ftol={cost_change:.2e} " f"(thresh={ftol:.0e})" ) if rel_change < xtol and cost_change < ftol: log.info( f"Out-of-Core converged: xtol={rel_change:.2e}<{xtol:.0e}, " f"ftol={cost_change:.2e}<{ftol:.0e}" ) s2 = float(chi2_new) / max(count - n_params_effective, 1) try: pcov = s2 * np.linalg.inv(np.array(total_JtJ)) except np.linalg.LinAlgError: log.warning( "Singular J^T J in OOC - using pseudo-inverse for covariance" ) pcov = s2 * np.linalg.pinv(np.array(total_JtJ)) _early_result = ( np.array(params_curr), pcov, { "chi_squared": float(chi2_new), "iterations": i + 1, "convergence_status": "converged", "message": "Out-of-Core converged (xtol+ftol)", }, ) break break # Break inner LM loop, proceed to next accumulation else: # Reject log.debug( f"Reject step (chi2 {float(chi2_new):.4e} >= {float(total_chi2):.4e}). Damping up." ) lm_lambda *= 10 if _early_result is not None: break if not step_accepted: log.warning("Could not find better step. Stopping.") break finally: # Clean up parallel compute pool and shared memory if ooc_pool is not None: ooc_pool.shutdown() if ooc_shared is not None: ooc_shared.cleanup() if _early_result is not None: return _early_result # Determine final status (rel_change initialized to inf before loop) converged = rel_change < xtol and cost_change < ftol info = { "chi_squared": float(total_chi2), "iterations": i + 1, "convergence_status": "converged" if converged else "max_iter", "message": "Out-of-Core accumulation completed", } # pcov = s^2 * (J^T J)^{-1} where s^2 = RSS / (n - p_effective) # Uses n_params_effective for correct DOF in auto_averaged mode. s2 = float(total_chi2) / max(count - n_params_effective, 1) try: pcov = s2 * np.linalg.inv(np.array(total_JtJ)) except np.linalg.LinAlgError: log.warning("Singular J^T J in OOC - using pseudo-inverse for covariance") pcov = s2 * np.linalg.pinv(np.array(total_JtJ)) return np.array(params_curr), pcov, info