Source code for homodyne.cli.commands

"""Command Dispatcher for Homodyne CLI
======================================

Handles command execution and coordination between CLI arguments,
configuration, and optimization methods.

This module serves as the main orchestrator and re-export hub.
Implementation details are in focused submodules:
- config_handling: Device config, config loading, CLI overrides
- data_pipeline: Data loading, t=0 exclusion, angle filtering, MCMC pooling
- optimization_runner: NLSQ/CMC optimization, warm-start
- result_saving: JSON/NPZ saving for NLSQ and MCMC results
- plot_dispatch: Plotting dispatch for experimental/simulated data

Functions that are mock-patched by tests via @patch("homodyne.cli.commands.X")
remain in this module to preserve test compatibility.
"""

from __future__ import annotations

import argparse
import json
import time
from datetime import datetime
from pathlib import Path
from typing import Any, cast

import numpy as np
from numpy.typing import NDArray

from homodyne.cli.args_parser import validate_args

# Re-export from config_handling
from homodyne.cli.config_handling import (  # noqa: F401
    _apply_cli_overrides,
    _build_mcmc_runtime_kwargs,
    _configure_device,
    _get_default_config,
    _load_configuration,
)

# Re-export from data_pipeline
from homodyne.cli.data_pipeline import (  # noqa: F401
    COMMON_XPCS_ANGLES,
    _apply_angle_filtering,
    _apply_angle_filtering_for_optimization,
    _exclude_t0_from_analysis,
    _pool_mcmc_data,
    _prepare_cmc_config,
)

# Re-export from optimization_runner
from homodyne.cli.optimization_runner import (  # noqa: F401
    _resolve_nlsq_warmstart,
    _run_nlsq_optimization,
    load_nlsq_result_from_file,
)

# Re-export from plot_dispatch
from homodyne.cli.plot_dispatch import (  # noqa: F401
    _apply_angle_filtering_for_plot,
    _handle_plotting,
    generate_nlsq_plots,
)

# Re-export from result_saving
from homodyne.cli.result_saving import (  # noqa: F401
    _compute_theoretical_c2_from_mcmc,
    _extract_nlsq_metadata,
    _json_safe,
    _json_serializer,
    _prepare_parameter_data,
    _save_nlsq_json_files,
    _save_nlsq_npz_file,
    _save_results,
    save_mcmc_results,
    save_nlsq_results,
)
from homodyne.config.parameter_space import ParameterSpace
from homodyne.data.angle_filtering import (
    angle_in_range as _data_angle_in_range,
)
from homodyne.data.angle_filtering import (
    normalize_angle_to_symmetric_range as _data_normalize_angle_to_symmetric_range,
)
from homodyne.utils.logging import (
    AnalysisSummaryLogger,
    configure_logging,
    get_logger,
    log_exception,
    log_phase,
)

logger = get_logger(__name__)

# Import core modules with fallback.
# These module-level names are also used as mock patch targets by tests
# (e.g. @patch("homodyne.cli.commands.XPCSDataLoader")), so they must
# remain importable from this module.
try:
    from homodyne.config.manager import ConfigManager
    from homodyne.data.xpcs_loader import XPCSDataLoader
    from homodyne.optimization import fit_mcmc_jax

    HAS_CORE_MODULES = True
    HAS_XPCS_LOADER = True
except ImportError as e:
    HAS_CORE_MODULES = False
    HAS_XPCS_LOADER = False
    logger.error(f"Core modules not available: {e}")

    # Fallback for missing XPCSDataLoader
    class XPCSDataLoader:  # type: ignore[no-redef]
        """Placeholder when XPCSDataLoader is not available."""

        def __init__(self, *args: Any, **kwargs: Any) -> None:
            raise ImportError("XPCSDataLoader not available")


[docs] def normalize_angle_to_symmetric_range( angle: float | NDArray[np.floating[Any]], ) -> float | NDArray[np.floating[Any]]: """Normalize angle(s) to [-180, 180] range. This is a wrapper that delegates to homodyne.data.angle_filtering. """ return _data_normalize_angle_to_symmetric_range(angle)
def _angle_in_range(angle: float, min_angle: float, max_angle: float) -> bool: """Check if angle is in range, accounting for wrap-around at +/-180. This is a wrapper that delegates to homodyne.data.angle_filtering. """ return _data_angle_in_range(angle, min_angle, max_angle)
[docs] def dispatch_command(args: argparse.Namespace) -> dict[str, Any]: """Dispatch command based on parsed CLI arguments. Parameters ---------- args : argparse.Namespace Parsed command-line arguments Returns ------- dict Command execution result with success status and details """ run_id = getattr(args, "run_id", None) or datetime.now().strftime("%Y%m%d_%H%M%S") args.run_id = run_id logger.info(f"[CLI] Dispatching homodyne analysis command (run_id={run_id})") # Log resolved command-line arguments at DEBUG level (T023) logger.debug(f"[CLI] Resolved arguments: {vars(args)}") # Validate arguments if not validate_args(args): return {"success": False, "error": "Invalid command-line arguments"} if not HAS_CORE_MODULES: return { "success": False, "error": "Core modules not available. Please check installation.", } # Initialize analysis summary logger (T024) cli_mode = "laminar_flow" if getattr(args, "laminar_flow", False) else "static" if getattr(args, "static_mode", False): cli_mode = "static_isotropic" summary = AnalysisSummaryLogger(run_id=run_id, analysis_mode=cli_mode) try: # Create output directory args.output_dir.mkdir(parents=True, exist_ok=True) # Phase 1: Load configuration (T019) summary.start_phase("config_loading") with log_phase("config_loading"): config = _load_configuration(args) # Configure logging using config + CLI verbosity flags config_dict = ( config.get_config() if hasattr(config, "get_config") else config ) logging_cfg = ( config_dict.get("logging", {}) if isinstance(config_dict, dict) else {} ) log_file = configure_logging( logging_cfg, verbose=getattr(args, "verbose", 0) > 0, quiet=getattr(args, "quiet", False), output_dir=args.output_dir, run_id=run_id, ) if log_file: logger.info(f"[CLI] Log file created: {log_file}") summary.add_output_file(log_file) summary.end_phase("config_loading") # Update analysis mode from loaded config (T055) config_dict = config.get_config() if hasattr(config, "get_config") else config if isinstance(config_dict, dict): config_analysis_mode = config_dict.get("analysis_mode", cli_mode) if config_analysis_mode != summary.analysis_mode: logger.debug( f"[CLI] Updated analysis_mode from '{summary.analysis_mode}' " f"to '{config_analysis_mode}' (from config)" ) summary.analysis_mode = config_analysis_mode # Configure device (CPU/GPU) device_config = _configure_device(args) # Check if only simulated data plotting is requested plot_exp = getattr(args, "plot_experimental_data", False) plot_sim = getattr(args, "plot_simulated_data", False) save_plots = getattr(args, "save_plots", False) # Simulated data plotting doesn't need experimental data or optimization if plot_sim and not plot_exp and not save_plots: logger.info( "[CLI] Plotting simulated data only (skipping data loading and optimization)", ) config_dict_plot: dict[str, Any] = ( config.get_config() if hasattr(config, "get_config") else cast(dict[str, Any], config) ) _handle_plotting(args, None, {}, config_dict_plot) summary.set_convergence_status("skipped_simulated_only") summary.log_summary(logger) return { "success": True, "result": None, "device_config": device_config, "output_dir": str(args.output_dir), } # Phase 2: Load data (T020) summary.start_phase("data_loading") with log_phase("data_loading", track_memory=True) as data_phase: data = _load_data(args, config) summary.end_phase("data_loading", memory_peak_gb=data_phase.memory_peak_gb) # Plot experimental data only (no optimization needed) plot_only = plot_exp and not save_plots and not plot_sim if plot_only: logger.info("[CLI] Plotting experimental data only (skipping optimization)") result = None summary.set_convergence_status("skipped_plot_only") else: # Phase 3: Run optimization (T021) summary.start_phase("optimization") with log_phase("optimization", track_memory=True) as opt_phase: result = _run_optimization(args, config, data) summary.end_phase("optimization", memory_peak_gb=opt_phase.memory_peak_gb) # Record optimization metrics if result is not None: is_cmc = ( callable(getattr(result, "is_cmc_result", None)) and result.is_cmc_result() ) if hasattr(result, "chi_squared") and not is_cmc: summary.record_metric("chi_squared", float(result.chi_squared)) if hasattr(result, "n_iterations"): summary.record_metric("n_iterations", float(result.n_iterations)) if hasattr(result, "convergence_status"): summary.set_convergence_status(result.convergence_status) elif hasattr(result, "converged"): summary.set_convergence_status( "converged" if result.converged else "not_converged" ) elif hasattr(result, "success"): summary.set_convergence_status( "converged" if result.success else "failed" ) # Phase 4: Save results (T022) summary.start_phase("result_saving") with log_phase("result_saving"): _save_results(args, result, device_config, data, config) summary.end_phase("result_saving") # Handle plotting options config_dict2: dict[str, Any] = ( config.get_config() if hasattr(config, "get_config") else cast(dict[str, Any], config) ) _handle_plotting(args, result, data, config_dict2) logger.info("[CLI] Analysis completed successfully") # Log analysis summary (T024) summary.log_summary(logger) # Summary message if log_file: logger.info(f"[CLI] Analysis log saved to: {log_file}") else: log_dir = args.output_dir / "logs" log_files = list(log_dir.glob("homodyne_analysis_*.log")) if log_files: logger.info(f"[CLI] Analysis log saved to: {log_files[-1]}") return { "success": True, "result": result, "device_config": device_config, "output_dir": str(args.output_dir), "summary": summary.as_dict(), } except Exception as e: log_exception(logger, e, context={"run_id": run_id, "phase": "dispatch"}) summary.set_convergence_status("failed") summary.increment_error_count() return {"success": False, "error": str(e)}
def _load_data(args: argparse.Namespace, config: ConfigManager) -> dict[str, Any]: """Load experimental data using XPCSDataLoader. Uses XPCSDataLoader which properly handles the config format (data_folder_path + data_file_name) internally. """ logger.info("Loading experimental data...") if not HAS_XPCS_LOADER: raise RuntimeError( "XPCSDataLoader not available. " "Please ensure homodyne.data module is properly installed", ) try: if args.data_file: data_file_path = Path(args.data_file).resolve() parent_dir = data_file_path.parent if parent_dir == Path.cwd(): logger.debug( f"Using current directory for data file: {data_file_path.name}", ) temp_config = { "experimental_data": { "data_folder_path": str(parent_dir), "data_file_name": data_file_path.name, }, "analyzer_parameters": ( config.config.get("analyzer_parameters", {}) if hasattr(config, "config") and config.config is not None else {"dt": 0.1, "start_frame": 1, "end_frame": -1} ), } logger.info(f"Loading data from CLI override: {data_file_path}") loader = XPCSDataLoader(config_dict=temp_config) else: if not hasattr(config, "config") or not config.config: raise ValueError("No configuration loaded") exp_data = config.config.get("experimental_data", {}) if not exp_data.get("data_folder_path") and not exp_data.get("file_path"): raise ValueError( "No data file specified in configuration.\n" "Config must have either:\n" " experimental_data:\n" " data_folder_path: ./path/to/data/\n" " data_file_name: experiment.hdf\n" "Or:\n" " experimental_data:\n" " file_path: ./path/to/data/experiment.hdf\n" "Or use: --data-file path/to/data.hdf", ) logger.info("Loading data from configuration") loader = XPCSDataLoader(config_dict=config.config) data = loader.load_experimental_data() data_size = 0 if "c2_exp" in data: c2_exp = data["c2_exp"] data_size = c2_exp.size if hasattr(c2_exp, "size") else len(c2_exp) logger.info(f"OK: Data loaded successfully: {data_size:,} data points") return data except FileNotFoundError as e: logger.error(f"Data file not found: {e}") raise RuntimeError(f"Data file not found: {e}") from e except Exception as e: logger.error(f"Data loading failed: {e}") raise RuntimeError(f"Failed to load experimental data: {e}") from e def _run_optimization( args: argparse.Namespace, config: ConfigManager, data: dict[str, Any] ) -> Any: """Run the specified optimization method.""" method = args.method logger.info(f"Running {method.upper()} optimization...") start_time = time.perf_counter() # Apply angle filtering before optimization (if configured) filtered_data = _apply_angle_filtering_for_optimization(data, config) # CRITICAL FIX: Exclude t=0 from analysis to prevent D(t) singularity filtered_data = _exclude_t0_from_analysis(filtered_data) logger.debug("Using NLSQ native large dataset handling") try: if method == "nlsq": result = _run_nlsq_optimization(filtered_data, config, args) elif method == "cmc": cmc_config = _prepare_cmc_config(args, config) _backend_cfg = cmc_config["backend"] # noqa: F841 nlsq_result = _resolve_nlsq_warmstart(args, filtered_data, config) config_config_early = ( config.config if hasattr(config, "config") and config.config is not None else {} ) analysis_mode_early = cast( str, config_config_early.get("analysis_mode", "static_isotropic") ) require_warmstart = cmc_config.get("validation", {}).get( "require_nlsq_warmstart", False ) if ( require_warmstart and nlsq_result is None and "laminar" in analysis_mode_early.lower() ): raise ValueError( "CMC WARM-START REQUIRED: laminar_flow mode requires NLSQ warm-start " "when require_nlsq_warmstart=True. Remove --no-nlsq-warmstart flag or set " "validation.require_nlsq_warmstart=false in CMC config." ) logger.info(f"Method: {method.upper()} (Consensus Monte Carlo)") sharding = cmc_config.get("sharding", {}) backend = cmc_config.get("backend", {}) if isinstance(backend, str): backend_str = backend backend_config = cmc_config.get("backend_config", {}) parallel_backend = ( backend_config.get("name", "auto") if backend_config else "auto" ) backend_display = f"{backend_str}/{parallel_backend}" else: backend_display = backend.get("name", "auto") logger.debug( f"CMC sharding: strategy={sharding.get('strategy', 'auto')}, " f"num_shards={sharding.get('num_shards', 'auto')}, " f"backend={backend_display}", ) pooled = _pool_mcmc_data(filtered_data) mcmc_data = pooled["mcmc_data"] t1_pooled = pooled["t1_pooled"] t2_pooled = pooled["t2_pooled"] phi_pooled = pooled["phi_pooled"] _n_phi = pooled["n_phi"] # noqa: F841 _n_t = pooled["n_t"] # noqa: F841 initial_values = ( config.get_initial_parameters() if hasattr(config, "get_initial_parameters") else {} ) if initial_values: logger.debug( f"MCMC initial values from config: {list(initial_values.keys())} = " f"{[f'{v:.4g}' for v in initial_values.values()]}" ) else: logger.debug( "MCMC will use mid-point defaults (no initial_parameters.values in config)" ) config_config = ( config.config if hasattr(config, "config") and config.config is not None else {} ) analysis_mode_str = cast( str, config_config.get("analysis_mode", "static_isotropic") ) parameter_space = ParameterSpace.from_config( config_dict=config_config, analysis_mode=analysis_mode_str, ) logger.debug( f"Created ParameterSpace with config for {analysis_mode_str} mode" ) mcmc_runtime_kwargs = _build_mcmc_runtime_kwargs(args, config) result = fit_mcmc_jax( mcmc_data, t1=t1_pooled, t2=t2_pooled, phi=phi_pooled, q=( filtered_data.get("wavevector_q_list", [1.0])[0] if ( filtered_data.get("wavevector_q_list") is not None and len(filtered_data.get("wavevector_q_list", [])) > 0 ) else 1.0 ), L=float( config_config.get("analyzer_parameters", {}) .get("geometry", {}) .get("stator_rotor_gap", 2000000.0) ), analysis_mode=cast( str, config_config.get("analysis_mode", "static_isotropic") ), method=method, cmc_config=cmc_config, initial_values=initial_values, parameter_space=parameter_space, dt=config_config.get("analyzer_parameters", {}).get("dt"), nlsq_result=nlsq_result, **mcmc_runtime_kwargs, ) if hasattr(result, "inference_data") and result.inference_data is not None: analysis_mode_for_plot = cast( str, config_config.get("analysis_mode", "static_isotropic") ) _generate_cmc_diagnostic_plots( result, args.output_dir, analysis_mode_for_plot ) else: logger.warning( "Cannot generate ArviZ diagnostic plots: inference_data not available" ) elif method == "both": # Phase 1: Run NLSQ logger.info("Running sequential NLSQ -> CMC pipeline...") logger.info("Phase 1/2: NLSQ optimization...") nlsq_result = _run_nlsq_optimization(filtered_data, config, args) nlsq_time = time.perf_counter() - start_time logger.info(f"Phase 1/2: NLSQ completed in {nlsq_time:.3f}s") # Save NLSQ results before proceeding to CMC _save_results(args, nlsq_result, {}, filtered_data, config) logger.info("NLSQ results saved, proceeding to CMC...") # Phase 2: Run CMC with NLSQ warm-start logger.info("Phase 2/2: CMC optimization with NLSQ warm-start...") cmc_config = _prepare_cmc_config(args, config) config_config = ( config.config if hasattr(config, "config") and config.config is not None else {} ) analysis_mode_str = cast( str, config_config.get("analysis_mode", "static_isotropic") ) parameter_space = ParameterSpace.from_config( config_dict=config_config, analysis_mode=analysis_mode_str, ) pooled = _pool_mcmc_data(filtered_data) mcmc_data = pooled["mcmc_data"] t1_pooled = pooled["t1_pooled"] t2_pooled = pooled["t2_pooled"] phi_pooled = pooled["phi_pooled"] initial_values = ( config.get_initial_parameters() if hasattr(config, "get_initial_parameters") else {} ) mcmc_runtime_kwargs = _build_mcmc_runtime_kwargs(args, config) result = fit_mcmc_jax( mcmc_data, t1=t1_pooled, t2=t2_pooled, phi=phi_pooled, q=( filtered_data.get("wavevector_q_list", [1.0])[0] if ( filtered_data.get("wavevector_q_list") is not None and len(filtered_data.get("wavevector_q_list", [])) > 0 ) else 1.0 ), L=float( config_config.get("analyzer_parameters", {}) .get("geometry", {}) .get("stator_rotor_gap", 2000000.0) ), analysis_mode=analysis_mode_str, method="cmc", cmc_config=cmc_config, initial_values=initial_values, parameter_space=parameter_space, dt=config_config.get("analyzer_parameters", {}).get("dt"), nlsq_result=nlsq_result, **mcmc_runtime_kwargs, ) if hasattr(result, "inference_data") and result.inference_data is not None: _generate_cmc_diagnostic_plots( result, args.output_dir, analysis_mode_str ) else: raise ValueError(f"Unknown optimization method: {method}") optimization_time = time.perf_counter() - start_time logger.info( f"OK: {method.upper()} optimization completed in {optimization_time:.3f}s", ) return result except Exception as e: optimization_time = time.perf_counter() - start_time logger.error( f"{method.upper()} optimization failed after {optimization_time:.3f}s: {e}", ) raise def _generate_cmc_diagnostic_plots( result: Any, output_dir: Path, analysis_mode: str ) -> None: """Generate CMC/MCMC diagnostic plots using ArviZ. This function generates 6 standard ArviZ diagnostic plots: 1. Pair plot (corner plot) - pairwise parameter correlations 2. Forest plot - posterior distributions with HDI 3. Energy plot - HMC energy diagnostics 4. Autocorrelation plot - sample independence 5. Rank plot - chain mixing diagnostics 6. ESS plot - effective sample size evolution These plots are generated REGARDLESS of convergence status to help diagnose sampling problems. Parameters ---------- result : Any MCMC result object with inference_data (CMCResult or similar) output_dir : Path Output directory for saving plots analysis_mode : str Analysis mode (static_isotropic or laminar_flow) """ if not hasattr(result, "inference_data") or result.inference_data is None: logger.warning( "No inference_data available in result - skipping ArviZ diagnostic plots" ) return try: from homodyne.optimization.cmc.plotting import generate_diagnostic_plots diag_dir = output_dir / "diagnostics" diag_dir.mkdir(parents=True, exist_ok=True) logger.info("Generating ArviZ diagnostic plots...") saved_plots = generate_diagnostic_plots( result=result, output_dir=diag_dir, ) if saved_plots: logger.info(f"Generated {len(saved_plots)} ArviZ diagnostic plots:") for plot_path in saved_plots: logger.info(f" - {plot_path.name}") else: logger.warning("No diagnostic plots were generated") if hasattr(result, "cmc_diagnostics") and result.cmc_diagnostics is not None: diag_data = { "per_shard_diagnostics": result.cmc_diagnostics.get( "per_shard_diagnostics", [] ), "between_shard_kl": result.cmc_diagnostics.get("kl_matrix", []), "success_rate": result.cmc_diagnostics.get("success_rate", 0.0), "combined_diagnostics": result.cmc_diagnostics.get( "combined_diagnostics", {} ), } diag_file = diag_dir / "cmc_diagnostics.json" with open(diag_file, "w", encoding="utf-8") as f: json.dump(diag_data, f, indent=2, default=_json_serializer) logger.debug(f"CMC diagnostic data saved to: {diag_file}") except ImportError as e: logger.warning(f"ArviZ plotting not available: {e}") except Exception as e: logger.warning(f"Failed to generate diagnostic plots: {e}") logger.debug(f"Diagnostic plot error details: {e}", exc_info=True)