"""JAX-First Optimization for Homodyne.4
==========================================
Simplified optimization system using NLSQ package (primary) and CMC
(high-accuracy Bayesian) for robust parameter estimation in homodyne analysis.
This module implements the streamlined optimization philosophy:
1. NLSQ as primary method (fast, reliable parameter estimation)
2. CMC (NumPyro/NUTS) for uncertainty quantification
3. Unified homodyne model: c2_fitted = c2_theory * contrast + offset
Key Features:
- NLSQ trust-region optimization (Levenberg-Marquardt) as foundation
- CMC: Fresh reimplementation with ArviZ-native output
- CPU-primary architecture (GPU removed in v2.3.0)
- Dataset size-aware optimization strategies
Performance Comparison:
- NLSQ: Fast, reliable parameter estimation
- CMC: Full posterior sampling, publication-quality uncertainty
Note: Legacy mcmc/ package removed in v3.0. CMC is the sole MCMC backend.
"""
from __future__ import annotations
import logging
from typing import Any
# Import submodules as attributes for hasattr() checks
# These imports expose the submodule packages even if their contents fail to import
from homodyne.optimization import nlsq
_logger = logging.getLogger(__name__)
# Handle NLSQ imports with intelligent fallback
try:
from homodyne.optimization.nlsq import ( # Chunking; Residual; Sequential
MultiStartConfig,
MultiStartResult,
NLSQResult,
NLSQWrapper,
OptimizationResult,
StratificationDiagnostics,
StratifiedResidualFunction,
StratifiedResidualFunctionJIT,
create_angle_stratified_data,
create_angle_stratified_indices,
create_stratified_residual_function,
fit_nlsq_jax,
fit_nlsq_multistart,
optimize_per_angle_sequential,
should_use_stratification,
)
# NOTE: DatasetSizeStrategy, OptimizationStrategy, estimate_memory_requirements
# removed from public API in v2.12.0. Use NLSQ's WorkflowSelector instead.
NLSQ_AVAILABLE = True
except ImportError as e:
_logger.warning("Could not import NLSQ optimization: %s", e)
fit_nlsq_jax = None # type: ignore[assignment]
fit_nlsq_multistart = None # type: ignore[assignment]
MultiStartConfig = None # type: ignore[assignment,misc]
MultiStartResult = None # type: ignore[assignment,misc]
NLSQResult = None # type: ignore[assignment,misc]
NLSQWrapper = None # type: ignore[assignment,misc]
OptimizationResult = None # type: ignore[assignment,misc]
StratificationDiagnostics = None # type: ignore[assignment,misc]
create_angle_stratified_data = None # type: ignore[assignment]
create_angle_stratified_indices = None # type: ignore[assignment]
should_use_stratification = None # type: ignore[assignment]
StratifiedResidualFunction = None # type: ignore[assignment,misc]
StratifiedResidualFunctionJIT = None # type: ignore[assignment,misc]
create_stratified_residual_function = None # type: ignore[assignment]
optimize_per_angle_sequential = None # type: ignore[assignment]
NLSQ_AVAILABLE = False
# Handle CMC imports (NO FALLBACK to legacy mcmc - it's removed)
# Try to import the cmc module for hasattr() checks
try:
from homodyne.optimization import cmc
CMC_SUBMODULE_AVAILABLE = True
except ImportError:
cmc = None # type: ignore[assignment]
CMC_SUBMODULE_AVAILABLE = False
try:
from homodyne.optimization.cmc import (
CMCConfig,
CMCResult,
fit_mcmc_jax,
)
# Aliases for backward compatibility
MCMCResult = CMCResult
# CMC uses NumPyro/JAX
MCMC_JAX_AVAILABLE = True
NUMPYRO_AVAILABLE = True
MCMC_AVAILABLE = True
# Check BlackJAX availability separately (optional dependency)
try:
import blackjax as _blackjax_check # noqa: F401
BLACKJAX_AVAILABLE = True
except ImportError:
BLACKJAX_AVAILABLE = False
except ImportError as e:
_logger.warning("Could not import CMC optimization: %s", e)
fit_mcmc_jax = None # type: ignore[assignment]
CMCConfig = None # type: ignore[assignment,misc]
CMCResult = None # type: ignore[assignment,misc]
MCMCResult = None # type: ignore[misc,assignment]
MCMC_JAX_AVAILABLE = False
NUMPYRO_AVAILABLE = False
BLACKJAX_AVAILABLE = False
MCMC_AVAILABLE = False
# Module status
OPTIMIZATION_STATUS = {
"nlsq_available": NLSQ_AVAILABLE,
"mcmc_available": MCMC_AVAILABLE,
"cmc_available": MCMC_AVAILABLE, # CMC is the MCMC backend
"jax_available": MCMC_JAX_AVAILABLE if MCMC_AVAILABLE else False,
"numpyro_available": NUMPYRO_AVAILABLE if MCMC_AVAILABLE else False,
"blackjax_available": BLACKJAX_AVAILABLE if MCMC_AVAILABLE else False,
}
# Primary API functions
__all__ = [
# Primary optimization methods
"fit_nlsq_jax", # NLSQ trust-region (PRIMARY)
"fit_nlsq_multistart", # Multi-start NLSQ (v2.6.0)
"fit_mcmc_jax", # CMC NumPyro/NUTS (SECONDARY)
# Result classes
"NLSQResult",
"MultiStartConfig",
"MultiStartResult",
"CMCResult", # CMC result class
"MCMCResult", # Alias for backward compatibility
"CMCConfig", # CMC configuration
# NLSQ components
"NLSQWrapper",
"OptimizationResult",
"StratificationDiagnostics",
"create_angle_stratified_data",
"create_angle_stratified_indices",
"should_use_stratification",
"StratifiedResidualFunction",
"StratifiedResidualFunctionJIT",
"create_stratified_residual_function",
"optimize_per_angle_sequential",
# NOTE: DatasetSizeStrategy, OptimizationStrategy, estimate_memory_requirements
# removed from public API in v2.12.0. Use NLSQ's WorkflowSelector instead.
# Status information
"OPTIMIZATION_STATUS",
"NLSQ_AVAILABLE",
"MCMC_AVAILABLE",
# Submodules
"nlsq",
"cmc",
]
[docs]
def get_optimization_info() -> dict[str, Any]:
"""Get information about available optimization methods.
Returns
-------
dict
Dictionary with availability status and recommendations
"""
info: dict[str, Any] = {
"status": OPTIMIZATION_STATUS.copy(),
"primary_method": "nlsq" if NLSQ_AVAILABLE else None,
"secondary_method": "cmc" if MCMC_AVAILABLE else None,
"recommendations": [],
}
if NLSQ_AVAILABLE:
info["recommendations"].append(
"Use fit_nlsq_jax() for fast, reliable parameter estimation",
)
if MCMC_AVAILABLE:
info["recommendations"].append(
"Use fit_mcmc_jax() for uncertainty quantification (CMC)",
)
if not NLSQ_AVAILABLE and not MCMC_AVAILABLE:
info["recommendations"].append(
"Install NLSQ and NumPyro for optimization capabilities",
)
return info