Source code for homodyne.optimization.cmc.core

"""CMC core module - main entry point.

This module provides the fit_mcmc_jax() function that serves as the
main entry point for CMC analysis, matching the CLI signature.
"""

from __future__ import annotations

import math
import time
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any

import jax.numpy as jnp
import numpy as np

from homodyne.core.scaling_utils import estimate_per_angle_scaling
from homodyne.optimization.cmc.backends import select_backend
from homodyne.optimization.cmc.config import CMCConfig
from homodyne.optimization.cmc.data_prep import (
    prepare_mcmc_data,
    shard_data_angle_balanced,
    shard_data_random,
    shard_data_stratified,
)
from homodyne.optimization.cmc.diagnostics import (
    compute_precision_analysis,
    log_precision_analysis,
    summarize_diagnostics,
)
from homodyne.optimization.cmc.model import get_xpcs_model
from homodyne.optimization.cmc.priors import get_param_names_in_order
from homodyne.optimization.cmc.reparameterization import (
    compute_t_ref,
    transform_nlsq_to_reparam_space,
)
from homodyne.optimization.cmc.results import CMCResult
from homodyne.optimization.cmc.sampler import run_nuts_sampling
from homodyne.optimization.cmc.scaling import compute_scaling_factors
from homodyne.utils.logging import get_logger, log_exception, with_context

if TYPE_CHECKING:
    from homodyne.config.parameter_space import ParameterSpace

logger = get_logger(__name__)


def _resolve_max_points_per_shard(
    analysis_mode: str,
    n_total: int,
    max_points_per_shard: int | str | None,
    max_shards: int | None = None,
    n_phi: int = 1,
    iteration_ratio: float = 1.0,
) -> int:
    """Determine optimal shard size based on mode, data volume, and angle count.

    NUTS MCMC is O(n) per iteration - evaluates ALL points in a shard.
    Laminar flow (7 params) needs ~10x smaller shards than static (3 params)
    due to complex gradient computation (trigonometric functions, cumulative integrals).

    CRITICAL (Jan 2026): Angle-aware scaling for multi-angle datasets.
    When n_phi is small (e.g., 3 angles), each shard contains points from all angles,
    making gradient computation ~n_phi times more expensive than single-angle shards.
    We scale the base shard size inversely with angle count to compensate.

    FIX (Jan 2026 v2): Increased minimum shard size to 10K for laminar_flow.
    Previous 3K shards with 3 angles created 999 data-starved shards that caused:
    - High per-shard heterogeneity (D_offset CV=1.80, gamma_dot_t0 CV=1.50)
    - CMC posteriors diverging 37% from NLSQ on D0
    - Artificially narrow uncertainties from corrupted precision-weighted combination

    FIX (Feb 2026): Iteration-aware shard sizing.
    Base shard sizes are calibrated for the default 2000 iterations (500+1500).
    When users configure more iterations (e.g., 4000), the per-shard cost doubles,
    causing timeouts.  Scale shard size inversely with iteration_ratio to compensate.

    Scaling guidelines (laminar_flow mode with 4 chains, 2000 iterations):
    - 10K points → ~8-12 min/shard (single angle)
    - 10K points with 3 angles → ~20-35 min/shard (angle scaling effect)
    - 15K points with 3 angles → ~30-50 min/shard

    Memory scalability for shard combination:
    - Each shard result: ~100KB (13 params × 4 chains × 1500 samples × 8 bytes)
    - Peak memory: ~6 × K MB where K = number of shards
    - Safe limits: K=1000 → ~6GB, K=5000 → ~30GB, K=10000 → ~60GB

    Dataset size guidelines:
    - 1M points: ~100 shards (10K/shard) → ~600MB memory
    - 10M points: ~1000 shards (10K/shard) → ~6GB memory
    - 100M points: ~10000 shards (10K/shard) → ~60GB memory
    - 1B points: ~50000 shards (20K/shard) → ~300GB memory (HPC only)

    Parameters
    ----------
    analysis_mode : str
        Analysis mode: "static" or "laminar_flow".
    n_total : int
        Total number of data points.
    max_points_per_shard : int | str | None
        User-specified shard size or "auto".
    max_shards : int | None
        Maximum number of shards to create (caps memory usage).
        If None, dynamically computed based on dataset size:
        - Small (<10M): up to 2000 shards
        - Medium (10M-100M): up to 10000 shards
        - Large (100M-1B): up to 50000 shards
        - Very large (>1B): up to 100000 shards
    n_phi : int
        Number of phi angles in the dataset. Used for angle-aware scaling.
        Default 1 (single angle - no scaling applied).
    iteration_ratio : float
        Ratio of default iterations to actual iterations.  Default 1.0 means
        no adjustment.  Values < 1.0 (user configured more iterations than
        default) shrink shards; values > 1.0 (fewer iterations) grow them.
        Clamped to [0.25, 2.0] to avoid extreme shard sizes.
    """
    # Minimum shard sizes to prevent data starvation (Jan 2026 fix)
    MIN_SHARD_SIZE_LAMINAR = (
        3_000  # Reduced: reparameterization fixes bimodal posteriors
    )
    MIN_SHARD_SIZE_STATIC = 5_000  # 5K minimum for static

    if max_points_per_shard is not None and max_points_per_shard != "auto":
        user_specified = int(max_points_per_shard)
        # Still enforce minimum for laminar_flow even with user specification
        if analysis_mode == "laminar_flow" and user_specified < MIN_SHARD_SIZE_LAMINAR:
            logger.warning(
                f"Enforcing minimum shard size for laminar_flow: "
                f"{user_specified:,} -> {MIN_SHARD_SIZE_LAMINAR:,} points "
                "(to prevent data-starved shards)"
            )
            return MIN_SHARD_SIZE_LAMINAR
        return user_specified

    # =========================================================================
    # Dynamic max_shards based on dataset size (Jan 2026 v3)
    # =========================================================================
    # Scale max_shards with dataset size to handle 1M to 1B+ point datasets
    if max_shards is None:
        if n_total >= 1_000_000_000:  # 1B+ points
            max_shards = 100_000  # ~600GB memory for combination
        elif n_total >= 100_000_000:  # 100M+ points
            max_shards = 50_000  # ~300GB memory
        elif n_total >= 10_000_000:  # 10M+ points
            max_shards = 10_000  # ~60GB memory
        else:
            max_shards = 2_000  # ~12GB memory (suitable for workstations)

    # =========================================================================
    # Angle-aware scaling factor (Jan 2026 fix v2: Less aggressive scaling)
    # =========================================================================
    # Multi-angle datasets with random sharding have ALL angles in each shard.
    # This makes gradient computation ~n_phi times more expensive.
    # Scale shard size inversely, but with higher floor to prevent data starvation.
    if n_phi <= 3:
        angle_factor = 0.6  # 60% of base size for 1-3 angles (was 0.3)
    elif n_phi <= 5:
        angle_factor = 0.7  # 70% for 4-5 angles (was 0.5)
    elif n_phi <= 10:
        angle_factor = 0.85  # 85% for 6-10 angles (was 0.7)
    else:
        angle_factor = 1.0  # Full size for many angles (stratified sharding preferred)

    # Auto-detection based on analysis mode and dataset size
    if analysis_mode == "laminar_flow":
        # Laminar flow: scale shard size with dataset to balance parallelism vs. data per shard
        # Feb 2026: Reduced from 30K-50K to 5K-10K. Reparameterization (D_ref, gamma_ref)
        # fixes bimodal posteriors, so shards no longer need 20K+ points per mode.
        # Adaptive sampling + prior tempering handle small shards correctly.
        if n_total >= 1_000_000_000:  # 1B+ points
            base = 10_000  # Large datasets: moderate shards for combination overhead
        elif n_total >= 100_000_000:  # 100M+ points
            base = 8_000
        elif n_total >= 50_000_000:  # 50M+ points
            base = 5_000
        elif n_total >= 20_000_000:  # 20M+ points
            base = 5_000
        elif n_total >= 2_000_000:  # 2M+ points
            base = 5_000
        else:
            base = 8_000  # Small datasets: fewer, larger shards
    else:
        # Static mode (3 params) - simpler gradients, ~2x larger shards than laminar_flow.
        # NUTS is still O(n) per leapfrog step; 100K base caused 2h+ shard timeouts (Feb 2026).
        if n_total >= 100_000_000:  # 100M+ points
            base = 20_000  # ~5K shards
        elif n_total >= 50_000_000:  # 50M+ points
            base = 15_000  # ~3.3K shards
        else:
            base = 10_000  # ~2x laminar_flow base; yields 5-6K after scaling

    # Apply angle-aware scaling
    scaled_base = int(base * angle_factor)

    # Enforce minimum shard size to avoid data-starved shards (Jan 2026 fix)
    if analysis_mode == "laminar_flow":
        scaled_base = max(scaled_base, MIN_SHARD_SIZE_LAMINAR)
    else:
        scaled_base = max(scaled_base, MIN_SHARD_SIZE_STATIC)

    # Log angle-aware scaling if factor < 1.0
    if angle_factor < 1.0:
        logger.info(
            f"Angle-aware shard sizing: n_phi={n_phi} -> factor={angle_factor:.1f}, "
            f"base={base:,} -> scaled={scaled_base:,} (min={MIN_SHARD_SIZE_LAMINAR if analysis_mode == 'laminar_flow' else MIN_SHARD_SIZE_STATIC:,})"
        )

    # Apply iteration-aware scaling (Feb 2026):
    # Base sizes are calibrated for 2000 default iterations.
    # When users configure more iterations, shrink shards proportionally.
    clamped_ratio = max(0.25, min(2.0, iteration_ratio))
    if clamped_ratio != 1.0:
        pre_iter = scaled_base
        scaled_base = int(scaled_base * clamped_ratio)
        min_size = (
            MIN_SHARD_SIZE_LAMINAR
            if analysis_mode == "laminar_flow"
            else MIN_SHARD_SIZE_STATIC
        )
        scaled_base = max(scaled_base, min_size)
        if scaled_base != pre_iter:
            logger.info(
                f"Iteration-aware shard sizing: ratio={clamped_ratio:.2f} "
                f"(default/actual iterations), shard_size {pre_iter:,} -> {scaled_base:,}"
            )

    # Cap shard count to prevent memory exhaustion during combination
    estimated_shards = n_total // scaled_base
    if estimated_shards > max_shards:
        # Increase shard size to respect max_shards limit
        adjusted = (n_total + max_shards - 1) // max_shards
        # Don't go too large for laminar_flow - cap at 100K for runtime
        if analysis_mode == "laminar_flow":
            MAX_LAMINAR_SHARD = 100_000  # 100K max to keep runtime reasonable
            if adjusted > MAX_LAMINAR_SHARD:
                final_shards = n_total // MAX_LAMINAR_SHARD
                # Warn user: need more memory for very large laminar_flow datasets
                logger.warning(
                    f"Dataset ({n_total:,} points) requires {final_shards:,} shards with "
                    f"max {MAX_LAMINAR_SHARD:,} pts/shard. "
                    f"Ensure sufficient memory (~{final_shards * 6 // 1000}GB) for shard combination."
                )
            adjusted = min(adjusted, MAX_LAMINAR_SHARD)
        return adjusted

    return scaled_base


def _cap_laminar_max_points(max_points_per_shard: int, logger) -> int:
    """Guard rails for laminar_flow to keep shards within reasonable runtime.

    Caps overly large user values that would routinely exceed per-shard timeouts.
    """

    # Allow larger shards for smoke runs; still guard against runaway values.
    cap = 3_000_000
    if max_points_per_shard > cap:
        logger.warning(
            f"max_points_per_shard={max_points_per_shard:,} is high for laminar_flow; "
            f"capping to {cap:,} to keep per-shard runtime tractable"
        )
        return cap
    return max_points_per_shard


def _compute_suggested_timeout(
    *,
    cost_per_shard: int,
    max_timeout: int,
    secs_per_unit: float = 5.0e-5,
    safety_factor: float = 5.0,
    min_timeout: int = 600,
) -> tuple[int, bool]:
    """Derive a timeout (seconds) from shard cost with clamping.

    cost_per_shard = num_chains * (num_warmup + num_samples) * max_points_per_shard

    Note: secs_per_unit=5.0e-5 with safety_factor=5.0 provides ~2.5x headroom
    above observed real-world runtimes to handle variance across different machines.

    Returns
    -------
    tuple[int, bool]
        (clamped timeout in seconds, whether the raw estimate exceeded max_timeout).
    """

    raw = safety_factor * secs_per_unit * cost_per_shard
    exceeded = raw > max_timeout
    clamped = min(max_timeout, max(min_timeout, raw))
    return int(clamped), exceeded


def _fmt_time(secs: float) -> str:
    """Format time nicely for display."""
    if secs < 60:
        return f"{secs:.0f}s"
    elif secs < 3600:
        return f"{secs / 60:.1f}min"
    else:
        return f"{secs / 3600:.1f}h"


def _estimate_n_workers() -> int:
    """Estimate the number of workers that will be used by the multiprocessing backend.

    This mirrors the logic in MultiprocessingBackend.__init__ to provide
    accurate runtime estimates before the backend is instantiated.

    Returns
    -------
    int
        Estimated number of worker processes.
    """
    import multiprocessing as mp

    # Try to get physical core count (same logic as multiprocessing backend)
    try:
        logical_cores = mp.cpu_count()
    except NotImplementedError:
        logical_cores = 4  # Conservative default

    # Estimate physical cores (assume 2 threads per core for HT)
    physical_cores_estimate = max(1, logical_cores // 2)

    # Reserve 1 core for main process (same as backend)
    n_workers = max(1, physical_cores_estimate - 1)

    return n_workers


def _log_runtime_estimate(
    logger,
    n_shards: int,
    n_chains: int,
    n_warmup: int,
    n_samples: int,
    avg_points_per_shard: int,
    n_workers: int | None = None,
    analysis_mode: str = "static",
    per_shard_timeout: int = 7200,
) -> float:
    """Log estimated CMC runtime for user awareness.

    Provides rough estimates based on empirical observations:
    - JIT compilation: scales with shard size (~45s for 5K, ~180s for 60K+)
    - MCMC step: ~0.1-0.5s per iteration (varies with point count)

    Returns
    -------
    float
        Estimated total runtime in seconds.
    """
    # Estimate worker count if not provided
    if n_workers is None:
        n_workers = _estimate_n_workers()

    # Estimate per-shard time
    # JIT overhead scales with shard size: larger shards compile bigger XLA graphs
    jit_overhead_per_shard = 45 + (avg_points_per_shard / 10_000) * 20
    iterations_per_shard = n_chains * (n_warmup + n_samples)

    # MCMC step time scales roughly with point count
    # Empirical: ~0.0001s per point per iteration for moderate complexity
    base_secs_per_iteration = 0.2 + (avg_points_per_shard / 100_000) * 0.3

    # Analysis mode factor - laminar_flow has more parameters and complexity
    mode_factor = 1.5 if analysis_mode == "laminar_flow" else 1.0
    secs_per_iteration = base_secs_per_iteration * mode_factor

    sampling_time_per_shard = iterations_per_shard * secs_per_iteration

    total_per_shard = jit_overhead_per_shard + sampling_time_per_shard

    # Parallel execution estimate
    batches = (n_shards + n_workers - 1) // n_workers
    total_parallel = batches * total_per_shard

    logger.info(
        f"Runtime estimate: {_fmt_time(total_parallel)} total "
        f"({n_shards} shards / {n_workers} workers, "
        f"~{_fmt_time(total_per_shard)}/shard with {iterations_per_shard:,} iterations)"
    )

    # Warn if per-shard estimate is close to timeout
    if total_per_shard > 0.8 * per_shard_timeout:
        margin_pct = (per_shard_timeout - total_per_shard) / per_shard_timeout * 100
        logger.warning(
            f"Per-shard estimate ({_fmt_time(total_per_shard)}) is within "
            f"{margin_pct:.0f}% of timeout ({_fmt_time(per_shard_timeout)}). "
            f"Timeouts are likely if JIT compilation or resource contention add overhead."
        )

    return total_parallel


def _log_runtime_comparison(
    logger,
    estimated_time: float,
    actual_time: float,
) -> None:
    """Log comparison of estimated vs actual runtime.

    Parameters
    ----------
    logger
        Logger instance.
    estimated_time : float
        Estimated runtime in seconds.
    actual_time : float
        Actual runtime in seconds.
    """
    if estimated_time <= 0:
        return

    accuracy = actual_time / estimated_time * 100

    if accuracy < 50:
        status = "much faster than estimated"
    elif accuracy < 90:
        status = "faster than estimated"
    elif accuracy <= 110:
        status = "close to estimate"
    elif accuracy <= 150:
        status = "slower than estimated"
    else:
        status = "much slower than estimated"

    logger.info(
        f"Runtime: {_fmt_time(actual_time)} actual vs {_fmt_time(estimated_time)} estimated "
        f"({accuracy:.0f}% - {status})"
    )

    # Provide suggestions if significantly off
    if accuracy > 200:
        logger.info("  -> Consider reducing num_samples or num_chains for faster runs")
    elif accuracy < 30:
        logger.info(
            "  -> Actual runtime much faster than expected - estimate may be conservative"
        )


def _infer_time_step(t1: np.ndarray, t2: np.ndarray) -> float:
    """Infer time step from pooled t1/t2 arrays (seconds).

    Uses the median positive difference across all unique time points to avoid
    being skewed by repeated values from meshgrid flattening.
    """
    time_values = np.unique(np.concatenate([np.asarray(t1), np.asarray(t2)]))
    if time_values.size < 2:
        logger.warning(
            f"Fewer than 2 unique time values ({time_values.size}); "
            "falling back to dt=1.0"
        )
        return 1.0

    diffs = np.diff(time_values)
    positive_diffs = diffs[diffs > 0]
    if positive_diffs.size == 0:
        logger.warning("No positive time differences found; falling back to dt=1.0")
        return 1.0

    return float(np.median(positive_diffs))


def _populate_reparam_priors(
    nlsq_prior_config: dict[str, Any],
    nlsq_values: dict[str, float],
    nlsq_uncertainties: dict[str, float] | None,
    t_ref: float,
    log: Any,
) -> None:
    """Populate reparameterized prior values in nlsq_prior_config.

    This is deferred from section 2e because t_ref is not available
    until the time grid is constructed in section 4.
    """
    reparam_vals, reparam_uncs = transform_nlsq_to_reparam_space(
        nlsq_values, nlsq_uncertainties, t_ref
    )
    nlsq_prior_config["reparam_values"] = reparam_vals
    nlsq_prior_config["reparam_uncertainties"] = reparam_uncs
    if reparam_vals:
        log.info(
            "NLSQ reparam values: "
            + ", ".join(f"{k}={v:.4g}" for k, v in reparam_vals.items())
        )


[docs] def fit_mcmc_jax( data: np.ndarray, t1: np.ndarray, t2: np.ndarray, phi: np.ndarray, q: float, L: float, analysis_mode: str, method: str = "mcmc", cmc_config: dict[str, Any] | None = None, initial_values: dict[str, float] | None = None, parameter_space: ParameterSpace | None = None, dt: float | None = None, output_dir: Path | str | None = None, progress_bar: bool = True, run_id: str | None = None, nlsq_result: dict | None = None, **kwargs, ) -> CMCResult: """Run CMC (Consensus Monte Carlo) analysis on XPCS data. This function signature matches the CLI call in cli/commands.py:1201. Parameters ---------- data : np.ndarray Pooled C2 correlation data, shape (n_total,). t1 : np.ndarray Pooled time coordinates t1, shape (n_total,). t2 : np.ndarray Pooled time coordinates t2, shape (n_total,). phi : np.ndarray Pooled phi angles, shape (n_total,). q : float Wavevector magnitude. L : float Stator-rotor gap length (nm). analysis_mode : str Analysis mode: "static" or "laminar_flow". method : str Method identifier (always "mcmc" for CMC). cmc_config : dict[str, Any] | None CMC configuration from ConfigManager.get_cmc_config(). initial_values : dict[str, float] | None Initial parameter values from ConfigManager.get_initial_parameters(). parameter_space : ParameterSpace Parameter space with bounds and priors from ParameterSpace.from_config(). dt : float | None Time step for physics model. If None, inferred from pooled time arrays. output_dir : Path | str | None Output directory for saving results. progress_bar : bool Whether to show progress bar during sampling. run_id : str | None Optional identifier used to correlate logs across shards/backends. nlsq_result : dict | None Optional NLSQ result dictionary for warm-start priors. When provided, builds informative priors centered on NLSQ estimates, improving convergence speed and reducing divergences. Should contain parameter values and optionally uncertainties (see extract_nlsq_values_for_cmc). **kwargs Additional keyword arguments (for compatibility). Returns ------- CMCResult Complete result with posterior samples and diagnostics. Raises ------ ValueError If data validation fails. RuntimeError If MCMC sampling fails. Examples -------- >>> from homodyne.optimization.cmc import fit_mcmc_jax >>> result = fit_mcmc_jax( ... data=c2_pooled, ... t1=t1_pooled, ... t2=t2_pooled, ... phi=phi_pooled, ... q=0.01, ... L=2000000.0, ... analysis_mode="laminar_flow", ... method="mcmc", ... cmc_config=config.get_cmc_config(), ... initial_values=config.get_initial_parameters(), ... parameter_space=parameter_space, ... ) >>> print(result.convergence_status) converged """ try: return _fit_mcmc_jax_impl( data=data, t1=t1, t2=t2, phi=phi, q=q, L=L, analysis_mode=analysis_mode, method=method, cmc_config=cmc_config, initial_values=initial_values, parameter_space=parameter_space, dt=dt, output_dir=output_dir, progress_bar=progress_bar, run_id=run_id, nlsq_result=nlsq_result, **kwargs, ) except Exception as exc: log_exception(logger, exc, context={"phase": "CMC fitting"}) raise
def _fit_mcmc_jax_impl( data: np.ndarray, t1: np.ndarray, t2: np.ndarray, phi: np.ndarray, q: float, L: float, analysis_mode: str, method: str = "mcmc", cmc_config: dict[str, Any] | None = None, initial_values: dict[str, float] | None = None, parameter_space: ParameterSpace | None = None, dt: float | None = None, output_dir: Path | str | None = None, progress_bar: bool = True, run_id: str | None = None, nlsq_result: dict | None = None, **kwargs, ) -> CMCResult: """Internal implementation of fit_mcmc_jax.""" start_time = time.perf_counter() run_identifier = run_id or datetime.now().strftime("%Y%m%d_%H%M%S") run_logger = with_context(logger, run=run_identifier, analysis="cmc") # Normalize analysis mode to canonical strings if "static" in analysis_mode.lower(): analysis_mode = "static" elif "laminar" in analysis_mode.lower(): analysis_mode = "laminar_flow" run_logger.info( f"Starting CMC analysis: {len(data):,} points, mode={analysis_mode}, q={q:.4f}" ) # ========================================================================= # 1. Parse configuration # ========================================================================= if cmc_config is None: cmc_config = {} config = CMCConfig.from_dict(cmc_config) config.run_id = getattr(config, "run_id", None) or run_identifier # Log configuration run_logger.info( f"CMC config: {config.num_chains} chains, " f"{config.num_warmup} warmup, {config.num_samples} samples" ) # NLSQ warm-start state (initialized here to avoid temporal coupling). # Populated in section 2e if nlsq_result is provided and priors are enabled. nlsq_prior_config: dict[str, Any] | None = None nlsq_values: dict[str, float] = {} nlsq_uncertainties: dict[str, float] | None = None # ========================================================================= # 2. Prepare and validate data # ========================================================================= prepared = prepare_mcmc_data(data, t1, t2, phi) # ========================================================================= # 2b. Determine per-angle mode and select appropriate model (v2.18.0+) # ========================================================================= # Extract per-angle mode from NLSQ result for parameterization parity (Jan 2026) nlsq_per_angle_mode = None if nlsq_result is not None: # Try to extract per_angle_mode from NLSQ result metadata metadata = ( nlsq_result.get("metadata", {}) if isinstance(nlsq_result, dict) else {} ) nlsq_per_angle_mode = metadata.get("per_angle_mode") if nlsq_per_angle_mode: run_logger.info( f"NLSQ warm-start detected per_angle_mode='{nlsq_per_angle_mode}' from NLSQ result" ) effective_per_angle_mode = config.get_effective_per_angle_mode( prepared.n_phi, nlsq_per_angle_mode=nlsq_per_angle_mode, has_nlsq_warmstart=(nlsq_result is not None), ) # Only use reparameterization for auto mode + laminar_flow (not constant_averaged) # When Fix 1 is active (NLSQ warm-start → constant_averaged), the reparameterized # model is NOT used because effective_per_angle_mode is "constant_averaged". # Reparameterization is the fallback for runs without NLSQ warm-start. use_reparam = ( effective_per_angle_mode == "auto" and analysis_mode == "laminar_flow" and (config.reparameterization_d_total or config.reparameterization_log_gamma) ) xpcs_model = get_xpcs_model( effective_per_angle_mode, use_reparameterization=use_reparam, ) run_logger.info( f"CMC per-angle mode: {config.per_angle_mode} -> {effective_per_angle_mode} " f"(n_phi={prepared.n_phi}, threshold={config.constant_scaling_threshold})" ) # ========================================================================= # 2c. Estimate fixed per-angle scaling for constant mode (v2.18.0+) # ========================================================================= # Mode semantics (param counts depend on analysis_mode: static=3, laminar_flow=7): # - "auto": xpcs_model_averaged SAMPLES single averaged contrast/offset # No fixed arrays needed - the model samples them # - "constant": xpcs_model_constant uses FIXED per-angle arrays # Requires fixed_contrast/fixed_offset arrays from quantile estimation # - "individual": xpcs_model_scaled SAMPLES per-angle contrast/offset # No fixed arrays needed - the model samples them n_physical = 7 if analysis_mode == "laminar_flow" else 3 fixed_contrast: jnp.ndarray | None = None fixed_offset: jnp.ndarray | None = None if effective_per_angle_mode in ("constant", "constant_averaged"): # CONSTANT/CONSTANT_AVERAGED mode: Use FIXED values from quantile estimation # Get contrast/offset bounds from parameter_space contrast_bounds = parameter_space.get_bounds("contrast") offset_bounds = parameter_space.get_bounds("offset") # Estimate per-angle contrast/offset from quantile analysis run_logger.info( f"{effective_per_angle_mode.upper()} mode: Estimating FIXED scaling from data quantiles..." ) scaling_estimates = estimate_per_angle_scaling( c2_data=prepared.data, t1=prepared.t1, t2=prepared.t2, phi_indices=prepared.phi_indices, n_phi=prepared.n_phi, contrast_bounds=contrast_bounds, offset_bounds=offset_bounds, log=run_logger, ) # Build per-angle arrays from estimates contrast_per_angle = np.array( [scaling_estimates[f"contrast_{i}"] for i in range(prepared.n_phi)] ) offset_per_angle = np.array( [scaling_estimates[f"offset_{i}"] for i in range(prepared.n_phi)] ) # NaN guard: reject corrupted per-angle scaling estimates if np.any(np.isnan(contrast_per_angle)): raise ValueError( "NaN detected in per-angle contrast estimates. " "Check data quality for angles with insufficient points." ) if np.any(np.isnan(offset_per_angle)): raise ValueError( "NaN detected in per-angle offset estimates. " "Check data quality for angles with insufficient points." ) fixed_contrast = jnp.array(contrast_per_angle) fixed_offset = jnp.array(offset_per_angle) if effective_per_angle_mode == "constant_averaged": # CONSTANT_AVERAGED mode: Model will internally average these run_logger.info( f"CONSTANT_AVERAGED mode: Using FIXED AVERAGED scaling (NLSQ parity):\n" f" contrast: per-angle range=[{np.nanmin(contrast_per_angle):.4f}, {np.nanmax(contrast_per_angle):.4f}], " f"avg={np.nanmean(contrast_per_angle):.4f} (will be used)\n" f" offset: per-angle range=[{np.nanmin(offset_per_angle):.4f}, {np.nanmax(offset_per_angle):.4f}], " f"avg={np.nanmean(offset_per_angle):.4f} (will be used)\n" f" Parameters: {n_physical} physical + 1 sigma = {n_physical + 1} total (scaling fixed, averaged)" ) else: # CONSTANT mode: Different value per angle run_logger.info( f"CONSTANT mode: Using FIXED per-angle scaling (NOT sampled):\n" f" contrast: mean={np.nanmean(contrast_per_angle):.4f}, " f"range=[{np.nanmin(contrast_per_angle):.4f}, {np.nanmax(contrast_per_angle):.4f}]\n" f" offset: mean={np.nanmean(offset_per_angle):.4f}, " f"range=[{np.nanmin(offset_per_angle):.4f}, {np.nanmax(offset_per_angle):.4f}]\n" f" Parameters: {n_physical} physical + 1 sigma = {n_physical + 1} total (scaling fixed)" ) elif effective_per_angle_mode == "auto": # AUTO mode: xpcs_model_averaged will SAMPLE single averaged contrast/offset # No fixed arrays needed - log the expected behavior run_logger.info( f"AUTO mode: Will SAMPLE averaged contrast/offset (NLSQ parity):\n" f" Parameters: 2 averaged scaling + {n_physical} physical + 1 sigma = {n_physical + 3} total" ) else: # INDIVIDUAL mode: xpcs_model_scaled will SAMPLE per-angle contrast/offset run_logger.info( f"INDIVIDUAL mode: Will SAMPLE per-angle contrast/offset:\n" f" Parameters: {prepared.n_phi * 2} per-angle + {n_physical} physical + 1 sigma = " f"{prepared.n_phi * 2 + n_physical + 1} total" ) # Log initial values if provided if initial_values: run_logger.info( f"Initial values: {', '.join(f'{k}={v:.4g}' for k, v in list(initial_values.items())[:5])}..." ) else: run_logger.info("No initial values provided, using midpoint defaults") # ========================================================================= # 2d. NLSQ Warm-Start Validation (Jan 2026) # ========================================================================= # Warn or error if running laminar_flow without NLSQ warm-start # The 7-parameter laminar_flow model spans 6+ orders of magnitude # (D0 ~ 1e4, gamma_dot_t0 ~ 1e-3) and NUTS struggles to adapt without # good initial values. Cold-start runs showed 28% divergence rates vs <5% # with NLSQ warm-start. no_warmstart = nlsq_result is None and not initial_values if no_warmstart and analysis_mode == "laminar_flow": require_warmstart = config.require_nlsq_warmstart if require_warmstart: raise ValueError( "CMC WARM-START REQUIRED: laminar_flow mode requires NLSQ warm-start " "when require_nlsq_warmstart=True. Run NLSQ first and pass nlsq_result " "to fit_mcmc_jax(), or set require_nlsq_warmstart=False in CMC config." ) else: run_logger.warning( "CMC WARM-START ADVISORY: Running laminar_flow without NLSQ warm-start. " "This is strongly discouraged for production use because:\n" " 1. 7 parameters span 6+ orders of magnitude (D0~1e4, gamma_dot_t0~1e-3)\n" " 2. NUTS adaptation may waste warmup exploring implausible regions\n" " 3. Higher divergence rates and inflated posterior uncertainty expected\n" "Recommendation: Run NLSQ first and pass nlsq_result to fit_mcmc_jax()\n" "To enforce this, set validation.require_nlsq_warmstart=true in CMC config." ) # ========================================================================= # 2e. NLSQ Warm-Start (Jan 2026): Use NLSQ results for better init values # ========================================================================= if nlsq_result is not None: from homodyne.optimization.cmc.priors import extract_nlsq_values_for_cmc nlsq_values, nlsq_uncertainties = extract_nlsq_values_for_cmc(nlsq_result) # Override initial_values with NLSQ estimates (if not already provided) if initial_values is None: initial_values = {} # Merge NLSQ values into initial_values (NLSQ takes precedence for physical params) physical_params = ["D0", "alpha", "D_offset"] if analysis_mode == "laminar_flow": physical_params.extend( ["gamma_dot_t0", "beta", "gamma_dot_t_offset", "phi0"] ) nlsq_used = [] for param in physical_params: if param in nlsq_values: val = nlsq_values[param] # P2-3: Clip NLSQ values to CMC parameter bounds to prevent # out-of-bounds initial values from crashing NUTS initialization. if parameter_space is not None: try: lo, hi = parameter_space.get_bounds(param) if val < lo or val > hi: clipped = max(lo, min(hi, val)) run_logger.warning( f"NLSQ value {param}={val:.4g} outside CMC bounds " f"[{lo:.4g}, {hi:.4g}], clipping to {clipped:.4g}" ) val = clipped except KeyError: pass # Parameter not in space (e.g., reparameterized) initial_values[param] = val nlsq_used.append(param) run_logger.info( f"NLSQ warm-start: Using NLSQ estimates for {len(nlsq_used)} params: " f"{', '.join(f'{p}={nlsq_values[p]:.4g}' for p in nlsq_used[:5])}" + ("..." if len(nlsq_used) > 5 else "") ) # Log NLSQ uncertainties if available (useful for posterior comparison) if nlsq_uncertainties: unc_str = ", ".join( f"{p}+/-{nlsq_uncertainties[p]:.4g}" for p in nlsq_used[:5] if p in nlsq_uncertainties ) if unc_str: run_logger.info(f"NLSQ uncertainties: {unc_str}") # ===================================================================== # Feb 2026: Build NLSQ-informed prior config for model functions # ===================================================================== # Store as plain dict (serializable for multiprocessing workers) # Distribution objects aren't picklable, so we pass the raw values # and let each model function build its own TruncatedNormal priors. if config.use_nlsq_informed_priors: nlsq_prior_config = { "values": nlsq_values, "uncertainties": nlsq_uncertainties, "width_factor": config.nlsq_prior_width_factor, "reparam_values": {}, "reparam_uncertainties": {}, } run_logger.info( f"NLSQ-informed priors: enabled (width_factor={config.nlsq_prior_width_factor})" ) # NOTE: reparam_values/uncertainties are populated later (after t_ref # is computed from the time grid). See "Populate reparam priors" below. # ========================================================================= # 3. Determine if CMC sharding is needed # ========================================================================= def _int_like(val) -> bool: return isinstance(val, int) or (isinstance(val, str) and val.isdigit()) forced_shards = _int_like(config.num_shards) or _int_like( config.max_points_per_shard ) use_cmc = config.should_enable_cmc(prepared.n_total, analysis_mode) or forced_shards # Resolve max_points_per_shard - critical for NUTS tractability # Scale inversely with parameter count: more params = fewer points per shard # Also scale inversely with iteration count relative to default (Feb 2026) max_points_setting = config.max_points_per_shard _DEFAULT_ITERATIONS = 2000 # CMCConfig defaults: 500 warmup + 1500 samples actual_iterations = config.num_warmup + config.num_samples iteration_ratio = _DEFAULT_ITERATIONS / max(1, actual_iterations) max_per_shard = _resolve_max_points_per_shard( analysis_mode, prepared.n_total, max_points_setting, n_phi=prepared.n_phi, iteration_ratio=iteration_ratio, ) if analysis_mode == "laminar_flow": max_per_shard = _cap_laminar_max_points(max_per_shard, run_logger) if max_points_setting is None or max_points_setting == "auto": run_logger.info( f"Auto-selected max_points_per_shard={max_per_shard} for {analysis_mode} mode " f"(n_total={prepared.n_total:,}, n_phi={prepared.n_phi})" ) # Derive a suggested per-shard timeout from cost cost_per_shard = ( config.num_chains * (config.num_warmup + config.num_samples) * max_per_shard ) suggested_timeout, cost_exceeded = _compute_suggested_timeout( cost_per_shard=cost_per_shard, max_timeout=config.per_shard_timeout, ) run_logger.info( f"Suggested per-shard timeout: {suggested_timeout}s (cost={cost_per_shard:,}, " f"chains={config.num_chains}, warmup+samples={config.num_warmup + config.num_samples}, " f"max_points_per_shard={max_per_shard:,}, clamp=[600,{config.per_shard_timeout}])" ) if cost_exceeded: run_logger.warning( f"Per-shard cost ({cost_per_shard:,}) exceeds timeout budget. " f"Shards may timeout at {config.per_shard_timeout}s. " f"Consider reducing num_warmup/num_samples or max_points_per_shard." ) requested_shards = int(config.num_shards) if _int_like(config.num_shards) else None sharding_mode = config.sharding_strategy # CRITICAL FIX (Jan 2026): Force random sharding for multi-angle datasets with global parameters. # Stratified sharding (by angle) creates disjoint posteriors that cannot be combined # by Consensus MC for global parameters (like D0, alpha, phi0). if use_cmc and prepared.n_phi > 1 and sharding_mode == "stratified": run_logger.warning( "Overriding sharding_strategy='stratified' -> 'random' for multi-angle data. " "Stratified sharding violates Consensus MC assumptions for global parameters." ) sharding_mode = "random" # Safety constant: NUTS is O(n) per leapfrog step — never run 100K+ points. _SINGLE_SHARD_HARD_LIMIT = 100_000 # CLAUDE.md: "Never use 100K+" if sharding_mode == "stratified" and use_cmc: # Shard by phi angle (stratified) - Only valid for disjoint models (no global params) # or single-angle data (where n_phi=1, handled above) num_shards = ( requested_shards if requested_shards is not None else config.get_num_shards(prepared.n_total, prepared.n_phi) ) shards = shard_data_stratified( prepared, num_shards, max_points_per_shard=max_per_shard ) total_shard_points = sum(s.n_total for s in shards) run_logger.info( f"Using CMC with {len(shards)} shards (stratified by phi), " f"{total_shard_points:,} total points" ) estimated_runtime = _log_runtime_estimate( run_logger, n_shards=len(shards), n_chains=config.num_chains, n_warmup=config.num_warmup, n_samples=config.num_samples, avg_points_per_shard=total_shard_points // len(shards), analysis_mode=analysis_mode, per_shard_timeout=suggested_timeout, ) elif use_cmc: # Sharding strategy selection (Jan 2026 enhancement): # - Multi-angle (n_phi > 1): Use angle-balanced sharding for consistent posteriors # - Single-angle: Use random sharding (i.i.d. statistically correct for Consensus MC) # # Dynamic max_shards based on analysis mode and data size. # laminar_flow needs more shards (smaller size) due to O(n) NUTS cost # with 7 physical params vs 3 for static. if analysis_mode == "laminar_flow": max_shards_for_mode = 1000 else: max_shards_for_mode = 500 if prepared.n_phi > 1: # Angle-balanced sharding ensures each shard has proportional angle coverage # This prevents heterogeneous posteriors (e.g., D_offset CV=1.58) shards = shard_data_angle_balanced( prepared, num_shards=requested_shards, # Honor explicit shards when provided max_points_per_shard=max_per_shard, max_shards=max_shards_for_mode, min_angle_coverage=0.8, # Require 80% angle coverage per shard ) sharding_desc = "angle-balanced" else: # Single-angle: random sharding is fine shards = shard_data_random( prepared, num_shards=requested_shards, # Honor explicit shards when provided max_points_per_shard=max_per_shard, max_shards=max_shards_for_mode, # Dynamic cap matching multi-angle path ) sharding_desc = "random" total_shard_points = sum(s.n_total for s in shards) run_logger.info( f"Using CMC with {len(shards)} shards ({sharding_desc}), " f"{total_shard_points:,} total points" ) estimated_runtime = _log_runtime_estimate( run_logger, n_shards=len(shards), n_chains=config.num_chains, n_warmup=config.num_warmup, n_samples=config.num_samples, avg_points_per_shard=total_shard_points // len(shards), analysis_mode=analysis_mode, per_shard_timeout=suggested_timeout, ) else: shards = None estimated_runtime = 0.0 # No estimate for single-shard # Non-CMC single-shard: check dataset size if prepared.n_total > _SINGLE_SHARD_HARD_LIMIT: run_logger.warning( f"Single-shard MCMC requested but dataset ({prepared.n_total:,} pts) " f"exceeds hard limit ({_SINGLE_SHARD_HARD_LIMIT:,}). " f"Forcing CMC sharding with max_points_per_shard={max_per_shard:,}." ) shards = shard_data_random( prepared, num_shards=requested_shards, max_points_per_shard=max_per_shard, max_shards=500, ) total_shard_points = sum(s.n_total for s in shards) run_logger.info( f"Safety-fallback: Using CMC with {len(shards)} shards (random), " f"{total_shard_points:,} total points" ) estimated_runtime = _log_runtime_estimate( run_logger, n_shards=len(shards), n_chains=config.num_chains, n_warmup=config.num_warmup, n_samples=config.num_samples, avg_points_per_shard=total_shard_points // len(shards), analysis_mode=analysis_mode, per_shard_timeout=suggested_timeout, ) else: run_logger.info( f"Using single-shard MCMC ({prepared.n_total:,} pts, no CMC sharding)" ) # P1-7: Safety guard for forced single-shard CMC on large datasets. # If user set num_shards=1 explicitly, the sharding path creates 1 shard, # but the single shard runs NUTS on all points — O(n) per leapfrog step. if shards is not None and len(shards) == 1: shard_size = shards[0].n_total if shard_size > _SINGLE_SHARD_HARD_LIMIT: run_logger.warning( f"Forced single-shard CMC but shard ({shard_size:,} pts) exceeds " f"hard limit ({_SINGLE_SHARD_HARD_LIMIT:,}). Re-sharding with " f"max_points_per_shard={max_per_shard:,}." ) shards = shard_data_random( prepared, num_shards=None, max_points_per_shard=max_per_shard, max_shards=500, ) # ========================================================================= # 4. Build model function # ========================================================================= # CRITICAL FIX (Dec 2025): Construct time_grid with PROPER dt spacing # Previously used np.unique(t1, t2) which gave incorrect grid density # when data is subsampled or pooled from shards with different time points. # # The physics integration (trapezoidal cumsum) depends critically on grid density: # - With dt=0.1s and t_max=100s, need 1001 points for correct physics # - Using np.unique gave variable n_points (e.g., 201 with subsampled data) # - This caused up to 26% error in C2 values vs NLSQ (see scripts/compare_nlsq_cmc_c2.py) # # Fix: Construct time_grid from config dt, NOT from data unique values # First determine dt to use (config dt takes precedence) inferred_dt = _infer_time_step(prepared.t1, prepared.t2) dt_used = dt if dt is not None else inferred_dt if not np.isfinite(dt_used) or dt_used <= 0: dt_used = inferred_dt if np.isfinite(inferred_dt) and inferred_dt > 0 else 0.1 run_logger.warning( f"Invalid dt provided; using inferred fallback dt={dt_used:.6g} seconds" ) else: # Check for mismatch between config dt and data dt rel_diff = ( abs(dt_used - inferred_dt) / max(inferred_dt, 1e-12) if np.isfinite(inferred_dt) and inferred_dt > 0 else 0.0 ) if dt is None: run_logger.info(f"Inferred dt from pooled times: dt={dt_used:.6g} seconds") elif rel_diff > 1e-2: # >1% mismatch is significant # CRITICAL FIX (Jan 2026): Prioritize DATA TRUTH over config for dt # If config says dt=0.1s but data says dt=1e-5s, using dt=0.1s constructs # a coarse grid that collapses all data to index 0 (g1=1.0, no decay). # We MUST use the inferred_dt to match the actual data timestamps. run_logger.warning( f"[CMC] CRITICAL dt mismatch detected!\n" f" Config dt: {dt_used:.6g}s\n" f" Inferred dt: {inferred_dt:.6g}s\n" f" Mismatch: {rel_diff:.1%} (>1%)\n" f"Action: OVERRIDING config dt with Inferred dt to prevent physics collapse.\n" f"Please check your configuration or data timestamps." ) dt_used = inferred_dt elif rel_diff > 1e-4: run_logger.info( f"[CMC] Minor dt mismatch ({rel_diff:.2%}): {dt_used:.6g}s vs {inferred_dt:.6g}s. Using config dt." ) # CRITICAL: Construct time_grid with CORRECT dt spacing to match NLSQ physics # The grid must have the same density as NLSQ (e.g., dt=0.1s gives 1001 points for [0, 100]) t1_np = np.asarray(prepared.t1) t2_np = np.asarray(prepared.t2) t_min = 0.0 # Always start from t=0 for consistent integration t_max = float(max(float(np.nanmax(t1_np)), float(np.nanmax(t2_np)))) n_time_points = int(round(t_max / dt_used)) + 1 # Guard against OOM: cap time grid at 100K points. For typical XPCS # experiments (dt~0.1s, t_max~100s) this gives ~1001 points. # High-frequency data (dt~1e-5, t_max~10000) would request 1B points. MAX_TIME_GRID_POINTS = 100_000 if n_time_points > MAX_TIME_GRID_POINTS: run_logger.warning( f"[CMC] Time grid would require {n_time_points:,} points " f"(dt={dt_used:.6g}, t_max={t_max:.6g}). " f"Capping at {MAX_TIME_GRID_POINTS:,} to prevent OOM. " f"Effective dt will be {t_max / (MAX_TIME_GRID_POINTS - 1):.6g}s." ) n_time_points = MAX_TIME_GRID_POINTS time_grid_np = np.linspace(t_min, t_max, n_time_points) time_grid = jnp.array(time_grid_np) # Compute reference time for reparameterization (geometric mean of time range). # dt_used is validated positive above; t_max > 0 because prepared data is # non-empty (checked in prepare_mcmc_data). Guard defensively for future refactors. t_ref = compute_t_ref(dt_used, t_max, fallback_value=1.0) if t_ref == 1.0 and ( dt_used <= 0 or t_max <= 0 or not np.isfinite(dt_used) or not np.isfinite(t_max) ): run_logger.info("[CMC] Reference time: t_ref=1.0 (fallback)") else: run_logger.info( f"[CMC] Reference time: t_ref={t_ref:.6g}s " f"(sqrt({dt_used:.6g} * {t_max:.6g}))" ) # Populate reparam priors now that t_ref is available (deferred from section 2e) if nlsq_prior_config is not None: _populate_reparam_priors( nlsq_prior_config, nlsq_values, nlsq_uncertainties, t_ref, run_logger ) # Log time_grid construction details run_logger.info( f"[CMC] time_grid constructed with dt={dt_used:.6g}s: " f"n_points={n_time_points}, range=[{t_min:.6g}, {t_max:.6g}]" ) t1_lo, t1_hi = float(np.nanmin(t1_np)), float(np.nanmax(t1_np)) t2_lo, t2_hi = float(np.nanmin(t2_np)), float(np.nanmax(t2_np)) run_logger.info( f"[CMC] Data time ranges: t1=[{t1_lo:.6g}, {t1_hi:.6g}], " f"t2=[{t2_lo:.6g}, {t2_hi:.6g}]" ) # Verify grid spacing matches config dt actual_grid_dt = ( (time_grid_np[1] - time_grid_np[0]) if len(time_grid_np) > 1 else dt_used ) # P2-R6-08: Use relative tolerance; absolute 1e-6 is too tight for large dt # (e.g. dt=1.5s gives float rounding ~2e-16 which compares fine) and too # loose for small dt (e.g. dt=1e-5s capped to 100K grid gives ~0.01s). _grid_rel_diff = abs(actual_grid_dt - dt_used) / max(dt_used, 1e-15) if _grid_rel_diff > 1e-6: run_logger.warning( f"[CMC] Grid spacing {actual_grid_dt:.6g}s differs from config dt={dt_used:.6g}s " f"(relative diff={_grid_rel_diff:.2e})" ) phi_indices_arr = jnp.array(prepared.phi_indices) model_kwargs = { "data": jnp.array(prepared.data), "t1": jnp.array(prepared.t1), "t2": jnp.array(prepared.t2), "phi_unique": jnp.array(prepared.phi_unique), "phi_indices": phi_indices_arr, "q": q, "L": L, "dt": dt_used, "time_grid": time_grid, "analysis_mode": analysis_mode, "parameter_space": parameter_space, "n_phi": prepared.n_phi, "noise_scale": prepared.noise_scale, # P0-1: Pre-compute scaling factors once (pure function of static inputs, # avoids ~50K Python allocations per NUTS leapfrog step). "scalings": ( compute_scaling_factors(parameter_space, prepared.n_phi, analysis_mode) if parameter_space is not None else None ), # P0-2: Pre-compute wavevector constants (depend only on q, L, dt). "wavevector_q_squared_half_dt": jnp.asarray(0.5 * (q**2) * dt_used), "sinc_prefactor": jnp.asarray(0.5 / math.pi * q * L * dt_used), # P1-3: Pre-compute point index array (depends only on data shape). "point_idx": jnp.arange(phi_indices_arr.shape[0], dtype=jnp.int32), } # D2: Pre-compute shard-constant grid quantities for single-shard path only. # Multi-shard: each worker builds its own ShardGrid in _run_shard_worker. if shards is None or len(shards) <= 1: try: from homodyne.core.physics_cmc import precompute_shard_grid model_kwargs["shard_grid"] = precompute_shard_grid( time_grid, model_kwargs["t1"], model_kwargs["t2"], dt_used, ) except ImportError as exc: run_logger.debug(f"precompute_shard_grid not available: {exc}") except (ValueError, RuntimeError) as exc: run_logger.warning( f"precompute_shard_grid failed (non-fatal, using legacy path): {exc}" ) # Add fixed scaling arrays for constant/constant_averaged mode (v2.18.0+) if ( effective_per_angle_mode in ("constant", "constant_averaged") and fixed_contrast is not None ): model_kwargs["fixed_contrast"] = fixed_contrast model_kwargs["fixed_offset"] = fixed_offset # P0-4: Pass global phi_unique so shards can remap shard-local phi_indices # to global indices when looking up fixed_contrast/fixed_offset arrays. model_kwargs["global_phi_unique"] = jnp.array(prepared.phi_unique) # Prior tempering (Feb 2026): pass actual shard count so each shard's model # uses prior^(1/K) via Normal(0, sqrt(K)) instead of Normal(0, 1). # Single-shard mode (num_shards=1) produces identical behavior to untampered priors. if config.prior_tempering and shards is not None: model_kwargs["num_shards"] = len(shards) else: model_kwargs["num_shards"] = 1 # Propagate per_angle_mode for sampler and workers to build correct init dicts model_kwargs["per_angle_mode"] = effective_per_angle_mode # Add NLSQ-informed prior config if available (built in section 2e) if nlsq_prior_config is not None: model_kwargs["nlsq_prior_config"] = nlsq_prior_config # Add reparameterization config and t_ref when active (Feb 2026) if use_reparam: from homodyne.optimization.cmc.reparameterization import ReparamConfig model_kwargs["reparam_config"] = ReparamConfig( enable_d_ref=config.reparameterization_d_total, enable_gamma_ref=config.reparameterization_log_gamma, t_ref=t_ref, ) model_kwargs["t_ref"] = t_ref # DEBUG: Log model_kwargs for diagnosis run_logger.debug( f"[CMC DEBUG] model_kwargs: q={q:.6g}, L={L:.6g}, dt={dt_used:.6g}, " f"n_phi={prepared.n_phi}, noise_scale={prepared.noise_scale:.6g}" ) run_logger.debug(f"[CMC DEBUG] phi_unique: {prepared.phi_unique}") # DEBUG: Compute and log D values at sample times to verify physics if initial_values: D0_init = initial_values.get("D0", 1e10) alpha_init = initial_values.get("alpha", -0.5) D_offset_init = initial_values.get("D_offset", 1e9) # Sample D at a few time points t_samples = np.array([0.0, 1.0, 10.0, 50.0]) t_safe = t_samples + 1e-10 D_samples = D0_init * (t_safe**alpha_init) + D_offset_init run_logger.debug( f"[CMC DEBUG] D(t) at sample times with initial params:\n" f" D0={D0_init:.4g}, alpha={alpha_init:.4g}, D_offset={D_offset_init:.4g}\n" f" t=[0, 1, 10, 50] -> D={D_samples}" ) # Compute expected prefactor wavevector_q_squared_half_dt = 0.5 * (q**2) * dt_used run_logger.debug( f"[CMC DEBUG] Physics prefactor: 0.5*q^2*dt = 0.5*{q}^2*{dt_used} = {wavevector_q_squared_half_dt:.6g}" ) # ========================================================================= # 5. Select backend and run sampling # ========================================================================= stats = None # Only set for single-shard path if shards is not None and len(shards) > 1: # Use parallel backend for CMC backend = select_backend(config) run_logger.info(f"Using backend: {backend.get_name()}") # Enforce timeout only where supported (multiprocessing). Others log advisory. # Build a shallow copy of the config so we never mutate the caller's object. import dataclasses as _dc run_config = _dc.replace(config) # shallow copy; all scalar fields are copied if backend.get_name().startswith("multiprocessing"): effective_timeout = min(config.per_shard_timeout, suggested_timeout) if effective_timeout != config.per_shard_timeout: run_logger.info( f"Applying tighter per_shard_timeout={effective_timeout}s based on shard cost" ) run_config.per_shard_timeout = effective_timeout else: run_logger.warning( f"Backend '{backend.get_name()}' does not enforce per_shard_timeout; " f"suggested={suggested_timeout}s (cap={config.per_shard_timeout}s)" ) # T047: Log shard progress start run_logger.info( f"Starting CMC sampling: {len(shards)} shards, " f"{config.num_chains} chains, {config.num_warmup}+{config.num_samples} samples" f"{' (adaptive per-shard)' if config.adaptive_sampling else ''}" ) mcmc_samples = backend.run( model=xpcs_model, model_kwargs=model_kwargs, config=run_config, # use copy with effective timeout, not caller's config shards=shards, initial_values=initial_values, parameter_space=parameter_space, analysis_mode=analysis_mode, progress_bar=progress_bar, ) run_logger.info(f"CMC sampling completed: all {len(shards)} shards finished") stats_warmup = 0.0 # Not tracked for parallel stats_total = time.perf_counter() - start_time else: # Single-shard: run directly mcmc_samples, stats = run_nuts_sampling( model=xpcs_model, model_kwargs=model_kwargs, config=config, initial_values=initial_values, parameter_space=parameter_space, n_phi=prepared.n_phi, analysis_mode=analysis_mode, progress_bar=progress_bar, per_angle_mode=effective_per_angle_mode, ) stats_warmup = stats.warmup_time stats_total = stats.total_time # ========================================================================= # 6. Create result # ========================================================================= from homodyne.optimization.cmc.sampler import SamplingStats # Single-shard: use adapted warmup from plan (may differ from config default) # Multi-shard: use config default (workers adapt independently) if stats is not None: if stats.plan is not None: actual_n_warmup = stats.plan.n_warmup actual_plan = stats.plan else: run_logger.warning( "SamplingPlan invariant violated: stats without plan. " "Using config defaults." ) actual_n_warmup = config.num_warmup actual_plan = None # Reuse stats.num_divergent (already computed accurately in run_nuts_sampling) num_divergent = stats.num_divergent else: # Multi-shard: use median adapted warmup from workers if available, # otherwise fall back to config default. actual_n_warmup = ( mcmc_samples.shard_adapted_n_warmup if mcmc_samples.shard_adapted_n_warmup is not None else config.num_warmup ) actual_plan = None num_divergent = int( mcmc_samples.extra_fields.get("diverging", np.array([0])).sum() ) final_stats = SamplingStats( warmup_time=stats_warmup, sampling_time=stats_total - stats_warmup, total_time=stats_total, num_divergent=num_divergent, plan=actual_plan, ) result = CMCResult.from_mcmc_samples( mcmc_samples=mcmc_samples, stats=final_stats, analysis_mode=analysis_mode, n_warmup=actual_n_warmup, min_ess=config.min_ess, ) # ========================================================================= # 7. Log summary # ========================================================================= summary = summarize_diagnostics( r_hat=result.r_hat, ess_bulk=result.ess_bulk, divergences=result.divergences, n_samples=result.n_samples, n_chains=result.n_chains, num_shards=result.num_shards, ) run_logger.info(f"CMC complete: {result.convergence_status}") run_logger.info(summary) # Log parameter estimates stats_dict = result.get_posterior_stats() for name in get_param_names_in_order(prepared.n_phi, analysis_mode)[:5]: if name in stats_dict: s = stats_dict[name] run_logger.info( f" {name}: {s['mean']:.4g} +/- {s['std']:.4g} " f"(R-hat={s['r_hat']:.3f}, ESS={s['ess_bulk']:.0f})" ) # ========================================================================= # 8. Precision analysis (Jan 2026): Compare CMC posteriors to NLSQ # ========================================================================= if nlsq_result is not None: # nlsq_values, nlsq_uncertainties already extracted in step 2e above precision_analysis = compute_precision_analysis( cmc_result=stats_dict, nlsq_result=nlsq_values, nlsq_uncertainties=nlsq_uncertainties, ) # Log comprehensive precision report log_precision_analysis(precision_analysis, log_fn=run_logger.info) # Warn if CMC significantly disagrees with NLSQ high_z_params = [ (p, m.get("z_score", 0)) for p, m in precision_analysis.items() if m.get("z_score", 0) > 3 ] if high_z_params: run_logger.warning( "CMC-NLSQ disagreement (z > 3): " + ", ".join(f"{p} (z={z:.1f})" for p, z in high_z_params) ) total_time = time.perf_counter() - start_time run_logger.info(f"Total execution time: {total_time:.1f}s") # Log runtime comparison if we had an estimate if estimated_runtime > 0: _log_runtime_comparison(run_logger, estimated_runtime, total_time) return result
[docs] def run_cmc_analysis( data: np.ndarray, t1: np.ndarray, t2: np.ndarray, phi: np.ndarray, q: float, L: float, analysis_mode: str, config: CMCConfig, parameter_space: ParameterSpace, initial_values: dict[str, float] | None = None, dt: float | None = None, ) -> CMCResult: """Simplified interface for CMC analysis. This is a convenience wrapper around fit_mcmc_jax() that takes a CMCConfig object directly instead of a dict. Parameters ---------- data, t1, t2, phi : np.ndarray Data arrays. q, L : float Physics parameters. analysis_mode : str Analysis mode. config : CMCConfig CMC configuration object. parameter_space : ParameterSpace Parameter space. initial_values : dict[str, float] | None Initial values. dt : float | None Time step (None infers from pooled time arrays). Returns ------- CMCResult Analysis result. """ return fit_mcmc_jax( data=data, t1=t1, t2=t2, phi=phi, q=q, L=L, analysis_mode=analysis_mode, method="mcmc", cmc_config=config.to_dict(), initial_values=initial_values, parameter_space=parameter_space, dt=dt, )