Source code for homodyne.optimization.cmc.sampler

"""NUTS sampler wrapper for CMC analysis.

This module provides utilities for running NumPyro NUTS sampling
with proper initialization and progress tracking.
"""

from __future__ import annotations

import time
from collections.abc import Callable
from dataclasses import dataclass, field, replace
from typing import TYPE_CHECKING, Any

import jax
import numpy as np
from numpyro.infer import MCMC, NUTS
from numpyro.infer.initialization import init_to_median, init_to_value

from homodyne.optimization.cmc.config import CMCConfig
from homodyne.optimization.cmc.priors import (
    build_init_values_dict,
    get_param_names_in_order,
)
from homodyne.optimization.cmc.scaling import (
    compute_scaling_factors,
    transform_initial_values_to_z,
)
from homodyne.utils.logging import get_logger, with_context

# Divergence rate thresholds for NUTS convergence diagnostics.
# P2-2: Centralized constants to ensure consistency across sampler paths.
DIVERGENCE_RATE_TARGET = 0.05  # Below this: acceptable
DIVERGENCE_RATE_HIGH = 0.10  # Above this: posterior may be biased
DIVERGENCE_RATE_CRITICAL = 0.30  # Above this: posterior likely unreliable


def _subset_model_kwargs_for_preflight(
    model_kwargs: dict[str, Any],
    *,
    max_points: int = 512,
) -> dict[str, Any]:
    """Build a reduced model_kwargs dict for fast preflight diagnostics."""
    data = model_kwargs.get("data")
    if data is None:
        return dict(model_kwargs)

    if not hasattr(data, "shape"):
        return dict(model_kwargs)

    n = int(data.shape[0])
    if n <= max_points:
        return dict(model_kwargs)

    # Pick evenly spaced points to cover the full time range.
    idx = np.linspace(0, n - 1, num=max_points, dtype=np.int64)
    reduced = dict(model_kwargs)
    for key in ("data", "t1", "t2", "phi_indices"):
        arr = model_kwargs.get(key)
        if arr is not None:
            reduced[key] = arr[idx]
    # Drop pre-computed shard_grid: its idx1/idx2 are aligned with the full
    # shard (length N), not the preflight subset (length max_points).  The
    # model functions fall back to the legacy compute_g1_total path when
    # shard_grid is absent.
    reduced.pop("shard_grid", None)
    return reduced


def _preflight_log_density(
    *,
    model: Callable,
    model_kwargs: dict[str, Any],
    params: dict[str, Any],
    run_logger,
    max_points: int = 512,
) -> None:
    """Compute initial log density and basic finiteness diagnostics.

    This is intended to catch the common failure modes behind near-zero
    acceptance (NaN/-inf log prob, non-finite deterministics) before spending
    hours running many CMC shards.
    """
    try:
        from numpyro import handlers
        from numpyro.infer.util import log_density

        subset_kwargs = _subset_model_kwargs_for_preflight(
            model_kwargs, max_points=max_points
        )

        seeded = handlers.seed(model, jax.random.PRNGKey(0))
        log_joint, trace = log_density(seeded, (), subset_kwargs, params)
        log_joint_val = float(np.asarray(log_joint))

        n_nonfinite_log_prob = 0
        n_total_log_prob = 0
        for site in trace.values():
            if site.get("type") != "sample":
                continue

            fn = site.get("fn")
            value = site.get("value")
            if fn is None or value is None:
                continue

            try:
                log_prob = fn.log_prob(value)
            except Exception:  # nosec B112
                continue

            log_prob_np = np.asarray(log_prob)
            n_total_log_prob += log_prob_np.size
            n_nonfinite_log_prob += int(np.sum(~np.isfinite(log_prob_np)))

        n_issues = None
        if "n_numerical_issues" in trace:
            try:
                n_issues = int(np.asarray(trace["n_numerical_issues"]["value"]))
            except Exception:
                n_issues = None

        run_logger.info(
            "Preflight log_density: "
            f"log_joint={log_joint_val:.4g}, "
            f"nonfinite_log_prob={n_nonfinite_log_prob}/{n_total_log_prob}, "
            f"n_numerical_issues={n_issues}"
        )

        if not np.isfinite(log_joint_val) or n_nonfinite_log_prob > 0:
            raise RuntimeError(
                "Preflight detected non-finite log density/log_prob at initialization. "
                "This typically leads to 0% NUTS acceptance; check bounds, initial values, "
                "and numerical stability of the physics model."
            )

    except RuntimeError:
        raise
    except Exception as e:
        # If the preflight itself fails, keep going but make it loud.
        run_logger.warning(f"Preflight diagnostics failed: {e}")


def _compute_mcmc_safe_d0(
    initial_values: dict[str, float] | None,
    q: float,
    dt: float,
    time_grid: np.ndarray | None,
    logger_inst,
    *,
    target_g1: float = 0.5,
    g1_threshold: float = 0.1,
) -> dict[str, float] | None:
    """Check if initial D0 causes g1→0 and compute MCMC-safe adjustment.

    When D0 is very large (or alpha very negative), the diffusion integral
    can become enormous, causing g1 = exp(-integral) → 0. This creates
    a flat likelihood surface with vanishing gradients, causing NUTS to
    reject all proposals (0% acceptance rate).

    This function detects this condition and computes a scaled D0 that
    produces g1 ≈ target_g1 at a typical time lag, ensuring gradients
    are alive for MCMC exploration.

    Parameters
    ----------
    initial_values : dict[str, float] | None
        Initial parameter values containing D0, alpha, D_offset.
    q : float
        Wavevector magnitude.
    dt : float
        Time step.
    time_grid : np.ndarray | None
        Time grid for integration.
    logger_inst
        Logger instance for warnings.
    target_g1 : float
        Target g1 value for adjusted D0 (default 0.5).
    g1_threshold : float
        Threshold below which D0 is adjusted (default 0.1).

    Returns
    -------
    dict[str, float] | None
        Adjusted initial values if D0 was scaled, None otherwise.
    """
    if initial_values is None:
        return None

    # Get diffusion parameters
    d0 = initial_values.get("D0")
    alpha = initial_values.get("alpha")
    d_offset = initial_values.get("D_offset")

    if d0 is None or alpha is None or d_offset is None:
        return None

    if time_grid is None or len(time_grid) < 2:
        return None

    # Safety checks
    if not np.isfinite(d0) or not np.isfinite(alpha) or not np.isfinite(d_offset):
        return None

    try:
        # Compute D(t) on the time grid
        # D(t) = D0 * t^alpha + D_offset
        epsilon = 1e-10
        time_safe = np.asarray(time_grid) + epsilon
        D_grid = d0 * (time_safe**alpha) + d_offset
        D_grid = np.maximum(D_grid, 1e-10)

        # Compute trapezoidal cumsum (without dt scaling)
        if len(D_grid) > 1:
            trap_avg = 0.5 * (D_grid[:-1] + D_grid[1:])
            cumsum = np.concatenate([[0.0], np.cumsum(trap_avg)])
        else:
            cumsum = np.cumsum(D_grid)

        # Estimate integral at typical time lag (1/4 to 3/4 of range)
        n = len(cumsum)
        idx_low = n // 4
        idx_high = 3 * n // 4
        integral_estimate = abs(cumsum[idx_high] - cumsum[idx_low])

        # Compute g1 = exp(-q^2 * dt * integral)
        # Physics: g1 = exp(-q² ∫D(t)dt). The homodyne 0.5 factor appears on g1²,
        # not in the log of g1 directly. Using the full q² gives a correct estimate
        # of the ISF decay rate for the safety guard.
        prefactor = q**2 * dt
        log_g1 = -prefactor * integral_estimate
        log_g1_clipped = max(log_g1, -700.0)  # Prevent underflow
        g1_estimate = np.exp(log_g1_clipped)

        logger_inst.debug(
            f"[MCMC-SAFE] g1 estimate: D0={d0:.4g}, alpha={alpha:.4g}, "
            f"integral={integral_estimate:.4g}, log_g1={log_g1:.4g}, g1={g1_estimate:.4g}"
        )

        # If g1 is too small, compute scaled D0
        if g1_estimate < g1_threshold:
            # For g1 = target_g1:
            # log(target_g1) = -prefactor * target_integral
            # target_integral = -log(target_g1) / prefactor
            target_log_g1 = np.log(target_g1)
            target_integral = -target_log_g1 / prefactor

            # Scale factor: how much smaller should the integral be?
            if integral_estimate > 0:
                scale_factor = target_integral / integral_estimate
            else:
                scale_factor = 0.01  # Fallback

            # Apply scaling to D0 (approximately linear for moderate adjustments)
            # Also adjust D_offset proportionally for consistency
            new_d0 = d0 * scale_factor
            new_d_offset = d_offset * scale_factor

            # Ensure new values are within reasonable range
            new_d0 = max(new_d0, 1.0)  # Minimum D0
            new_d_offset = max(new_d_offset, -1e6)  # Allow negative but bound

            logger_inst.warning(
                f"MCMC-SAFE ADJUSTMENT: Initial D0={d0:.4g} causes g1~{g1_estimate:.2e} (vanishing gradients). "
                f"Scaling D0 to {new_d0:.4g} (x{scale_factor:.4f}) for MCMC exploration stability. "
                f"The sampler can still converge to optimal values if supported by likelihood."
            )

            # Return adjusted values
            adjusted = dict(initial_values)
            adjusted["D0"] = new_d0
            adjusted["D_offset"] = new_d_offset
            return adjusted

        return None  # No adjustment needed

    except Exception as e:
        logger_inst.debug(f"[MCMC-SAFE] Check failed: {e}")
        return None


if TYPE_CHECKING:
    from homodyne.config.parameter_space import ParameterSpace

logger = get_logger(__name__)


def _summarize_inverse_mass_matrix(inv_mass: Any) -> str:
    """Return a compact summary of the adapted inverse mass matrix."""

    def _one(mat: Any) -> str:
        if isinstance(mat, dict):
            keys = list(mat.keys())
            if not keys:
                return "dict(empty)"
            first = mat[keys[0]]
            return f"dict(keys={len(keys)}) first[{keys[0]}]: {_one(first)}"

        try:
            arr = np.asarray(mat)
        except Exception:
            return f"type={type(mat).__name__}"

        if arr.ndim == 0:
            # Could be an object scalar (e.g., dict) depending on upstream types.
            try:
                return f"scalar={float(arr):.3g}"
            except Exception:
                return f"scalar(type={type(arr.item()).__name__})"

        if arr.ndim == 1:
            diag = arr
            diag = diag[np.isfinite(diag)]
            if diag.size == 0:
                return f"diag(dim={arr.size}) all-nonfinite"
            dmin = float(np.min(diag))
            dmax = float(np.max(diag))
            cond = float(dmax / dmin) if dmin > 0 else float("inf")
            return f"diag(dim={arr.size}) min={dmin:.3g} max={dmax:.3g} cond~{cond:.3g}"

        if arr.ndim == 2 and arr.shape[0] == arr.shape[1]:
            diag = np.diag(arr)
            diag = diag[np.isfinite(diag)]
            if diag.size == 0:
                return f"dense(dim={arr.shape[0]}) diag all-nonfinite"
            dmin = float(np.min(diag))
            dmax = float(np.max(diag))
            try:
                cond = float(np.linalg.cond(arr))
            except Exception:
                cond = float("nan")
            return (
                f"dense(dim={arr.shape[0]}) diag[min={dmin:.3g}, max={dmax:.3g}] "
                f"cond={cond:.3g}"
            )

        # Per-chain dense matrices: (n_chains, dim, dim)
        if arr.ndim == 3 and arr.shape[1] == arr.shape[2]:
            n_chains = arr.shape[0]
            dim = arr.shape[1]
            # summarize first two chains
            parts = []
            for i in range(min(n_chains, 2)):
                parts.append(_one(arr[i]))
            more = "" if n_chains <= 2 else f" (+{n_chains - 2} more)"
            return f"per-chain dense(dim={dim})[{', '.join(parts)}]{more}"

        return f"array(shape={arr.shape}, ndim={arr.ndim})"

    if isinstance(inv_mass, (list, tuple)):
        parts = [_one(m) for m in inv_mass[:2]]
        more = "" if len(inv_mass) <= 2 else f" (+{len(inv_mass) - 2} more)"
        return f"per-chain[{', '.join(parts)}]{more}"

    return _one(inv_mass)


def _extract_adapt_states(last_state: Any) -> list[Any]:
    """Extract per-chain adapt_state objects from a NumPyro MCMC last_state."""
    if last_state is None:
        return []

    if hasattr(last_state, "adapt_state"):
        return [last_state.adapt_state]

    if isinstance(last_state, (list, tuple)):
        out: list[Any] = []
        for item in last_state:
            if hasattr(item, "adapt_state"):
                out.append(item.adapt_state)
        return out

    # NumPyro may omit adapt_state (e.g., API differences or failed adaptation).
    return []


def _log_array_stats(
    run_logger,
    *,
    name: str,
    arr: Any,
) -> None:
    try:
        a = np.asarray(arr)
    except Exception:
        return

    if a.size == 0:
        return

    finite = np.isfinite(a)
    if not np.any(finite):
        run_logger.info(f"{name} stats: all non-finite, shape={a.shape}")
        return

    run_logger.info(
        f"{name} stats: "
        f"min={float(np.min(a[finite])):.3g}, "
        f"median={float(np.median(a[finite])):.3g}, "
        f"max={float(np.max(a[finite])):.3g}, "
        f"mean={float(np.mean(a[finite])):.3g}, "
        f"std={float(np.std(a[finite])):.3g}, "
        f"finite={float(np.mean(finite)):.1%}, shape={a.shape}"
    )


def _extract_step_sizes(adapt_states: list[Any]) -> list[float]:
    """Extract step_size values from NumPyro adapt_state objects."""
    step_sizes: list[float] = []
    for adapt_state in adapt_states:
        if adapt_state is None:
            continue

        if hasattr(adapt_state, "step_size"):
            try:
                step_sizes.append(float(adapt_state.step_size))
                continue
            except Exception:  # noqa: S110 - Fallback for adapt_state variants
                pass

        if isinstance(adapt_state, dict) and "step_size" in adapt_state:
            try:
                step_sizes.append(float(adapt_state["step_size"]))
            except Exception:  # noqa: S110 - Fallback for adapt_state variants
                pass

    return step_sizes


[docs] @dataclass class SamplingStats: """Statistics from MCMC sampling. Attributes ---------- warmup_time : float Time spent in warmup phase (seconds). sampling_time : float Time spent in sampling phase (seconds). total_time : float Total sampling time (seconds). num_divergent : int Number of divergent transitions. accept_prob : float Mean acceptance probability. step_size : float Final step size. step_size_min : float Minimum adapted step size across chains (if available). step_size_max : float Maximum adapted step size across chains (if available). inverse_mass_matrix_summary : str | None Compact summary of the adapted inverse mass matrix (if available). tree_depth : float Mean tree depth. """ warmup_time: float = 0.0 sampling_time: float = 0.0 total_time: float = 0.0 num_divergent: int = 0 accept_prob: float = 0.0 step_size: float = 0.0 step_size_min: float | None = None step_size_max: float | None = None inverse_mass_matrix_summary: str | None = None tree_depth: float = 0.0 plan: SamplingPlan | None = None
[docs] @dataclass(frozen=True) class SamplingPlan: """Adapted MCMC sampling counts for a single shard. Captures the actual warmup/sample counts after adaptive scaling, which may differ from CMCConfig defaults for small shards. Use SamplingPlan.from_config() instead of accessing config.num_warmup / config.num_samples in hot paths. """ n_warmup: int n_samples: int n_chains: int shard_size: int n_params: int was_adapted: bool
[docs] @classmethod def from_config( cls, config: CMCConfig, shard_size: int, n_params: int ) -> SamplingPlan: n_warmup, n_samples = config.get_adaptive_sample_counts( shard_size=shard_size, n_params=n_params ) return cls( n_warmup=n_warmup, n_samples=n_samples, n_chains=config.num_chains, shard_size=shard_size, n_params=n_params, was_adapted=( n_warmup != config.num_warmup or n_samples != config.num_samples ), )
@property def total_samples(self) -> int: return self.n_samples * self.n_chains
[docs] @dataclass class MCMCSamples: """Container for MCMC samples. Attributes ---------- samples : dict[str, np.ndarray] Parameter samples, shape (n_chains, n_samples) per parameter. param_names : list[str] Parameter names in sampling order. n_chains : int Number of chains. n_samples : int Number of samples per chain. extra_fields : dict[str, Any] Additional MCMC info (divergences, energy, etc.). num_shards : int Number of shards combined (1 for single shard, >1 for CMC). Used for correct divergence rate calculation in CMC. """ samples: dict[str, np.ndarray] param_names: list[str] n_chains: int n_samples: int extra_fields: dict[str, Any] = field(default_factory=dict) num_shards: int = 1 shard_adapted_n_warmup: int | None = None bimodal_consensus: Any = ( None # BimodalConsensusResult when mode-aware consensus used )
[docs] def create_init_strategy( initial_values: dict[str, float] | None, param_names: list[str], use_init_to_value: bool = True, z_space_values: dict[str, float] | None = None, ) -> Callable: """Create initialization strategy for NUTS. Parameters ---------- initial_values : dict[str, float] | None Initial values from config (original space). param_names : list[str] Expected parameter names in order. use_init_to_value : bool If True, use init_to_value when values provided. z_space_values : dict[str, float] | None Initial values in z-space (for scaled model). If provided, these are used directly as {name}_z values. Returns ------- Callable NumPyro initialization function. """ # For scaled model, use z-space values if z_space_values is not None and use_init_to_value: if z_space_values: logger.debug( f"Using init_to_value (z-space) for {len(z_space_values)} params: " f"{list(z_space_values.keys())[:5]}..." ) return init_to_value(values=z_space_values) # For unscaled model, use original values if initial_values is not None and use_init_to_value: # Filter to only parameters we're sampling (exclude deterministics) init_dict = {} for name in param_names: if name in initial_values: init_dict[name] = initial_values[name] if init_dict: logger.debug( f"Using init_to_value for {len(init_dict)} params: {list(init_dict.keys())}" ) return init_to_value(values=init_dict) # Fallback to median initialization logger.debug("Using init_to_median (no initial values)") return init_to_median()
[docs] def run_nuts_sampling( model: Callable, model_kwargs: dict[str, Any], config: CMCConfig, initial_values: dict[str, float] | None, parameter_space: ParameterSpace, n_phi: int, analysis_mode: str, rng_key: jax.random.PRNGKey | None = None, progress_bar: bool = True, per_angle_mode: str = "individual", ) -> tuple[MCMCSamples, SamplingStats]: """Run NUTS sampling with configuration. Parameters ---------- model : Callable NumPyro model function. model_kwargs : dict[str, Any] Keyword arguments to pass to model. config : CMCConfig CMC configuration. initial_values : dict[str, float] | None Initial parameter values from config. parameter_space : ParameterSpace Parameter space for building init values. n_phi : int Number of phi angles. analysis_mode : str Analysis mode. rng_key : jax.random.PRNGKey | None Random key. If None, creates from seed. progress_bar : bool Whether to show progress bar. per_angle_mode : str Per-angle scaling mode: "individual", "auto", "constant", or "constant_averaged". Controls which parameters are sampled vs fixed. Returns ------- tuple[MCMCSamples, SamplingStats] Samples and timing statistics. """ run_logger = with_context(logger, run=getattr(config, "run_id", None)) # Get parameter names in correct order param_names = get_param_names_in_order(n_phi, analysis_mode, per_angle_mode) # Add sigma (noise parameter) param_names_with_sigma = param_names + ["sigma"] # Build full init values dict if needed # Extract data arrays from model_kwargs for data-driven estimation c2_data = model_kwargs.get("data") t1_data = model_kwargs.get("t1") t2_data = model_kwargs.get("t2") phi_indices = model_kwargs.get("phi_indices") # Convert JAX arrays to numpy for estimation (if needed) if c2_data is not None and hasattr(c2_data, "__array__"): c2_data = np.asarray(c2_data) if t1_data is not None and hasattr(t1_data, "__array__"): t1_data = np.asarray(t1_data) if t2_data is not None and hasattr(t2_data, "__array__"): t2_data = np.asarray(t2_data) if phi_indices is not None and hasattr(phi_indices, "__array__"): phi_indices = np.asarray(phi_indices) # ========================================================================= # MCMC-SAFE D0 CHECK: Detect and fix vanishing gradient regime # ========================================================================= # When D0 is very large (or alpha very negative), g1 → 0 everywhere, # causing vanishing gradients and 0% NUTS acceptance rate. # This check detects that condition and scales D0 to ensure gradients are alive. q = model_kwargs.get("q", 0.01) dt = model_kwargs.get("dt", 0.1) time_grid = model_kwargs.get("time_grid") if time_grid is not None and hasattr(time_grid, "__array__"): time_grid_np = np.asarray(time_grid) else: time_grid_np = time_grid adjusted_init = _compute_mcmc_safe_d0( initial_values=initial_values, q=q, dt=dt, time_grid=time_grid_np, logger_inst=run_logger, ) # Use adjusted values if D0 was scaled effective_init_values = ( adjusted_init if adjusted_init is not None else initial_values ) full_init = build_init_values_dict( n_phi=n_phi, analysis_mode=analysis_mode, initial_values=effective_init_values, parameter_space=parameter_space, c2_data=c2_data, t1=t1_data, t2=t2_data, phi_indices=phi_indices, per_angle_mode=per_angle_mode, ) # ========================================================================= # GRADIENT BALANCING: Transform initial values to z-space for scaled model # ========================================================================= # The scaled model (xpcs_model_scaled) samples in normalized z-space where # z ~ Normal(0, 1). We need to transform our initial values to this space # for proper initialization with init_to_value. scalings = compute_scaling_factors(parameter_space, n_phi, analysis_mode) z_space_init = transform_initial_values_to_z(full_init, scalings) # Log scaling transformation info run_logger.info( f"Gradient balancing: {len(scalings)} params transformed to unit scale. " f"Sample scale range: {min(s.scale for s in scalings.values()):.2e} to " f"{max(s.scale for s in scalings.values()):.2e}" ) # ========================================================================= # REPARAMETERIZED INIT VALUES: Add log_D_ref, D_offset_ratio, log_gamma_ref # ========================================================================= # When using the reparameterized model, the sampled parameters are different # from the z-space params. Add init values for the reparam parameters. reparam_config = model_kwargs.get("reparam_config") if reparam_config is not None: t_ref_init = model_kwargs.get("t_ref", getattr(reparam_config, "t_ref", 1.0)) # Compute reparameterized init values from physics init values D0_init = full_init.get("D0", 1e4) alpha_init = full_init.get("alpha", -0.5) D_offset_init = full_init.get("D_offset", 1e3) if getattr(reparam_config, "enable_d_ref", False): D_ref_init = D0_init * (t_ref_init**alpha_init) D_ref_init = max(D_ref_init, 1e-10) z_space_init["log_D_ref"] = float(np.log(D_ref_init)) # D_offset_ratio = D_offset / D_ref (linear, handles negative D_offset). # Clamp to (-1+eps, inf) to match the TruncatedNormal prior floor in model.py: # ratio <= -1 means D_ref + D_offset <= 0 (non-physical at t_ref). raw_ratio = D_offset_init / D_ref_init z_space_init["D_offset_ratio"] = float(max(raw_ratio, -1.0 + 1e-4)) if ( getattr(reparam_config, "enable_gamma_ref", False) and analysis_mode == "laminar_flow" ): gamma_dot_t0_init = full_init.get("gamma_dot_t0", 1e-3) beta_init = full_init.get("beta", -0.3) gamma_ref_init = gamma_dot_t0_init * (t_ref_init**beta_init) gamma_ref_init = max(gamma_ref_init, 1e-20) z_space_init["log_gamma_ref"] = float(np.log(gamma_ref_init)) run_logger.info( f"Reparameterized init: t_ref={t_ref_init:.4g}, " + ", ".join( f"{k}={z_space_init[k]:.4g}" for k in ["log_D_ref", "D_offset_ratio", "log_gamma_ref"] if k in z_space_init ) ) # Create init strategy for NUTS kernel. # P1-R5-01: NLSQ-informed models ("auto"/"constant_averaged" per_angle_mode) # sample in original parameter space — their NumPyro sites are named "D0", # "alpha", etc., NOT "D0_z", "alpha_z". Passing z_space_values with z-space # site names causes init_to_value to silently ignore all entries (no site name # match), falling back to init_to_median and discarding the NLSQ warm-start. # For these modes, use the original-space full_init directly. nlsq_prior_config = model_kwargs.get("nlsq_prior_config") if nlsq_prior_config is not None and per_angle_mode in ( "auto", "constant_averaged", ): # Original-space initialization for NLSQ-informed models init_strategy = create_init_strategy(full_init, param_names_with_sigma) else: # Z-space initialization for scaled models init_strategy = create_init_strategy( full_init, param_names_with_sigma, z_space_values=z_space_init ) # ========================================================================= # PREFLIGHT: Validate initial log density and finiteness # ========================================================================= # This catches common causes of 0% acceptance (NaNs/-inf log prob) before # spending wall-clock hours running many CMC shards. sigma_init = float(max(model_kwargs.get("noise_scale", 0.1), 1e-6)) # P1-5 / P1-R5-01: NLSQ-informed models (averaged, constant_averaged) sample # parameters with original-space names ("D0", "alpha", ...) NOT z-space names # ("D0_z"). Both the preflight and init_strategy must use matching names. # nlsq_prior_config was read above when building init_strategy. if nlsq_prior_config is not None and per_angle_mode in ( "auto", "constant_averaged", ): # Use original-space names for NLSQ-informed models preflight_params = dict(full_init) else: preflight_params = dict(z_space_init) preflight_params["sigma"] = sigma_init _preflight_log_density( model=model, model_kwargs=model_kwargs, params=preflight_params, run_logger=run_logger, ) # Create NUTS kernel # GRADIENT BALANCING FIX (Dec 2025): Use dense_mass=True to learn # cross-correlations between parameters with vastly different scales. # Without this, the 10^6:1 gradient imbalance between D0 (~10^4) and # gamma_dot_t0 (~10^-3) causes 0% acceptance rate because no single # step size ε works for all dimensions. Dense mass matrix allows NUTS # to adapt per-dimension and learn covariance structure during warmup. # # CONVERGENCE FIX (Jan 2026): Elevate target_accept_prob for laminar_flow. # The 28.4% divergence rate in 3-angle laminar_flow was caused by step size # adaptation settling on values too large for the complex posterior geometry. # Higher target_accept_prob forces smaller steps, reducing divergences. effective_target_accept = config.target_accept_prob if analysis_mode == "laminar_flow" and config.target_accept_prob < 0.9: effective_target_accept = 0.9 run_logger.info( f"Elevating target_accept_prob from {config.target_accept_prob} to {effective_target_accept} " f"for laminar_flow mode (reduces divergences in complex posterior)" ) kernel = NUTS( model, init_strategy=init_strategy, target_accept_prob=effective_target_accept, dense_mass=config.dense_mass, max_tree_depth=config.max_tree_depth, ) # Get shard size for adaptive sampling data = model_kwargs.get("data") shard_size = len(data) if data is not None else 10000 # Build SamplingPlan: captures adapted warmup/samples for this shard. # Profiling showed 1310s for 50 points with 500/1500 defaults. # Adaptive scaling reduces overhead for small datasets. plan = SamplingPlan.from_config( config, shard_size=shard_size, n_params=len(param_names_with_sigma) ) num_warmup, num_samples = plan.n_warmup, plan.n_samples # Determine effective chain method with auto-fallback for small shards effective_chain_method = config.chain_method if effective_chain_method == "parallel" and shard_size < 500: run_logger.warning( f"Shard size {shard_size:,} < 500: falling back to sequential " f"chains (parallel overhead exceeds benefit for small shards)" ) effective_chain_method = "sequential" # Create MCMC runner with adaptive sample counts mcmc = MCMC( kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=config.num_chains, chain_method=effective_chain_method, progress_bar=progress_bar, ) # Create RNG key from config seed (defaults to 42 for reproducibility) if rng_key is None: rng_key = jax.random.PRNGKey(config.seed) # Run sampling with timing adaptive_note = "" if config.adaptive_sampling and ( num_warmup != config.num_warmup or num_samples != config.num_samples ): adaptive_note = f" (adaptive: {shard_size:,} pts)" run_logger.info( f"Starting NUTS sampling: {config.num_chains} chains, " f"{num_warmup} warmup, {num_samples} samples{adaptive_note}" ) start_time = time.perf_counter() run_logger.info("NUTS phase: JIT compile + sampling started (may take minutes)...") # JAX profiler setup (Feb 2026): Capture XLA-level performance data # py-spy can only profile Python code; XLA runs native code invisible to py-spy. # When enabled, trace XLA operations to identify HLO graph bottlenecks. profile_context = None if config.enable_jax_profiling: import os os.makedirs(config.jax_profile_dir, exist_ok=True) run_logger.info(f"JAX profiling enabled, output: {config.jax_profile_dir}") profile_context = jax.profiler.trace(config.jax_profile_dir) profile_context.__enter__() try: # Request only essential extra fields to minimize extraction overhead. # Previously we requested 7 fields which caused 25-45 minute extraction times # due to JAX lazy evaluation materializing large intermediate arrays. # NOTE: "adapt_state.inverse_mass_matrix" is intentionally omitted here. # Requesting it stores the full mass matrix at every warmup+sample step, # producing a (n_chains, n_warmup+n_samples, n_params, n_params) tensor that # wastes 5-180 MB per shard (scales with n_params²). The final adapted mass # matrix is already extracted from mcmc.last_state via _extract_adapt_states. mcmc.run( rng_key, extra_fields=( "accept_prob", "diverging", "num_steps", "potential_energy", "adapt_state.step_size", ), **model_kwargs, ) except Exception as e: run_logger.error(f"MCMC sampling failed: {e}") raise RuntimeError(f"MCMC sampling failed: {e}") from e finally: if profile_context is not None: profile_context.__exit__(None, None, None) run_logger.info(f"JAX profile saved to {config.jax_profile_dir}") # Force JAX to complete all pending computations before timing extraction. # JAX uses lazy evaluation, so without this the actual computation happens # during device_get(), causing misleading timing and 25-45 min "extraction" times. last_state = getattr(mcmc, "last_state", None) if last_state is not None: jax.block_until_ready(last_state) total_time = time.perf_counter() - start_time run_logger.info(f"NUTS finished in {total_time:.1f}s") # Extract samples - should be fast now that computation is complete t_extract = time.perf_counter() run_logger.info("Extracting samples + extra_fields...") samples = mcmc.get_samples(group_by_chain=True) # Use block_until_ready before device_get to ensure computation is complete jax.block_until_ready(samples) samples = jax.device_get(samples) # Convert to numpy and proper format samples_np: dict[str, np.ndarray] = {} for name, arr in samples.items(): samples_np[name] = np.asarray(arr) # Get extra fields (divergences, etc.) extra = mcmc.get_extra_fields(group_by_chain=True) jax.block_until_ready(extra) extra = jax.device_get(extra) extra_fields = {k: np.asarray(v) for k, v in extra.items()} run_logger.info( f"Extraction complete in {time.perf_counter() - t_extract:.2f}s " f"(samples={len(samples_np)}, extra_fields={len(extra_fields)})" ) # Compute statistics num_divergent = 0 if "diverging" in extra_fields: num_divergent = int(np.sum(extra_fields["diverging"])) # CONVERGENCE CHECK (Jan 2026): Early divergence rate detection # High divergence rates indicate NUTS is struggling with the posterior geometry. # The 28.4% divergence rate in the 3-angle failure case signals unreliable posteriors. # Use actual chain count from samples array (may differ from config if init fails). _first_sample = next(iter(samples_np.values()), None) _actual_chains_for_div = ( _first_sample.shape[0] if _first_sample is not None else config.num_chains ) total_samples = num_samples * _actual_chains_for_div if total_samples > 0: divergence_rate = num_divergent / total_samples if divergence_rate > DIVERGENCE_RATE_CRITICAL: run_logger.error( f"CRITICAL: Divergence rate {divergence_rate:.1%} ({num_divergent}/{total_samples}) " f"exceeds {DIVERGENCE_RATE_CRITICAL:.0%} threshold. Posterior is likely unreliable. Consider:\n" " 1. Reducing shard size (smaller max_points_per_shard)\n" " 2. Using NLSQ warm-start for better initial values\n" " 3. Widening priors or fixing problematic parameters" ) elif divergence_rate > DIVERGENCE_RATE_HIGH: run_logger.warning( f"High divergence rate: {divergence_rate:.1%} ({num_divergent}/{total_samples}). " f"Posterior may be biased. Target: <{DIVERGENCE_RATE_TARGET:.0%}" ) elif divergence_rate > DIVERGENCE_RATE_TARGET: run_logger.info( f"Elevated divergence rate: {divergence_rate:.1%} ({num_divergent}/{total_samples}). " f"Acceptable but monitor closely." ) accept_prob = float("nan") accept_prob_arr = None if "accept_prob" in extra_fields: accept_prob_arr = np.asarray(extra_fields["accept_prob"]) if accept_prob_arr.size: accept_prob = float(np.nanmean(accept_prob_arr)) # Adaptation diagnostics (step size, mass matrix) step_size = 0.0 inv_mass_summary = None last_state = getattr(mcmc, "last_state", None) adapt_states = _extract_adapt_states(last_state) step_sizes = _extract_step_sizes(adapt_states) if step_sizes: step_size = float(np.nanmedian(step_sizes)) step_size_min = float(np.nanmin(step_sizes)) step_size_max = float(np.nanmax(step_sizes)) else: step_size_min = None step_size_max = None # If we couldn't extract step_size from last_state, try extra_fields # (available across NumPyro versions when explicitly requested). if step_size == 0.0 and "adapt_state.step_size" in extra_fields: try: ss = np.asarray(extra_fields["adapt_state.step_size"]).reshape(-1) ss = ss[np.isfinite(ss)] if ss.size: step_size = float(np.median(ss)) except Exception: # noqa: S110 - Fallback for step_size extraction pass inv_mass = None if adapt_states: a0 = adapt_states[0] inv_mass = getattr(a0, "inverse_mass_matrix", None) if inv_mass is None and isinstance(a0, dict): inv_mass = a0.get("inverse_mass_matrix") if inv_mass is not None: inv_mass_summary = _summarize_inverse_mass_matrix(inv_mass) if step_sizes: run_logger.info( "Adapted step_size stats: " f"min={float(np.nanmin(step_sizes)):.3g}, " f"median={float(np.nanmedian(step_sizes)):.3g}, " f"max={float(np.nanmax(step_sizes)):.3g}" ) elif step_size > 0: run_logger.info(f"Adapted step_size~={step_size:.3g} (from extra_fields)") if inv_mass_summary is not None: run_logger.info(f"Adapted inverse_mass_matrix: {inv_mass_summary}") # Always log basic accept/energy diagnostics when available. if accept_prob_arr is not None and accept_prob_arr.size: run_logger.info( "accept_prob stats: " f"mean={float(np.nanmean(accept_prob_arr)):.3g}, " f"min={float(np.nanmin(accept_prob_arr)):.3g}, " f"median={float(np.nanmedian(accept_prob_arr)):.3g}, " f"max={float(np.nanmax(accept_prob_arr)):.3g}, " f"frac<1e-12={float(np.nanmean(accept_prob_arr < 1e-12)):.1%}, " f"shape={accept_prob_arr.shape}" ) # Per-chain stats are often the easiest way to spot a single stuck chain. if accept_prob_arr.ndim >= 2: for i in range(min(accept_prob_arr.shape[0], 8)): a = np.asarray(accept_prob_arr[i]).reshape(-1) if a.size: run_logger.info( f"accept_prob chain[{i}] mean={float(np.nanmean(a)):.3g} " f"min={float(np.nanmin(a)):.3g} median={float(np.nanmedian(a)):.3g} " f"max={float(np.nanmax(a)):.3g}" ) if "diverging" in extra_fields: div = np.asarray(extra_fields["diverging"]) run_logger.info(f"diverging total={int(np.sum(div))} shape={div.shape}") if div.ndim >= 2: for i in range(min(div.shape[0], 8)): run_logger.info(f"diverging chain[{i}]={int(np.sum(div[i]))}") # Step count stats help identify stiffness/underflow that forces tiny step sizes. if "num_steps" in extra_fields: _log_array_stats(run_logger, name="num_steps", arr=extra_fields["num_steps"]) if "mean_accept_prob" in extra_fields: _log_array_stats( run_logger, name="mean_accept_prob", arr=extra_fields["mean_accept_prob"] ) for energy_key in ("potential_energy", "energy"): if energy_key in extra_fields: _log_array_stats(run_logger, name=energy_key, arr=extra_fields[energy_key]) # Critical warning for zero acceptance rate if np.isfinite(accept_prob) and accept_prob < 0.001: run_logger.warning( "CRITICAL: Acceptance rate is essentially 0% - all proposals rejected! " "This indicates severe sampling problems. Possible causes:\n" " 1. Initial values are outside prior support or at boundaries\n" " 2. Likelihood returns -inf due to numerical issues (NaN/overflow)\n" " 3. Prior is too narrow for the data\n" " 4. Step size adaptation failed during warmup\n" "Consider: checking initial values, widening priors, or running NLSQ first." ) if accept_prob_arr is not None and accept_prob_arr.size: finite = np.isfinite(accept_prob_arr) run_logger.warning( "accept_prob stats: " f"min={float(np.nanmin(accept_prob_arr)):.3g}, " f"median={float(np.nanmedian(accept_prob_arr)):.3g}, " f"max={float(np.nanmax(accept_prob_arr)):.3g}, " f"frac<1e-12={float(np.mean(accept_prob_arr < 1e-12)):.1%}, " f"finite={float(np.mean(finite)):.1%}, shape={accept_prob_arr.shape}" ) else: run_logger.warning( f"accept_prob array missing/empty; extra_fields keys={sorted(extra_fields.keys())}" ) for energy_key in ("potential_energy", "energy"): if energy_key in extra_fields: e = np.asarray(extra_fields[energy_key]) finite = np.isfinite(e) if np.any(finite): run_logger.warning( f"{energy_key} stats: " f"min={float(np.min(e[finite])):.3g}, " f"median={float(np.median(e[finite])):.3g}, " f"max={float(np.max(e[finite])):.3g}, " f"finite={float(np.mean(finite)):.1%}" ) else: run_logger.warning(f"{energy_key} all non-finite") else: run_logger.warning( f"{energy_key} not present in extra_fields (keys={sorted(extra_fields.keys())})" ) # Check for numerical issues exposed by the model. # NOTE: numpyro stores deterministics in `get_samples`, not `get_extra_fields`. if "n_numerical_issues" in samples_np: try: n_issues_total = float(np.sum(samples_np["n_numerical_issues"])) if n_issues_total > 0: total_evals = num_samples * config.num_chains issue_rate = n_issues_total / max(total_evals, 1) run_logger.warning( f"Numerical issues detected: {n_issues_total:.0f} NaN/Inf occurrences " f"({issue_rate:.1%} of evaluations). " "This may indicate parameter combinations causing overflow in physics model." ) except Exception: # noqa: S110 - Fallback for step_size extraction pass # Estimate warmup vs sampling time (rough estimate) # Use actual num_warmup/num_samples (may differ from config if adaptive sampling) warmup_ratio = num_warmup / (num_warmup + num_samples) warmup_time = total_time * warmup_ratio sampling_time = total_time * (1 - warmup_ratio) # Compute mean tree depth from NUTS num_steps (tree_depth = log2(num_steps)) mean_tree_depth = 0.0 if "num_steps" in extra_fields: num_steps_arr = np.asarray(extra_fields["num_steps"]) finite_steps = num_steps_arr[np.isfinite(num_steps_arr) & (num_steps_arr > 0)] if finite_steps.size > 0: mean_tree_depth = float(np.mean(np.log2(finite_steps))) stats = SamplingStats( warmup_time=warmup_time, sampling_time=sampling_time, total_time=total_time, num_divergent=num_divergent, accept_prob=accept_prob, step_size=step_size, step_size_min=step_size_min, step_size_max=step_size_max, inverse_mass_matrix_summary=inv_mass_summary, plan=plan, tree_depth=mean_tree_depth, ) run_logger.info( f"Sampling complete in {total_time:.1f}s, " f"{num_divergent} divergences, " f"accept_prob={accept_prob:.3f}" ) # Create MCMCSamples object # Derive actual chain count from first sample array (not from config) first_param_samples = next(iter(samples_np.values()), None) actual_n_chains = ( first_param_samples.shape[0] if first_param_samples is not None else config.num_chains ) mcmc_samples = MCMCSamples( samples=samples_np, param_names=[k for k in samples_np.keys() if k != "obs"], n_chains=actual_n_chains, n_samples=num_samples, # Use actual samples count (may be adaptive) extra_fields=extra_fields, ) return mcmc_samples, stats
[docs] def run_nuts_with_retry( model: Callable, model_kwargs: dict[str, Any], config: CMCConfig, initial_values: dict[str, float] | None, parameter_space: ParameterSpace, n_phi: int, analysis_mode: str, max_retries: int = 3, rng_key: jax.random.PRNGKey | None = None, per_angle_mode: str = "individual", ) -> tuple[MCMCSamples, SamplingStats]: """Run NUTS sampling with automatic retry on failure. Parameters ---------- model : Callable NumPyro model function. model_kwargs : dict[str, Any] Model arguments. config : CMCConfig Configuration. initial_values : dict[str, float] | None Initial values. parameter_space : ParameterSpace Parameter space. n_phi : int Number of phi angles. analysis_mode : str Analysis mode. max_retries : int Maximum number of retry attempts. rng_key : jax.random.PRNGKey | None Random key. Returns ------- tuple[MCMCSamples, SamplingStats] Samples and statistics. Raises ------ RuntimeError If all retries fail. """ if rng_key is None: rng_key = jax.random.PRNGKey(config.seed) last_error = None run_logger = with_context(logger, run=getattr(config, "run_id", None)) # Track current target_accept_prob for adaptive escalation current_target_accept_prob = config.target_accept_prob for attempt in range(max_retries): attempt_num = attempt + 1 attempt_logger = with_context(run_logger, attempt=attempt_num) attempt_start = time.perf_counter() # Adaptive divergence threshold: stricter on first attempt divergence_threshold = ( DIVERGENCE_RATE_TARGET if attempt == 0 else DIVERGENCE_RATE_HIGH ) attempt_logger.info( f"Attempt {attempt_num}/{max_retries}: starting NUTS " f"(chains={config.num_chains}, samples={config.num_samples}, " f"target_accept={current_target_accept_prob:.2f}, div_threshold={divergence_threshold:.0%})" ) try: # Use different RNG key for each attempt attempt_key = jax.random.fold_in(rng_key, attempt) # Create config with potentially escalated target_accept_prob attempt_config = replace( config, target_accept_prob=current_target_accept_prob ) samples, stats = run_nuts_sampling( model=model, model_kwargs=model_kwargs, config=attempt_config, initial_values=initial_values, parameter_space=parameter_space, n_phi=n_phi, analysis_mode=analysis_mode, rng_key=attempt_key, progress_bar=attempt == 0, # Only show progress on first attempt per_angle_mode=per_angle_mode, ) # Check for excessive divergences (adaptive threshold) # Use samples.n_samples (actual adapted count) instead of # config.num_samples which ignores adaptive sampling reduction. total_retry_samples = samples.n_samples * samples.n_chains divergence_rate = ( stats.num_divergent / total_retry_samples if total_retry_samples > 0 else 1.0 ) duration = time.perf_counter() - attempt_start if divergence_rate > divergence_threshold: attempt_logger.warning( f"Attempt {attempt_num}/{max_retries}: divergence_rate={divergence_rate:.1%} " f"> threshold {divergence_threshold:.0%}, retrying with smaller step sizes..." ) # Escalate target_accept_prob for next retry (smaller step sizes) current_target_accept_prob = min( 0.95, current_target_accept_prob + 0.05 ) last_error = RuntimeError( f"High divergence rate: {divergence_rate:.1%}" ) continue attempt_logger.info( f"Attempt {attempt_num}/{max_retries} succeeded in {duration:.2f}s " f"(divergences={stats.num_divergent}, accept_prob={stats.accept_prob:.3f})" ) return samples, stats except Exception as e: duration = time.perf_counter() - attempt_start attempt_logger.warning( f"Attempt {attempt_num}/{max_retries} failed after {duration:.2f}s: {e}" ) last_error = e run_logger.error(f"MCMC sampling failed after {max_retries} attempts: {last_error}") # All retries failed raise RuntimeError( f"MCMC sampling failed after {max_retries} attempts: {last_error}" )