"""Multi-start NLSQ optimization with Latin Hypercube Sampling.
This module implements multi-start optimization to explore the parameter space
and avoid local minima. All datasets use the FULL strategy (N complete fits).
NOTE: Subsampling is explicitly NOT supported per project requirements.
Numerical precision and reproducibility take priority over computational speed.
Part of homodyne v2.6.0 architecture.
"""
from __future__ import annotations
import multiprocessing
import os
import time
from collections.abc import Callable
from concurrent.futures import ProcessPoolExecutor, TimeoutError, as_completed
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import numpy as np
from scipy.stats import qmc
from homodyne.optimization.nlsq.progress import MultiStartProgressTracker
from homodyne.utils.logging import get_logger
# Timeout for individual worker results (seconds)
# If a worker doesn't return within this time, we fall back to sequential
_WORKER_TIMEOUT = 1800 # 30 minutes per worker
# Maximum data points for parallel execution
# Beyond this, sequential execution is used to avoid memory/serialization issues
# 5 workers × 1M points × ~100 bytes/point = ~500MB serialization overhead per worker
_MAX_POINTS_FOR_PARALLEL = 500_000 # 500K points
if TYPE_CHECKING:
from numpy.typing import NDArray
from homodyne.optimization.nlsq.results import OptimizationResult
logger = get_logger(__name__)
# =============================================================================
# Configuration Dataclass
# =============================================================================
[docs]
@dataclass
class MultiStartConfig:
"""Configuration for multi-start optimization.
Attributes
----------
enable : bool
Whether to use multi-start optimization. Default: False.
n_starts : int
Number of starting points to generate. Default: 10.
seed : int
Random seed for reproducibility. Default: 42.
sampling_strategy : str
Method for generating starting points. "latin_hypercube" or "random".
custom_starts : list[list[float]] | None
User-provided custom starting points to include alongside generated starts.
n_workers : int
Number of parallel workers. 0 = auto (min of n_starts, cpu_count).
use_screening : bool
Whether to pre-filter starting points by initial cost.
screen_keep_fraction : float
Fraction of starting points to keep after screening.
refine_top_k : int
Number of top solutions to refine with tighter tolerance.
refinement_ftol : float
Function tolerance for refinement phase.
degeneracy_threshold : float
Chi-squared similarity threshold for degeneracy detection.
"""
enable: bool = False
n_starts: int = 10
seed: int = 42
sampling_strategy: str = "latin_hypercube"
custom_starts: list[list[float]] | None = None
n_workers: int = 0
use_screening: bool = True
screen_keep_fraction: float = 0.5
refine_top_k: int = 3
refinement_ftol: float = 1e-12
degeneracy_threshold: float = 0.1
[docs]
@classmethod
def from_nlsq_config(cls, nlsq_config: Any) -> MultiStartConfig:
"""Create MultiStartConfig from NLSQConfig.
Parameters
----------
nlsq_config : NLSQConfig
NLSQ configuration object.
Returns
-------
MultiStartConfig
Multi-start configuration.
"""
# Handle custom_starts if present in NLSQConfig
custom_starts = getattr(nlsq_config, "multi_start_custom_starts", None)
return cls(
enable=nlsq_config.enable_multi_start,
n_starts=nlsq_config.multi_start_n_starts,
seed=nlsq_config.multi_start_seed,
sampling_strategy=nlsq_config.multi_start_sampling_strategy,
custom_starts=custom_starts,
n_workers=nlsq_config.multi_start_n_workers,
use_screening=nlsq_config.multi_start_use_screening,
screen_keep_fraction=nlsq_config.multi_start_screen_keep_fraction,
refine_top_k=nlsq_config.multi_start_refine_top_k,
refinement_ftol=nlsq_config.multi_start_refinement_ftol,
degeneracy_threshold=nlsq_config.multi_start_degeneracy_threshold,
)
[docs]
def to_nlsq_global_config(self) -> Any:
"""Convert to NLSQ's GlobalOptimizationConfig.
Returns
-------
GlobalOptimizationConfig
NLSQ global optimization configuration.
Raises
------
ImportError
If NLSQ GlobalOptimizationConfig is not available.
Notes
-----
Maps homodyne's multi-start configuration to NLSQ's GlobalOptimizationConfig:
- sampling_strategy -> sampler (lhs, sobol, halton)
- use_screening -> elimination_rounds (0 if disabled)
- screen_keep_fraction -> elimination_fraction (inverted)
"""
try:
from nlsq.global_optimization import GlobalOptimizationConfig
except ImportError as e:
raise ImportError(
"NLSQ GlobalOptimizationConfig not available. "
"Please install NLSQ >= 0.4.0: pip install nlsq>=0.4.0"
) from e
# Map sampling strategy to NLSQ sampler
sampler_map = {
"latin_hypercube": "lhs",
"lhs": "lhs",
"sobol": "sobol",
"halton": "halton",
"random": "lhs", # Fallback random to LHS
}
sampler = sampler_map.get(self.sampling_strategy, "lhs")
# Map screening to elimination rounds
# screen_keep_fraction=0.5 means 50% kept = 50% eliminated
elimination_fraction = 1.0 - self.screen_keep_fraction
elimination_rounds = 3 if self.use_screening else 0
return GlobalOptimizationConfig(
n_starts=self.n_starts,
sampler=sampler,
elimination_rounds=elimination_rounds,
elimination_fraction=elimination_fraction,
)
# =============================================================================
# Result Dataclasses
# =============================================================================
[docs]
@dataclass
class SingleStartResult:
"""Result from a single starting point optimization.
Attributes
----------
start_idx : int
Index of the starting point in the LHS sequence.
initial_params : NDArray[np.float64]
Initial parameter values used.
final_params : NDArray[np.float64]
Optimized parameter values.
chi_squared : float
Final chi-squared value.
reduced_chi_squared : float
Chi-squared divided by degrees of freedom.
success : bool
Whether optimization converged successfully.
status : int
Optimizer status code.
message : str
Optimizer status message.
n_iterations : int
Number of optimization iterations.
n_fev : int
Number of function evaluations.
wall_time : float
Execution time in seconds.
hessian : NDArray[np.float64] | None
Hessian matrix at solution (for CMC initialization).
covariance : NDArray[np.float64] | None
Parameter covariance matrix.
jacobian : NDArray[np.float64] | None
Final Jacobian matrix.
"""
start_idx: int
initial_params: NDArray[np.float64]
final_params: NDArray[np.float64]
chi_squared: float
reduced_chi_squared: float = float("inf")
success: bool = False
status: int = 0
message: str = ""
n_iterations: int = 0
n_fev: int = 0
wall_time: float = 0.0
hessian: NDArray[np.float64] | None = None
covariance: NDArray[np.float64] | None = None
jacobian: NDArray[np.float64] | None = None
[docs]
@dataclass
class MultiStartResult:
"""Aggregated results from multi-start optimization.
Attributes
----------
best : SingleStartResult
Best result by chi-squared.
all_results : list[SingleStartResult]
All optimization results.
config : MultiStartConfig
Configuration used.
strategy_used : str
Strategy that was used (always "full").
n_successful : int
Number of successful optimizations.
n_unique_basins : int
Number of distinct local minima found.
degeneracy_detected : bool
Whether parameter degeneracy was detected.
total_wall_time : float
Total execution time in seconds.
screening_costs : NDArray[np.float64] | None
Initial costs from screening phase.
basin_labels : NDArray[np.int64] | None
Cluster labels for each result.
"""
best: SingleStartResult
all_results: list[SingleStartResult]
config: MultiStartConfig
strategy_used: str
n_successful: int = 0
n_unique_basins: int = 1
degeneracy_detected: bool = False
total_wall_time: float = 0.0
screening_costs: NDArray[np.float64] | None = None
basin_labels: NDArray[np.int64] | None = None
[docs]
def to_optimization_result(self) -> OptimizationResult:
"""Convert MultiStartResult to OptimizationResult for CLI compatibility.
Returns
-------
OptimizationResult
Optimization result object containing the best solution with
multi-start metadata in nlsq_diagnostics.
"""
from homodyne.optimization.nlsq.results import OptimizationResult
best = self.best
n_params = len(best.final_params)
# Determine convergence status
if best.success:
convergence_status = "converged"
else:
convergence_status = "failed"
# Determine quality flag based on chi-squared
if best.reduced_chi_squared < 2.0:
quality_flag = "good"
elif best.reduced_chi_squared < 10.0:
quality_flag = "marginal"
else:
quality_flag = "poor"
# Build multi-start diagnostics
multistart_diagnostics = {
"strategy_used": self.strategy_used,
"n_starts": len(self.all_results),
"n_successful": self.n_successful,
"n_unique_basins": self.n_unique_basins,
"degeneracy_detected": self.degeneracy_detected,
"total_wall_time": self.total_wall_time,
"best_start_idx": best.start_idx,
}
return OptimizationResult(
parameters=best.final_params,
uncertainties=(
np.sqrt(np.diag(best.covariance))
if best.covariance is not None
else np.zeros(n_params)
),
covariance=(
best.covariance if best.covariance is not None else np.eye(n_params)
),
chi_squared=best.chi_squared,
reduced_chi_squared=best.reduced_chi_squared,
convergence_status=convergence_status,
iterations=best.n_iterations,
execution_time=self.total_wall_time,
device_info={"type": "cpu", "multistart": True},
recovery_actions=[],
quality_flag=quality_flag,
nlsq_diagnostics=multistart_diagnostics,
)
# =============================================================================
# Helper Functions
# =============================================================================
def _get_phi_from_data(data: dict[str, Any]) -> NDArray[np.float64] | None:
"""Extract phi array from data dictionary, handling numpy array truthiness.
Parameters
----------
data : dict
Data dictionary that may contain 'phi' or 'phi_angles_list'.
Returns
-------
NDArray | None
Phi array if found, None otherwise.
"""
phi = data.get("phi")
if phi is not None:
return np.asarray(phi)
phi = data.get("phi_angles_list")
if phi is not None:
return np.asarray(phi)
return None
def _get_dataset_size(data: dict[str, Any]) -> int:
"""Calculate total number of data points from data dictionary.
This function handles both test fixtures (where 'phi' is flattened with
repeated angles) and actual XPCS data (where 'phi_angles_list' contains
only unique angles and 'g2'/'c2_exp' contains the actual 3D data).
Parameters
----------
data : dict
Data dictionary that may contain:
- 'g2' or 'c2_exp': Experimental data array (n_phi, n_t1, n_t2)
- 'phi' or 'phi_angles_list': Phi angles
Returns
-------
int
Total number of data points.
Raises
------
ValueError
If no valid data array is found.
"""
# First, try to get size from actual data arrays (most reliable)
for key in ("g2", "c2_exp"):
arr = data.get(key)
if arr is not None:
arr = np.asarray(arr)
return int(arr.size)
# Fallback: calculate from phi array
# This handles test fixtures where phi is already flattened
phi = _get_phi_from_data(data)
if phi is not None:
return len(np.asarray(phi).ravel())
raise ValueError("Cannot determine dataset size: no 'g2', 'c2_exp', or 'phi' found")
# =============================================================================
# Core Functions: Bounds Validation
# =============================================================================
[docs]
def check_zero_volume_bounds(bounds: NDArray[np.float64]) -> bool:
"""Check if parameter bounds have zero volume (all lower == upper).
Parameters
----------
bounds : NDArray[np.float64]
Parameter bounds as (n_params, 2) array with [lower, upper] for each.
Returns
-------
bool
True if bounds have zero volume (all parameters fixed).
"""
lower = bounds[:, 0]
upper = bounds[:, 1]
widths = upper - lower
# Zero volume if all widths are effectively zero
return np.all(np.abs(widths) < 1e-15)
[docs]
def validate_n_starts_for_lhs(
n_starts: int,
n_params: int,
warn: bool = True,
) -> int:
"""Validate n_starts for Latin Hypercube Sampling coverage.
For LHS to provide meaningful coverage, n_starts should be at least n_params.
Very large n_starts relative to parameter space may produce redundant samples.
Parameters
----------
n_starts : int
Requested number of starting points.
n_params : int
Number of parameters (dimensions).
warn : bool
Whether to emit warnings for suboptimal settings.
Returns
-------
int
Validated n_starts (unchanged if valid).
"""
# Minimum recommended: at least n_params for basic coverage
if n_starts < n_params and warn:
logger.warning(
f"n_starts ({n_starts}) < n_params ({n_params}): "
f"LHS coverage may be inadequate. Consider n_starts >= {n_params}."
)
# Very large n_starts warning (heuristic: >1000 per dimension is likely excessive)
max_meaningful = n_params * 1000
if n_starts > max_meaningful and warn:
logger.warning(
f"n_starts ({n_starts}) is very large for {n_params} parameters. "
f"This may produce redundant samples with diminishing returns. "
f"Consider n_starts <= {max_meaningful}."
)
return n_starts
# =============================================================================
# Core Functions: LHS Generation
# =============================================================================
[docs]
def generate_lhs_starts(
bounds: NDArray[np.float64],
n_starts: int,
seed: int = 42,
) -> NDArray[np.float64]:
"""Generate starting points via Latin Hypercube Sampling.
Parameters
----------
bounds : NDArray[np.float64]
Parameter bounds as (n_params, 2) array with [lower, upper] for each.
n_starts : int
Number of starting points to generate.
seed : int
Random seed for reproducibility.
Returns
-------
NDArray[np.float64]
Starting points as (n_starts, n_params) array.
"""
n_params = bounds.shape[0]
# Use scipy's Latin Hypercube Sampling
sampler = qmc.LatinHypercube(d=n_params, seed=seed, optimization="random-cd")
unit_samples = sampler.random(n=n_starts) # Samples in [0, 1]^d
# Scale to parameter bounds
lower = bounds[:, 0]
upper = bounds[:, 1]
scaled_samples = qmc.scale(unit_samples, lower, upper)
logger.debug(f"Generated {n_starts} LHS starting points for {n_params} parameters")
return scaled_samples
[docs]
def generate_random_starts(
bounds: NDArray[np.float64],
n_starts: int,
seed: int = 42,
) -> NDArray[np.float64]:
"""Generate starting points via random uniform sampling.
Parameters
----------
bounds : NDArray[np.float64]
Parameter bounds as (n_params, 2) array.
n_starts : int
Number of starting points to generate.
seed : int
Random seed for reproducibility.
Returns
-------
NDArray[np.float64]
Starting points as (n_starts, n_params) array.
"""
rng = np.random.default_rng(seed)
n_params = bounds.shape[0]
lower = bounds[:, 0]
upper = bounds[:, 1]
samples = rng.uniform(lower, upper, size=(n_starts, n_params))
logger.debug(
f"Generated {n_starts} random starting points for {n_params} parameters"
)
return samples
[docs]
def include_custom_starts(
generated_starts: NDArray[np.float64],
custom_starts: list[list[float]] | NDArray[np.float64] | None,
bounds: NDArray[np.float64],
) -> NDArray[np.float64]:
"""Include user-provided custom starting points alongside generated starts.
Custom starting points are prepended to the generated starts so they are
always included (not filtered by screening).
Parameters
----------
generated_starts : NDArray[np.float64]
Starting points generated by LHS or random sampling.
custom_starts : list[list[float]] | NDArray[np.float64] | None
User-provided custom starting points.
bounds : NDArray[np.float64]
Parameter bounds for validation.
Returns
-------
NDArray[np.float64]
Combined starting points with custom starts first.
"""
if custom_starts is None or len(custom_starts) == 0:
return generated_starts
custom_array = np.asarray(custom_starts, dtype=np.float64)
# Validate dimensions
n_params = bounds.shape[0]
if custom_array.ndim == 1:
custom_array = custom_array.reshape(1, -1)
if custom_array.shape[1] != n_params:
logger.warning(
f"Custom starts have wrong dimension: {custom_array.shape[1]} != {n_params}. "
f"Ignoring custom starts."
)
return generated_starts
# Validate bounds
lower = bounds[:, 0]
upper = bounds[:, 1]
n_custom = len(custom_array)
valid_mask = np.all((custom_array >= lower) & (custom_array <= upper), axis=1)
n_valid = np.sum(valid_mask)
if n_valid < n_custom:
n_invalid = n_custom - n_valid
logger.warning(
f"{n_invalid} custom starting point(s) are outside bounds and will be skipped."
)
custom_array = custom_array[valid_mask]
if len(custom_array) == 0:
return generated_starts
logger.info(f"Including {len(custom_array)} custom starting point(s)")
# Prepend custom starts so they're always included
combined = np.vstack([custom_array, generated_starts])
return combined
# =============================================================================
# Core Functions: Screening
# =============================================================================
[docs]
def screen_starts(
cost_func: Callable[[NDArray[np.float64]], float],
starts: NDArray[np.float64],
keep_fraction: float = 0.5,
min_keep: int = 3,
n_workers: int = 0,
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
"""Pre-filter starting points by initial cost.
Parameters
----------
cost_func : Callable
Function that computes cost (chi-squared) for a parameter vector.
starts : NDArray[np.float64]
Starting points as (n_starts, n_params) array.
keep_fraction : float
Fraction of starting points to keep (0, 1].
min_keep : int
Minimum number of starting points to keep.
n_workers : int
Number of parallel workers for cost evaluation. 0 = auto (cpu_count - 1).
Returns
-------
tuple[NDArray[np.float64], NDArray[np.float64]]
Filtered starting points and their initial costs.
"""
n_starts = len(starts)
n_keep = max(min_keep, int(n_starts * keep_fraction))
n_keep = min(n_keep, n_starts) # Don't keep more than we have
# Evaluate initial cost for each starting point
# Use parallel execution for faster screening (4x speedup typical)
if n_workers == 0:
n_workers = max(1, (os.cpu_count() or 1) - 1)
# Only parallelize if we have enough starts to benefit
if n_starts >= 4 and n_workers > 1:
try:
from concurrent.futures import ThreadPoolExecutor
# Use threads since cost_func likely releases GIL (JAX/NumPy)
with ThreadPoolExecutor(max_workers=n_workers) as executor:
costs = np.array(list(executor.map(cost_func, starts)))
except (RuntimeError, OSError, ValueError) as e:
logger.warning(
f"Parallel screening failed, falling back to sequential: {e}"
)
costs = np.array([cost_func(start) for start in starts])
else:
costs = np.array([cost_func(start) for start in starts])
# Sort by cost and keep top n_keep
sorted_indices = np.argsort(costs)
keep_indices = sorted_indices[:n_keep]
filtered_starts = starts[keep_indices]
filtered_costs = costs[keep_indices]
logger.info(
f"Screening: kept {n_keep}/{n_starts} starts "
f"(best cost: {filtered_costs[0]:.4g}, worst kept: {filtered_costs[-1]:.4g})"
)
return filtered_starts, costs
# =============================================================================
# Core Functions: Degeneracy Detection
# =============================================================================
[docs]
def detect_degeneracy(
results: list[SingleStartResult],
chi_sq_threshold: float = 0.1,
param_threshold: float = 0.2,
) -> tuple[bool, int, NDArray[np.int64] | None]:
"""Detect parameter degeneracy from multiple optimization results.
Parameters
----------
results : list[SingleStartResult]
List of optimization results.
chi_sq_threshold : float
Maximum relative chi-squared difference to consider similar.
param_threshold : float
Maximum relative parameter distance to consider same basin.
Returns
-------
tuple[bool, int, NDArray[np.int64] | None]
(degeneracy_detected, n_unique_basins, basin_labels)
"""
# Filter successful results
successful = [r for r in results if r.success]
if len(successful) < 2:
return False, 1, None
# Sort by chi-squared
successful.sort(key=lambda r: r.chi_squared)
best_chi_sq = successful[0].chi_squared
# Cluster into basins
basins: list[list[SingleStartResult]] = []
basin_assignments: list[int] = []
for r in successful:
# Check chi-squared similarity
chi_sq_diff = abs(r.chi_squared - best_chi_sq) / (best_chi_sq + 1e-10)
if chi_sq_diff > chi_sq_threshold:
# Not similar enough, assign to "other" basin
basin_assignments.append(-1)
continue
# Check parameter distance to existing basins
r_params = r.final_params
found_basin = False
for basin_idx, basin in enumerate(basins):
basin_center = basin[0].final_params
param_dist = np.linalg.norm(r_params - basin_center) / (
np.linalg.norm(basin_center) + 1e-10
)
if param_dist < param_threshold:
basin.append(r)
basin_assignments.append(basin_idx)
found_basin = True
break
if not found_basin:
basins.append([r])
basin_assignments.append(len(basins) - 1)
n_unique_basins = len(basins)
degeneracy_detected = n_unique_basins > 1
# Create labels array
labels = np.array(basin_assignments, dtype=np.int64)
if degeneracy_detected:
logger.warning(
f"Parameter degeneracy detected: {n_unique_basins} distinct basins "
f"with similar chi-squared values"
)
return degeneracy_detected, n_unique_basins, labels
# =============================================================================
# Core Functions: Parallel Execution
# =============================================================================
[docs]
def get_n_workers(config: MultiStartConfig, n_starts: int) -> int:
"""Determine number of parallel workers.
Parameters
----------
config : MultiStartConfig
Multi-start configuration.
n_starts : int
Number of starting points.
Returns
-------
int
Number of workers to use.
"""
if config.n_workers > 0:
n_workers = config.n_workers
else:
n_workers = os.cpu_count() or 4
# Don't use more workers than starts
n_workers = min(n_workers, n_starts)
logger.debug(f"Using {n_workers} parallel workers for {n_starts} starts")
return n_workers
def _run_sequential(
optimize_func: Callable[[int, NDArray[np.float64]], SingleStartResult],
starts: NDArray[np.float64],
) -> list[SingleStartResult]:
"""Run optimizations sequentially."""
results: list[SingleStartResult] = []
for idx, start in enumerate(starts):
try:
result = optimize_func(idx, start)
results.append(result)
except (ValueError, RuntimeError, OSError) as e:
logger.warning(f"Start {idx} failed: {e}")
results.append(
SingleStartResult(
start_idx=idx,
initial_params=start,
final_params=start,
chi_squared=np.inf,
success=False,
message=str(e),
)
)
return results
def _is_pickle_error(error_msg: str) -> bool:
"""Check if an error message indicates a pickle/serialization issue."""
pickle_indicators = [
"pickle",
"local object",
"can't get local",
"cannot serialize",
"attributeerror",
]
error_lower = error_msg.lower()
return any(indicator in error_lower for indicator in pickle_indicators)
# =============================================================================
# Main Orchestration Function
# =============================================================================
[docs]
def run_multistart_nlsq(
data: dict[str, Any],
bounds: NDArray[np.float64],
config: MultiStartConfig,
single_fit_func: Callable[[dict[str, Any], NDArray[np.float64]], SingleStartResult],
cost_func: Callable[[NDArray[np.float64]], float] | None = None,
custom_starts: list[list[float]] | NDArray[np.float64] | None = None,
) -> MultiStartResult:
"""Run multi-start NLSQ optimization with FULL strategy.
NOTE: Only FULL strategy is supported. Subsampling is explicitly NOT used
per project requirements - numerical precision takes priority over speed.
Parameters
----------
data : dict
XPCS data dictionary.
bounds : NDArray[np.float64]
Parameter bounds as (n_params, 2) array.
config : MultiStartConfig
Multi-start configuration.
single_fit_func : Callable
Function that runs a single NLSQ fit.
Signature: (data, initial_params) -> SingleStartResult
cost_func : Callable, optional
Function that computes cost for screening.
Signature: (params) -> float
custom_starts : list[list[float]] | NDArray, optional
User-provided custom starting points (overrides config.custom_starts).
Returns
-------
MultiStartResult
Aggregated results from all starting points.
"""
start_time = time.perf_counter()
# Log configuration
logger.info("=" * 60)
logger.info("MULTI-START NLSQ OPTIMIZATION")
logger.info("=" * 60)
logger.info(
f"Configuration: n_starts={config.n_starts}, "
f"sampling={config.sampling_strategy}, seed={config.seed}"
)
logger.info(
f"Options: screening={'ON' if config.use_screening else 'OFF'}, "
f"keep_fraction={config.screen_keep_fraction:.0%}, "
f"n_workers={config.n_workers if config.n_workers > 0 else 'auto'}"
)
# Check for zero-volume bounds (all parameters fixed)
if check_zero_volume_bounds(bounds):
logger.warning(
"Parameter bounds have zero volume (all lower == upper). "
"Falling back to single-start at bounds center."
)
center = (bounds[:, 0] + bounds[:, 1]) / 2
result = single_fit_func(data, center)
result.start_idx = 0
result.initial_params = center
return MultiStartResult(
best=result,
all_results=[result],
config=config,
strategy_used="single_start_fallback",
n_successful=1 if result.success else 0,
n_unique_basins=1,
degeneracy_detected=False,
total_wall_time=time.perf_counter() - start_time,
)
# Determine dataset size from data arrays (not just phi array length)
n_points = _get_dataset_size(data)
n_params = bounds.shape[0]
# Log dataset and parameter info
logger.info(f"Dataset: {n_points:,} total data points")
logger.info(f"Parameters: {n_params} free parameters")
logger.debug(f"Parameter bounds:\n{bounds}")
# Always use FULL strategy - no subsampling
logger.info("Strategy: FULL (all starting points run complete optimization)")
# Validate n_starts for LHS
validate_n_starts_for_lhs(config.n_starts, n_params)
# Generate starting points
logger.info("-" * 40)
logger.info("PHASE 1: Generating starting points")
logger.info("-" * 40)
if config.sampling_strategy == "latin_hypercube":
logger.info(
f"Using Latin Hypercube Sampling (n={config.n_starts}, seed={config.seed})"
)
starts = generate_lhs_starts(bounds, config.n_starts, config.seed)
else:
logger.info(
f"Using random uniform sampling (n={config.n_starts}, seed={config.seed})"
)
starts = generate_random_starts(bounds, config.n_starts, config.seed)
logger.info(f"Generated {len(starts)} starting points")
# Include custom starting points (from argument or config)
custom = custom_starts if custom_starts is not None else config.custom_starts
n_before = len(starts)
starts = include_custom_starts(starts, custom, bounds)
n_added = len(starts) - n_before
if n_added > 0:
logger.info(f"Added {n_added} custom starting point(s), total: {len(starts)}")
# Screening phase (optional)
screening_costs = None
if config.use_screening and cost_func is not None:
logger.info("-" * 40)
logger.info("PHASE 2: Screening starting points")
logger.info("-" * 40)
n_before_screen = len(starts)
starts, screening_costs = screen_starts(
cost_func, starts, config.screen_keep_fraction
)
n_filtered = n_before_screen - len(starts)
logger.info(f"Screening filtered {n_filtered} starts, keeping {len(starts)}")
else:
logger.debug("Screening disabled, proceeding with all starting points")
# Get worker count
n_workers = get_n_workers(config, len(starts))
cpu_count = os.cpu_count() or 1
# Execute FULL strategy (N complete fits)
logger.info("-" * 40)
logger.info("PHASE 3: Running optimizations")
logger.info("-" * 40)
logger.info(
f"Starting {len(starts)} optimizations with "
f"{n_workers} worker(s) (CPUs available: {cpu_count})"
)
if n_points > _MAX_POINTS_FOR_PARALLEL:
logger.info(
f"Note: Large dataset ({n_points:,} > {_MAX_POINTS_FOR_PARALLEL:,}), "
f"forcing sequential execution to avoid serialization overhead"
)
results = _run_full_strategy(
data,
starts,
single_fit_func,
n_workers,
enable_progress_bar=True, # Always show progress for multi-start
verbose=1,
)
# Find best result
logger.info("-" * 40)
logger.info("PHASE 4: Analyzing results")
logger.info("-" * 40)
successful = [r for r in results if r.success]
failed = [r for r in results if not r.success]
if not successful:
logger.error("All multi-start optimizations failed!")
for r in failed[:5]: # Show first 5 failures
logger.error(f" Start {r.start_idx}: {r.message}")
if len(failed) > 5:
logger.error(f" ... and {len(failed) - 5} more failures")
best = (
results[0]
if results
else SingleStartResult(
start_idx=0,
initial_params=starts[0],
final_params=starts[0],
chi_squared=np.inf,
success=False,
message="All optimizations failed",
)
)
else:
best = min(successful, key=lambda r: r.chi_squared)
logger.info(f"Successful optimizations: {len(successful)}/{len(results)}")
if failed:
logger.warning(f"Failed optimizations: {len(failed)}")
for r in failed:
logger.debug(f" Start {r.start_idx} failed: {r.message}")
# Degeneracy detection
degeneracy_detected, n_unique_basins, basin_labels = detect_degeneracy(
results, config.degeneracy_threshold
)
total_time = time.perf_counter() - start_time
# Final summary
logger.info("=" * 60)
logger.info("MULTI-START OPTIMIZATION COMPLETE")
logger.info("=" * 60)
logger.info(f"Best result: chi2={best.chi_squared:.6e} (start {best.start_idx})")
logger.info(f"Best reduced chi2: {best.reduced_chi_squared:.6f}")
logger.info(f"Unique basins found: {n_unique_basins}")
if degeneracy_detected:
logger.warning(
f"DEGENERACY DETECTED: {n_unique_basins} distinct minima with similar chi2"
)
logger.info(f"Total wall time: {total_time:.1f}s")
logger.info("=" * 60)
return MultiStartResult(
best=best,
all_results=results,
config=config,
strategy_used="full",
n_successful=len(successful),
n_unique_basins=n_unique_basins,
degeneracy_detected=degeneracy_detected,
total_wall_time=total_time,
screening_costs=screening_costs,
basin_labels=basin_labels,
)
# =============================================================================
# Strategy Implementation
# =============================================================================
class _OptimizeWorker:
"""Picklable worker class for parallel optimization.
This class wraps the single_fit_func and data, making them picklable
for use with ProcessPoolExecutor. Unlike nested closures, class instances
with __call__ can be pickled as long as their attributes are picklable.
"""
def __init__(
self,
data: dict[str, Any],
single_fit_func: Callable[
[dict[str, Any], NDArray[np.float64]], SingleStartResult
],
) -> None:
self.data = data
self.single_fit_func = single_fit_func
def __call__(self, idx: int, start: NDArray[np.float64]) -> SingleStartResult:
"""Run optimization for a single starting point."""
start_time = time.perf_counter()
result = self.single_fit_func(self.data, start)
result.start_idx = idx
result.initial_params = start
result.wall_time = time.perf_counter() - start_time
return result
def _run_full_strategy(
data: dict[str, Any],
starts: NDArray[np.float64],
single_fit_func: Callable[[dict[str, Any], NDArray[np.float64]], SingleStartResult],
n_workers: int,
enable_progress_bar: bool = True,
verbose: int = 1,
) -> list[SingleStartResult]:
"""Full multi-start: run N complete fits in parallel.
This is the ONLY supported strategy. No subsampling is performed.
For large datasets (>500K points), sequential execution is used to avoid
memory/serialization overhead from sending data to worker processes.
Parameters
----------
data : dict
XPCS data dictionary.
starts : NDArray
Starting points as (n_starts, n_params) array.
single_fit_func : Callable
Function to run single NLSQ fit.
n_workers : int
Number of parallel workers.
enable_progress_bar : bool
Whether to show progress bar.
verbose : int
Verbosity level.
Returns
-------
list[SingleStartResult]
Results from all starting points.
"""
n_starts = len(starts)
execution_mode = "parallel"
# Check dataset size - force sequential for large datasets
# Parallel execution with large data causes serialization overhead and hangs
n_points = _get_dataset_size(data)
if n_points > _MAX_POINTS_FOR_PARALLEL and n_workers > 1:
# Note: The caller (run_multistart_nlsq) already logs this decision
n_workers = 1
execution_mode = "sequential (large dataset)"
elif n_workers == 1:
execution_mode = "sequential"
logger.debug(f"Execution mode: {execution_mode}, workers: {n_workers}")
# Use a picklable worker class instead of a closure
worker = _OptimizeWorker(data, single_fit_func)
# Sequential mode with progress tracking
if n_workers == 1:
logger.info(f"Running {n_starts} optimizations sequentially")
results: list[SingleStartResult] = []
with MultiStartProgressTracker(
n_starts=n_starts,
enable_progress_bar=enable_progress_bar,
verbose=verbose,
) as progress:
for idx, start in enumerate(starts):
logger.debug(f"Starting optimization {idx + 1}/{n_starts}")
try:
result = worker(idx, start)
results.append(result)
progress.update(
start_idx=idx,
success=result.success,
chi_squared=result.chi_squared,
message=result.message,
wall_time=result.wall_time,
)
except (ValueError, RuntimeError, OSError) as e:
logger.debug(f"Optimization {idx + 1} raised exception: {e}")
failed_result = SingleStartResult(
start_idx=idx,
initial_params=start,
final_params=start,
chi_squared=np.inf,
success=False,
message=str(e),
)
results.append(failed_result)
progress.update(
start_idx=idx,
success=False,
chi_squared=np.inf,
message=str(e),
)
return results
# Parallel mode - progress bar updated as results complete
logger.info(
f"Running {n_starts} optimizations in parallel with {n_workers} workers"
)
with MultiStartProgressTracker(
n_starts=n_starts,
enable_progress_bar=enable_progress_bar,
verbose=verbose,
) as progress:
results = _run_parallel_with_progress(worker, starts, n_workers, progress)
return results
def _run_parallel_with_progress(
optimize_func: Callable[[int, NDArray[np.float64]], SingleStartResult],
starts: NDArray[np.float64],
n_workers: int,
progress: MultiStartProgressTracker,
) -> list[SingleStartResult]:
"""Run optimizations in parallel with progress tracking.
Uses 'spawn' multiprocessing context to avoid JAX JIT compilation deadlocks.
Falls back to sequential execution if parallel execution fails or hangs.
Parameters
----------
optimize_func : Callable
Function that takes (start_idx, initial_params) and returns SingleStartResult.
starts : NDArray[np.float64]
Starting points as (n_starts, n_params) array.
n_workers : int
Number of parallel workers.
progress : MultiStartProgressTracker
Progress tracker to update.
Returns
-------
list[SingleStartResult]
Results from all starting points.
"""
results: list[SingleStartResult] = []
fallback_to_sequential = False
fallback_reason = ""
# Use 'spawn' context to avoid JAX/XLA deadlocks
# 'fork' can cause issues with JAX's XLA compilation locks
mp_context = multiprocessing.get_context("spawn")
try:
logger.info(
f"Launching parallel execution: {n_workers} workers, "
f"{len(starts)} tasks, spawn context"
)
logger.debug(f"Worker timeout: {_WORKER_TIMEOUT}s per task")
parallel_start_time = time.perf_counter()
with ProcessPoolExecutor(
max_workers=n_workers, mp_context=mp_context
) as executor:
futures = {
executor.submit(optimize_func, idx, start): idx
for idx, start in enumerate(starts)
}
logger.debug(f"Submitted {len(futures)} tasks to executor")
# Track completion with timeout
completed_count = 0
total_count = len(futures)
for future in as_completed(futures, timeout=_WORKER_TIMEOUT):
idx = futures[future]
try:
# Timeout for individual result retrieval
result = future.result(timeout=60)
results.append(result)
completed_count += 1
progress.update(
start_idx=idx,
success=result.success,
chi_squared=result.chi_squared,
message=result.message,
wall_time=result.wall_time,
)
logger.debug(
f"Worker {idx} completed: chi2={result.chi_squared:.4e}, "
f"time={result.wall_time:.1f}s ({completed_count}/{total_count})"
)
except TimeoutError:
logger.warning(
f"Worker {idx} timed out after 60s waiting for result"
)
logger.info("Falling back to sequential execution")
fallback_to_sequential = True
fallback_reason = f"Worker {idx} timeout"
# Cancel remaining futures
for f in futures:
f.cancel()
break
except (
ValueError,
RuntimeError,
TypeError,
OSError,
AttributeError,
) as e:
error_msg = str(e)
if _is_pickle_error(error_msg):
fallback_to_sequential = True
fallback_reason = f"Pickle error: {error_msg[:100]}"
logger.warning(f"Pickle/serialization error detected: {e}")
logger.info("Falling back to sequential execution")
for f in futures:
f.cancel()
break
else:
# Non-fatal error for this worker
logger.warning(f"Worker {idx} failed: {e}")
failed_result = SingleStartResult(
start_idx=idx,
initial_params=starts[idx],
final_params=starts[idx],
chi_squared=np.inf,
success=False,
message=str(e),
)
results.append(failed_result)
completed_count += 1
progress.update(
start_idx=idx,
success=False,
chi_squared=np.inf,
message=str(e),
)
if not fallback_to_sequential:
parallel_time = time.perf_counter() - parallel_start_time
logger.info(
f"Parallel execution complete: {completed_count}/{total_count} tasks "
f"in {parallel_time:.1f}s"
)
except TimeoutError:
logger.warning(f"Parallel execution timed out after {_WORKER_TIMEOUT}s")
logger.info("Falling back to sequential execution")
fallback_to_sequential = True
fallback_reason = f"Overall timeout ({_WORKER_TIMEOUT}s)"
except (ValueError, RuntimeError, TypeError, OSError, AttributeError) as e:
error_msg = str(e)
if _is_pickle_error(error_msg):
fallback_to_sequential = True
fallback_reason = f"Pickle error: {error_msg[:100]}"
logger.warning(f"ProcessPoolExecutor pickle error: {e}")
else:
fallback_to_sequential = True
fallback_reason = f"Executor error: {error_msg[:100]}"
logger.warning(f"ProcessPoolExecutor failed: {e}")
logger.info("Falling back to sequential execution")
# If parallel failed, fall back to sequential
if fallback_to_sequential:
logger.info(f"Sequential fallback reason: {fallback_reason}")
logger.info(f"Running {len(starts)} optimizations sequentially")
# Clear any partial results
results = []
for idx, start in enumerate(starts):
logger.debug(f"Starting sequential optimization {idx + 1}/{len(starts)}")
try:
result = optimize_func(idx, start)
results.append(result)
progress.update(
start_idx=idx,
success=result.success,
chi_squared=result.chi_squared,
message=result.message,
wall_time=result.wall_time,
)
except (ValueError, RuntimeError, OSError) as e:
logger.debug(f"Sequential optimization {idx + 1} failed: {e}")
failed_result = SingleStartResult(
start_idx=idx,
initial_params=start,
final_params=start,
chi_squared=np.inf,
success=False,
message=str(e),
)
results.append(failed_result)
progress.update(
start_idx=idx,
success=False,
chi_squared=np.inf,
message=str(e),
)
# Sort by start_idx for consistent ordering
results.sort(key=lambda r: r.start_idx)
return results