Source code for homodyne.optimization.cmc.diagnostics

"""Convergence diagnostics for CMC analysis.

This module provides functions for computing MCMC convergence diagnostics
including R-hat, effective sample size (ESS), and divergence checks.
"""

from __future__ import annotations

import warnings
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any

try:
    import arviz as az

    HAS_ARVIZ = True
except ImportError:
    HAS_ARVIZ = False
    az = None  # type: ignore[assignment,unused-ignore]

import numpy as np

try:
    from sklearn.mixture import GaussianMixture

    HAS_SKLEARN = True
except (ImportError, ValueError):
    HAS_SKLEARN = False
    GaussianMixture = None  # type: ignore[assignment,misc,unused-ignore]

from homodyne.utils.logging import get_logger

logger = get_logger(__name__)

# Default convergence thresholds
DEFAULT_MAX_RHAT = 1.05
DEFAULT_MIN_ESS = 400
DEFAULT_MAX_DIVERGENCE_RATE = 0.05


[docs] def compute_r_hat( samples: dict[str, np.ndarray], ) -> dict[str, float]: """Compute split-R-hat (Vehtari et al. 2021) for each parameter. Uses ArviZ's implementation of split-R-hat, which splits each chain in half before computing R-hat across 2*n_chains half-chains. This detects both between-chain discordance and within-chain non-stationarity that the original 1992 Gelman-Rubin formula misses. Falls back to the classical Gelman-Rubin formula when ArviZ is not available. Parameters ---------- samples : dict[str, np.ndarray] Parameter samples, {name: (n_chains, n_samples)}. Returns ------- dict[str, float] R-hat value for each parameter. """ # Prefer ArviZ split-R-hat (Vehtari et al. 2021) over manual formula. if HAS_ARVIZ: try: idata = az.from_dict({"posterior": samples}) with warnings.catch_warnings(): warnings.simplefilter("ignore", RuntimeWarning) rhat_ds = az.rhat(idata) r_hat_dict: dict[str, float] = {} for name in samples: if hasattr(rhat_ds, name): val = float(getattr(rhat_ds, name).values) r_hat_dict[name] = val if np.isfinite(val) else np.nan else: r_hat_dict[name] = np.nan return r_hat_dict except Exception as e: logger.warning(f"ArviZ R-hat computation failed: {e}, using fallback") # Fallback: classical Gelman-Rubin (1992) formula r_hat_dict = {} for name, arr in samples.items(): if arr.ndim != 2: logger.warning(f"Skipping R-hat for {name}: expected 2D, got {arr.ndim}D") continue n_chains, n_samples = arr.shape if n_chains < 2: r_hat_dict[name] = np.nan continue # Between-chain variance (NaN-safe: NUTS can produce NaN samples on # divergent transitions even in "successful" shards) chain_means = np.nanmean(arr, axis=1) B = n_samples * np.nanvar(chain_means, ddof=1) # Within-chain variance chain_vars = np.nanvar(arr, axis=1, ddof=1) W = np.nanmean(chain_vars) # Pooled variance estimate var_plus = ((n_samples - 1) * W + B) / n_samples if W > 0: r_hat = np.sqrt(var_plus / W) else: r_hat = np.nan r_hat_dict[name] = float(r_hat) return r_hat_dict
[docs] def compute_ess( samples: dict[str, np.ndarray], ) -> tuple[dict[str, float], dict[str, float]]: """Compute effective sample size (bulk and tail) for each parameter. ESS measures the number of independent samples accounting for autocorrelation. Higher is better. Parameters ---------- samples : dict[str, np.ndarray] Parameter samples, {name: (n_chains, n_samples)}. Returns ------- tuple[dict[str, float], dict[str, float]] (ess_bulk, ess_tail) dictionaries. """ ess_bulk_dict: dict[str, float] = {} ess_tail_dict: dict[str, float] = {} # Create ArviZ InferenceData for ESS computation try: if not HAS_ARVIZ: raise ImportError("Arviz is required for full ESS computation") idata = az.from_dict({"posterior": samples}) # Compute ESS using ArviZ (suppress RuntimeWarning from degenerate chains) with warnings.catch_warnings(): warnings.simplefilter("ignore", RuntimeWarning) ess_bulk = az.ess(idata, method="bulk") ess_tail = az.ess(idata, method="tail") # Extract values for name in samples.keys(): if hasattr(ess_bulk, name): ess_bulk_dict[name] = float(getattr(ess_bulk, name).values) else: ess_bulk_dict[name] = np.nan if hasattr(ess_tail, name): ess_tail_dict[name] = float(getattr(ess_tail, name).values) else: ess_tail_dict[name] = np.nan except Exception as e: logger.warning(f"ArviZ ESS computation failed: {e}, using simple estimate") # Fallback: simple ESS estimate for name, arr in samples.items(): n_total = arr.size # Very rough estimate: assume moderate autocorrelation ess_bulk_dict[name] = float(n_total / 10) ess_tail_dict[name] = float(n_total / 10) return ess_bulk_dict, ess_tail_dict
[docs] def check_convergence( r_hat: dict[str, float], ess_bulk: dict[str, float], divergences: int, n_samples: int, n_chains: int, max_rhat: float = DEFAULT_MAX_RHAT, min_ess: float = DEFAULT_MIN_ESS, max_divergence_rate: float = DEFAULT_MAX_DIVERGENCE_RATE, num_shards: int = 1, ) -> tuple[str, list[str]]: """Check convergence and generate warnings. Parameters ---------- r_hat : dict[str, float] Per-parameter R-hat values. ess_bulk : dict[str, float] Per-parameter bulk ESS values. divergences : int Number of divergent transitions. n_samples : int Samples per chain. n_chains : int Number of chains. max_rhat : float Maximum acceptable R-hat. min_ess : float Minimum acceptable ESS. max_divergence_rate : float Maximum acceptable divergence rate. num_shards : int Number of shards (for CMC). Divergences are summed across shards, so total transitions = num_shards × n_chains × n_samples. Returns ------- tuple[str, list[str]] (status, warnings) where status is "converged" | "divergences" | "not_converged". """ warnings: list[str] = [] # Check R-hat max_r_hat_value = max( (v for v in r_hat.values() if not np.isnan(v)), default=1.0, ) if max_r_hat_value > max_rhat: bad_params = [k for k, v in r_hat.items() if v > max_rhat] # T046: Log R-hat warnings for poor convergence warning_msg = f"R-hat > {max_rhat} for parameters: {bad_params} (max={max_r_hat_value:.3f})" logger.warning(warning_msg) warnings.append(warning_msg) # Check ESS min_ess_value = min( (v for v in ess_bulk.values() if not np.isnan(v)), default=0.0, ) if min_ess_value < min_ess: bad_params = [k for k, v in ess_bulk.items() if v < min_ess] warnings.append( f"ESS < {min_ess} for parameters: {bad_params} (min={min_ess_value:.0f})" ) # Check divergences # For CMC, divergences are summed across all shards, so total transitions # must account for num_shards to get the correct rate total_transitions = num_shards * n_samples * n_chains divergence_rate = divergences / total_transitions if total_transitions > 0 else 0 if divergence_rate > max_divergence_rate: warnings.append( f"Divergence rate {divergence_rate:.1%} exceeds {max_divergence_rate:.1%} " f"({divergences}/{total_transitions} transitions)" ) # Determine status if divergences > 0 and divergence_rate > max_divergence_rate: status = "divergences" elif warnings: status = "not_converged" else: status = "converged" return status, warnings
[docs] def create_diagnostics_dict( r_hat: dict[str, float], ess_bulk: dict[str, float], ess_tail: dict[str, float], divergences: int, convergence_status: str, warnings: list[str], n_chains: int, n_warmup: int, n_samples: int, warmup_time: float, sampling_time: float, num_shards: int = 1, ) -> dict[str, Any]: """Create diagnostics dictionary for JSON output. Parameters ---------- r_hat : dict[str, float] Per-parameter R-hat. ess_bulk : dict[str, float] Per-parameter bulk ESS. ess_tail : dict[str, float] Per-parameter tail ESS. divergences : int Number of divergences. convergence_status : str Convergence status. warnings : list[str] Warning messages. n_chains : int Number of chains. n_warmup : int Warmup samples. n_samples : int Posterior samples. warmup_time : float Warmup time in seconds. sampling_time : float Sampling time in seconds. num_shards : int Number of shards combined (default 1). For CMC runs, ``divergences`` is the aggregate total across all shards, so the correct denominator is ``num_shards * n_chains * n_samples``. Returns ------- dict[str, Any] Diagnostics dictionary. """ # Compute summary statistics r_hat_values = [v for v in r_hat.values() if not np.isnan(v)] ess_values = [v for v in ess_bulk.values() if not np.isnan(v)] # For CMC, divergences are aggregated across all shards; use the full # total_transitions = num_shards * n_chains * n_samples as the denominator. total_transitions = num_shards * n_chains * n_samples divergence_rate = divergences / total_transitions if total_transitions > 0 else 0 return { "convergence_status": convergence_status, "total_divergences": divergences, "divergence_rate": divergence_rate, "max_r_hat": max(r_hat_values) if r_hat_values else np.nan, "min_ess_bulk": min(ess_values) if ess_values else np.nan, "min_ess_tail": min( (v for v in ess_tail.values() if not np.isnan(v)), default=np.nan ), "all_r_hat_ok": all(v <= DEFAULT_MAX_RHAT for v in r_hat_values), "all_ess_ok": all(v >= DEFAULT_MIN_ESS for v in ess_values), "warnings": warnings, "sampling_config": { "n_chains": n_chains, "n_warmup": n_warmup, "n_samples": n_samples, }, "timing": { "warmup_seconds": warmup_time, "sampling_seconds": sampling_time, "total_seconds": warmup_time + sampling_time, }, "per_parameter": { name: { "r_hat": r_hat.get(name, np.nan), "ess_bulk": ess_bulk.get(name, np.nan), "ess_tail": ess_tail.get(name, np.nan), } for name in r_hat.keys() }, }
[docs] def summarize_diagnostics( r_hat: dict[str, float], ess_bulk: dict[str, float], divergences: int, n_samples: int, n_chains: int, num_shards: int = 1, ) -> str: """Create human-readable diagnostics summary. Parameters ---------- r_hat : dict[str, float] R-hat values. ess_bulk : dict[str, float] ESS values. divergences : int Divergence count. n_samples : int Samples per chain. n_chains : int Number of chains. num_shards : int Number of shards (for CMC). Returns ------- str Summary string. """ r_hat_values = [v for v in r_hat.values() if not np.isnan(v)] ess_values = [v for v in ess_bulk.values() if not np.isnan(v)] max_rhat = max(r_hat_values) if r_hat_values else np.nan min_ess = min(ess_values) if ess_values else np.nan total = num_shards * n_samples * n_chains div_rate = divergences / total if total > 0 else 0 return ( f"Diagnostics: R-hat(max)={max_rhat:.3f}, " f"ESS(min)={min_ess:.0f}, " f"divergences={divergences} ({div_rate:.1%})" )
[docs] def log_analysis_summary( convergence_status: str, r_hat: dict[str, float], ess_bulk: dict[str, float], divergences: int, n_samples: int, n_chains: int, n_shards: int, shards_succeeded: int, execution_time: float, ) -> None: """Log a comprehensive summary at the end of CMC analysis. Parameters ---------- convergence_status : str Final convergence status. r_hat : dict[str, float] Per-parameter R-hat values. ess_bulk : dict[str, float] Per-parameter bulk ESS values. divergences : int Total divergent transitions. n_samples : int Samples per chain. n_chains : int Number of chains. n_shards : int Total number of shards. shards_succeeded : int Number of successful shards. execution_time : float Total execution time in seconds. """ r_hat_values = [v for v in r_hat.values() if not np.isnan(v)] ess_values = [v for v in ess_bulk.values() if not np.isnan(v)] max_rhat = max(r_hat_values) if r_hat_values else np.nan min_ess = min(ess_values) if ess_values else np.nan # For CMC, only successful shards contribute divergences to the total. # Using n_shards (total) would undercount the rate when some shards failed. total_transitions = shards_succeeded * n_samples * n_chains div_rate = divergences / total_transitions if total_transitions > 0 else 0 success_rate = shards_succeeded / n_shards if n_shards > 0 else 0 # Visual separator for easy identification in logs logger.info("=" * 60) logger.info("CMC ANALYSIS SUMMARY") logger.info("=" * 60) # Status with clear indicator if convergence_status == "converged": logger.info("Status: CONVERGED") else: logger.error(f"Status: {convergence_status.upper()}") # Key metrics logger.info(f" Shards: {shards_succeeded}/{n_shards} ({success_rate:.0%} success)") logger.info(f" Runtime: {execution_time:.1f}s ({execution_time / 60:.1f} min)") logger.info( f" R-hat (max): {max_rhat:.4f} {'[OK]' if max_rhat <= 1.05 else '[FAIL]'}" ) logger.info( f" ESS (min): {min_ess:.0f} " f"{'[OK]' if min_ess >= DEFAULT_MIN_ESS else '[FAIL]'}" ) logger.info(f" Divergences: {divergences} ({div_rate:.1%})") # Recommendations if there are issues recommendations = get_convergence_recommendations( max_rhat, min_ess, divergences, n_samples, n_chains, n_shards ) if recommendations: logger.info("-" * 40) logger.info("RECOMMENDATIONS:") for rec in recommendations: logger.info(f" - {rec}") logger.info("=" * 60)
[docs] def get_convergence_recommendations( max_rhat: float, min_ess: float, divergences: int, n_samples: int, n_chains: int, num_shards: int = 1, ) -> list[str]: """Generate specific recommendations for convergence issues. Parameters ---------- max_rhat : float Maximum R-hat value across parameters. min_ess : float Minimum bulk ESS across parameters. divergences : int Number of divergent transitions. n_samples : int Samples per chain. n_chains : int Number of chains. num_shards : int Number of shards (for CMC). Returns ------- list[str] List of recommendation strings. """ recommendations: list[str] = [] total_transitions = num_shards * n_samples * n_chains div_rate = divergences / total_transitions if total_transitions > 0 else 0 # R-hat recommendations if np.isfinite(max_rhat) and max_rhat > 1.1: recommendations.append( f"HIGH R-HAT ({max_rhat:.3f}): Chains have not mixed. " f"Try: increase num_warmup, " f"or use more chains (currently {n_chains})." ) elif np.isfinite(max_rhat) and max_rhat > 1.05: recommendations.append( f"MARGINAL R-HAT ({max_rhat:.3f}): Consider increasing num_samples " f"or num_warmup for better convergence." ) # ESS recommendations if np.isfinite(min_ess) and min_ess < 100: recommendations.append( f"LOW ESS ({min_ess:.0f}): High autocorrelation in samples. " f"Try: increase num_samples (currently {n_samples}) to at least {int(100 * n_samples / max(min_ess, 1))}." ) elif np.isfinite(min_ess) and min_ess < 400: recommendations.append( f"MODERATE ESS ({min_ess:.0f}): Consider increasing num_samples " f"for more reliable uncertainty estimates." ) # Divergence recommendations if div_rate > 0.10: recommendations.append( f"HIGH DIVERGENCES ({div_rate:.1%}): Model geometry issues. " "Try: reduce max_points_per_shard, increase target_accept_prob to 0.95, " "or check for data outliers." ) elif div_rate > 0.01: recommendations.append( f"MODERATE DIVERGENCES ({div_rate:.1%}): Some geometry issues. " "Consider increasing target_accept_prob to 0.90." ) # General efficiency recommendations if not recommendations and np.isfinite(max_rhat) and max_rhat <= 1.05: # Everything looks good - no recommendations needed pass return recommendations
# ============================================================================= # PRECISION DIAGNOSTICS (Jan 2026) # =============================================================================
[docs] def compute_posterior_contraction( posterior_std: float, prior_std: float, ) -> float: """Compute Posterior Contraction Ratio (PCR). PCR measures how much the data informed the posterior relative to the prior. PCR = 1 - (posterior_std / prior_std) Interpretation: - PCR ≈ 0: Posterior ≈ prior (data didn't constrain the parameter) - PCR ≈ 0.5: Posterior half as wide as prior (moderate constraint) - PCR ≈ 0.9: Posterior 10% as wide as prior (strong constraint) - PCR < 0: Posterior wider than prior (model misspecification or numerical issues) Parameters ---------- posterior_std : float Standard deviation of the posterior distribution. prior_std : float Standard deviation of the prior distribution. Returns ------- float Posterior contraction ratio, typically in [0, 1]. """ if prior_std <= 0 or not np.isfinite(prior_std): return np.nan if posterior_std <= 0 or not np.isfinite(posterior_std): return np.nan return 1.0 - (posterior_std / prior_std)
[docs] def compute_nlsq_comparison_metrics( cmc_mean: float, cmc_std: float, nlsq_value: float, nlsq_std: float | None = None, ) -> dict[str, float]: """Compute metrics comparing CMC posterior to NLSQ point estimate. Parameters ---------- cmc_mean : float CMC posterior mean. cmc_std : float CMC posterior standard deviation. nlsq_value : float NLSQ point estimate. nlsq_std : float | None NLSQ standard error. If None, only CMC-based metrics computed. Returns ------- dict[str, float] Dictionary with comparison metrics: - z_score: abs(CMC_mean - NLSQ) / CMC_std (should be < 2 for consistency) - uncertainty_ratio: CMC_std / NLSQ_std (should be < 5x ideally) - relative_diff: (CMC_mean - NLSQ) / abs(NLSQ) (percent difference) - coverage: Whether NLSQ falls within CMC 95% CI """ metrics = {} # Z-score: How many CMC standard deviations away is NLSQ? if cmc_std > 0 and np.isfinite(cmc_std): z_score = abs(cmc_mean - nlsq_value) / cmc_std metrics["z_score"] = z_score # Coverage: Does 95% CI contain NLSQ? metrics["coverage_95"] = float(z_score < 1.96) else: metrics["z_score"] = np.nan metrics["coverage_95"] = np.nan # Relative difference (percent) if nlsq_value != 0 and np.isfinite(nlsq_value): metrics["relative_diff"] = (cmc_mean - nlsq_value) / abs(nlsq_value) else: metrics["relative_diff"] = np.nan # Uncertainty ratio (if NLSQ std available) if nlsq_std is not None and nlsq_std > 0 and np.isfinite(nlsq_std): metrics["uncertainty_ratio"] = cmc_std / nlsq_std else: metrics["uncertainty_ratio"] = np.nan return metrics
[docs] def compute_precision_analysis( cmc_result: dict[str, dict], nlsq_result: dict[str, float] | None = None, nlsq_uncertainties: dict[str, float] | None = None, prior_stds: dict[str, float] | None = None, ) -> dict[str, dict[str, float]]: """Compute comprehensive precision analysis for all parameters. Parameters ---------- cmc_result : dict[str, dict] CMC posterior statistics, keyed by parameter name. Each entry should have "mean" and "std" keys. nlsq_result : dict[str, float] | None NLSQ point estimates, keyed by parameter name. nlsq_uncertainties : dict[str, float] | None NLSQ standard errors, keyed by parameter name. prior_stds : dict[str, float] | None Prior standard deviations, keyed by parameter name. Returns ------- dict[str, dict[str, float]] Precision metrics for each parameter. """ analysis = {} for param_name, stats in cmc_result.items(): # Skip non-physical parameters if param_name in ("sigma", "obs", "n_numerical_issues"): continue param_metrics = { "cmc_mean": stats.get("mean", np.nan), "cmc_std": stats.get("std", np.nan), } # Add posterior contraction if prior_std available if prior_stds and param_name in prior_stds: pcr = compute_posterior_contraction( stats.get("std", np.nan), prior_stds[param_name], ) param_metrics["posterior_contraction"] = pcr param_metrics["prior_std"] = prior_stds[param_name] # Add NLSQ comparison if available if nlsq_result and param_name in nlsq_result: nlsq_val = nlsq_result[param_name] nlsq_std = ( nlsq_uncertainties.get(param_name) if nlsq_uncertainties else None ) comparison = compute_nlsq_comparison_metrics( cmc_mean=stats.get("mean", np.nan), cmc_std=stats.get("std", np.nan), nlsq_value=nlsq_val, nlsq_std=nlsq_std, ) param_metrics.update(comparison) param_metrics["nlsq_value"] = nlsq_val if nlsq_std is not None: param_metrics["nlsq_std"] = nlsq_std analysis[param_name] = param_metrics return analysis
[docs] def log_precision_analysis( analysis: dict[str, dict[str, float]], log_fn: Callable[[str], None] | None = None, tolerance_pct: float = 20.0, ) -> str: """Log a comprehensive precision analysis report. Parameters ---------- analysis : dict[str, dict[str, float]] Output from compute_precision_analysis(). log_fn : callable | None Logging function. If None, uses module logger. tolerance_pct : float Percentage tolerance threshold for flagging parameters. Default 20% - parameters exceeding this are flagged. Returns ------- str Formatted analysis report. """ if log_fn is None: log_fn = logger.info lines = ["=" * 80, "CMC vs NLSQ PRECISION ANALYSIS", "=" * 80] # Summary statistics z_scores = [ m.get("z_score", np.nan) for m in analysis.values() if np.isfinite(m.get("z_score", np.nan)) ] rel_diffs = [ abs(m.get("relative_diff", np.nan) * 100) for m in analysis.values() if np.isfinite(m.get("relative_diff", np.nan)) ] unc_ratios = [ m.get("uncertainty_ratio", np.nan) for m in analysis.values() if np.isfinite(m.get("uncertainty_ratio", np.nan)) ] pcrs = [ m.get("posterior_contraction", np.nan) for m in analysis.values() if np.isfinite(m.get("posterior_contraction", np.nan)) ] # Summary section lines.append("SUMMARY:") if z_scores: max_z = max(z_scores) mean_z = np.mean(z_scores) lines.append(f" Z-scores: max={max_z:.2f}, mean={mean_z:.2f}") high_z = sum(1 for z in z_scores if z > 2) very_high_z = sum(1 for z in z_scores if z > 3) if very_high_z > 0: lines.append( f" CRITICAL: {very_high_z}/{len(z_scores)} params have z > 3 (severe disagreement)" ) elif high_z > 0: lines.append( f" WARNING: {high_z}/{len(z_scores)} params have z > 2 (significant disagreement)" ) else: lines.append(" All params have z <= 2 (good agreement)") if rel_diffs: max_diff = max(rel_diffs) mean_diff = np.mean(rel_diffs) lines.append( f" Percent differences: max={max_diff:.1f}%, mean={mean_diff:.1f}%" ) over_tolerance = sum(1 for d in rel_diffs if d > tolerance_pct) if over_tolerance > 0: lines.append( f" WARNING: {over_tolerance}/{len(rel_diffs)} params exceed {tolerance_pct:.0f}% tolerance" ) else: lines.append(f" All params within {tolerance_pct:.0f}% tolerance") if unc_ratios: lines.append( f" Uncertainty ratio (CMC/NLSQ): max={max(unc_ratios):.1f}x, median={np.median(unc_ratios):.1f}x" ) # Flag ratios < 0.5 (CMC too precise - possibly corrupted) or > 10 (CMC too uncertain) too_precise = sum(1 for r in unc_ratios if r < 0.5) too_uncertain = sum(1 for r in unc_ratios if r > 10) if too_precise > 0: lines.append( f" WARNING: {too_precise}/{len(unc_ratios)} params have ratio < 0.5x " "(CMC artificially precise - check for shard heterogeneity)" ) if too_uncertain > 0: lines.append( f" INFO: {too_uncertain}/{len(unc_ratios)} params have ratio > 10x (CMC more uncertain)" ) if pcrs: lines.append( f" Posterior contraction: max={max(pcrs):.2f}, mean={np.mean(pcrs):.2f}" ) low_pcr = sum(1 for p in pcrs if p < 0.3) if low_pcr > 0: lines.append( f" INFO: {low_pcr}/{len(pcrs)} params have PCR < 0.3 (weak data constraint)" ) lines.append("-" * 80) lines.append( f"{'Parameter':<18} {'CMC Mean':>11} {'CMC Std':>10} {'NLSQ':>11} " f"{'Diff%':>7} {'Z':>6} {'Ratio':>7}" ) lines.append("-" * 80) for param_name, metrics in sorted(analysis.items()): cmc_mean = metrics.get("cmc_mean", np.nan) cmc_std = metrics.get("cmc_std", np.nan) nlsq_val = metrics.get("nlsq_value", np.nan) z_score = metrics.get("z_score", np.nan) rel_diff = metrics.get("relative_diff", np.nan) unc_ratio = metrics.get("uncertainty_ratio", np.nan) # Format with appropriate precision cmc_mean_str = f"{cmc_mean:.4g}" if np.isfinite(cmc_mean) else "N/A" cmc_std_str = f"{cmc_std:.4g}" if np.isfinite(cmc_std) else "N/A" nlsq_str = f"{nlsq_val:.4g}" if np.isfinite(nlsq_val) else "N/A" z_str = f"{z_score:.2f}" if np.isfinite(z_score) else "N/A" diff_str = f"{rel_diff * 100:+.1f}%" if np.isfinite(rel_diff) else "N/A" ratio_str = f"{unc_ratio:.1f}x" if np.isfinite(unc_ratio) else "N/A" # Add warning markers marker = "" if np.isfinite(z_score) and z_score > 3: marker = " [SEVERE]" elif np.isfinite(z_score) and z_score > 2: marker = " [WARN]" elif np.isfinite(rel_diff) and abs(rel_diff * 100) > tolerance_pct: marker = " [WARN]" elif np.isfinite(unc_ratio) and unc_ratio < 0.5: marker = " [WARN]" # Artificially precise lines.append( f"{param_name:<18} {cmc_mean_str:>11} {cmc_std_str:>10} " f"{nlsq_str:>11} {diff_str:>7} {z_str:>6} {ratio_str:>7}{marker}" ) lines.append("=" * 80) report = "\n".join(lines) log_fn(report) return report
# ============================================================================= # BIMODAL DETECTION (Jan 2026) # =============================================================================
[docs] @dataclass class BimodalResult: r"""Result of bimodal detection for a single parameter. Attributes ---------- is_bimodal : bool Whether the posterior appears bimodal. weights : tuple[float, float] Component weights from GMM. means : tuple[float, float] Component means from GMM. stds : tuple[float, float] Per-component standard deviations from GMM. separation : float Absolute distance between means. relative_separation : float Separation relative to scale (separation / ``|mean(means)|``). """ is_bimodal: bool weights: tuple[float, float] means: tuple[float, float] stds: tuple[float, float] separation: float relative_separation: float
[docs] @dataclass class ModeCluster: """A single mode from bimodal consensus combination. Attributes ---------- mean : dict[str, float] Per-parameter consensus mean for this mode. std : dict[str, float] Per-parameter consensus std for this mode. weight : float Fraction of shards supporting this mode (0-1). n_shards : int Number of shards in this cluster. samples : dict[str, np.ndarray] Generated samples from N(mean, std^2), shape (n_chains, n_samples). """ mean: dict[str, float] std: dict[str, float] weight: float n_shards: int samples: dict[str, np.ndarray]
[docs] @dataclass class BimodalConsensusResult: """Result of mode-aware consensus combination. Attached to MCMCSamples when bimodal posteriors are detected and per-mode consensus is used instead of standard combination. Attributes ---------- modes : list[ModeCluster] Mode clusters (typically 2) with per-mode consensus statistics. modal_params : list[str] Parameter names that triggered bimodal detection. co_occurrence : dict[str, Any] Cross-parameter co-occurrence info (e.g., D0-alpha correlation). """ modes: list[ModeCluster] modal_params: list[str] co_occurrence: dict[str, Any]
[docs] def detect_bimodal( samples: np.ndarray, min_weight: float = 0.2, min_relative_separation: float = 0.5, ) -> BimodalResult: """Detect bimodality using 2-component Gaussian Mixture Model. Parameters ---------- samples : np.ndarray 1D array of posterior samples. min_weight : float Minimum weight for both components to be considered bimodal. min_relative_separation : float Minimum separation between means (relative to scale) for bimodality. Returns ------- BimodalResult Detection result with component details. """ if not HAS_SKLEARN: logger.debug("sklearn not available, skipping bimodality detection") return BimodalResult( is_bimodal=False, weights=(1.0, 0.0), means=(float(np.nanmean(samples)), float(np.nanmean(samples))), stds=(0.0, 0.0), separation=0.0, relative_separation=0.0, ) samples_2d = samples.reshape(-1, 1) # GaussianMixture requires at least n_components (2) samples. # Return a non-bimodal result for degenerate inputs rather than raising. if len(samples_2d) < 2: sample_val = float(samples_2d[0, 0]) if len(samples_2d) == 1 else 0.0 return BimodalResult( is_bimodal=False, weights=(1.0, 0.0), means=(sample_val, sample_val), stds=(0.0, 0.0), separation=0.0, relative_separation=0.0, ) gmm = GaussianMixture(n_components=2, random_state=42, n_init=3) gmm.fit(samples_2d) weights = tuple(gmm.weights_.tolist()) means = tuple(gmm.means_.flatten().tolist()) stds = tuple(np.sqrt(gmm.covariances_.flatten()).tolist()) separation = abs(means[0] - means[1]) scale = max(abs(np.mean(means)), 1e-10) relative_separation = separation / scale # Bimodal if: both components significant AND well-separated is_bimodal = ( min(weights) > min_weight and relative_separation > min_relative_separation ) return BimodalResult( is_bimodal=is_bimodal, weights=weights, means=means, stds=stds, separation=separation, relative_separation=relative_separation, )
[docs] def check_shard_bimodality( samples: dict[str, np.ndarray], params_to_check: list[str] | None = None, ) -> dict[str, BimodalResult]: """Check multiple parameters for bimodality. Parameters ---------- samples : dict[str, np.ndarray] Parameter samples from a shard. params_to_check : list[str], optional Parameters to check. Defaults to key physical parameters. Returns ------- dict[str, BimodalResult] Mapping from param name to BimodalResult. """ if params_to_check is None: params_to_check = ["D0", "D_offset", "gamma_dot_t0", "beta", "alpha"] results = {} for param in params_to_check: if param in samples: results[param] = detect_bimodal(samples[param].flatten()) return results
[docs] def summarize_cross_shard_bimodality( bimodal_detections: list[dict[str, Any]], n_shards: int, consensus_means: dict[str, float] | None = None, significance_threshold: float = 0.05, ) -> dict[str, Any]: """Aggregate per-shard bimodal detections into a cross-shard summary. Groups detections by parameter, computes mode statistics, separation significance, and D0-alpha co-occurrence to quantify consensus distortion. Parameters ---------- bimodal_detections : list[dict[str, Any]] Per-detection records, each with keys: "shard", "param", "mode1", "mode2", "weights", "separation". n_shards : int Total number of successful shards (denominator for bimodal fraction). consensus_means : dict[str, float] | None Mean-of-means for each parameter (pre-combine estimate). Used to check if consensus falls in a density trough between modes. significance_threshold : float Minimum bimodal fraction (detections/n_shards) to include a parameter in the summary. Default 5%. Returns ------- dict[str, Any] Summary with keys: - "per_param": dict mapping param name to per-parameter stats - "co_occurrence": dict with D0-alpha co-occurrence info - "n_detections": total detection count - "n_shards": total shard count """ if not bimodal_detections or n_shards == 0: return { "per_param": {}, "co_occurrence": {}, "n_detections": 0, "n_shards": n_shards, } # Group detections by parameter by_param: dict[str, list[dict[str, Any]]] = {} for det in bimodal_detections: param = det["param"] by_param.setdefault(param, []).append(det) per_param: dict[str, dict[str, Any]] = {} for param, detections in by_param.items(): bimodal_fraction = len(detections) / n_shards if bimodal_fraction < significance_threshold: continue # Sort each detection's modes into lower/upper lower_modes = [] upper_modes = [] for det in detections: lo, hi = sorted([det["mode1"], det["mode2"]]) lower_modes.append(lo) upper_modes.append(hi) lower_arr = np.array(lower_modes) upper_arr = np.array(upper_modes) lower_mean = float(np.mean(lower_arr)) lower_std = float(np.std(lower_arr)) upper_mean = float(np.mean(upper_arr)) upper_std = float(np.std(upper_arr)) # Separation significance: how many pooled-std apart are the mode clusters? pooled_std = np.sqrt(lower_std**2 + upper_std**2) sep_significance = ( abs(upper_mean - lower_mean) / pooled_std if pooled_std > 0 else float("inf") ) # Check if consensus falls between modes (density trough) consensus_in_trough = False if consensus_means is not None and param in consensus_means: c = consensus_means[param] consensus_in_trough = lower_mean < c < upper_mean per_param[param] = { "n_detections": len(detections), "bimodal_fraction": bimodal_fraction, "lower_mean": lower_mean, "lower_std": lower_std, "upper_mean": upper_mean, "upper_std": upper_std, "sep_significance": sep_significance, "consensus_in_trough": consensus_in_trough, } # D0-alpha co-occurrence: fraction of D0-bimodal shards also bimodal in alpha co_occurrence: dict[str, Any] = {} d0_shards = {det["shard"] for det in by_param.get("D0", [])} alpha_shards = {det["shard"] for det in by_param.get("alpha", [])} if d0_shards: overlap = d0_shards & alpha_shards co_occurrence["d0_alpha_overlap"] = len(overlap) co_occurrence["d0_alpha_fraction"] = len(overlap) / len(d0_shards) co_occurrence["d0_bimodal_shards"] = len(d0_shards) co_occurrence["alpha_bimodal_shards"] = len(alpha_shards) return { "per_param": per_param, "co_occurrence": co_occurrence, "n_detections": len(bimodal_detections), "n_shards": n_shards, }
[docs] def cluster_shard_modes( bimodal_detections: list[dict[str, Any]], successful_samples: list[Any], bimodal_summary: dict[str, Any], param_bounds: dict[str, tuple[float, float]], ) -> tuple[list[int], list[int]]: """Jointly cluster shards into two mode populations. Uses range-normalized feature vectors from modal parameters to assign each shard to the nearest mode centroid. Bimodal shards contribute one component to each cluster. Parameters ---------- bimodal_detections : list[dict[str, Any]] Per-detection records with keys: "shard", "param", "mode1", "mode2", "std1", "std2", "weights", "separation". successful_samples : list[Any] List of MCMCSamples (or similar with .samples dict attribute). bimodal_summary : dict[str, Any] Output from summarize_cross_shard_bimodality(). param_bounds : dict[str, tuple[float, float]] Parameter bounds for range-based normalization, {param: (lo, hi)}. Returns ------- tuple[list[int], list[int]] (cluster_0_shards, cluster_1_shards) where cluster_0 is "lower" and cluster_1 is "upper". Bimodal shards appear in both lists. """ per_param = bimodal_summary.get("per_param", {}) modal_params = sorted(per_param.keys()) n_shards = len(successful_samples) if not modal_params: return list(range(n_shards)), [] # Build centroids from cross-shard summary (lower/upper means) centroid_lower = [] centroid_upper = [] scales = [] for param in modal_params: stats = per_param[param] centroid_lower.append(stats["lower_mean"]) centroid_upper.append(stats["upper_mean"]) lo, hi = param_bounds.get(param, (0.0, 1.0)) param_range = abs(hi - lo) scales.append(max(param_range, 1e-10)) scales_arr = np.array(scales) centroid_lower_norm = np.array(centroid_lower) / scales_arr centroid_upper_norm = np.array(centroid_upper) / scales_arr # Index bimodal detections by shard for fast lookup bimodal_by_shard: dict[int, dict[str, dict[str, Any]]] = {} for det in bimodal_detections: shard_idx = det["shard"] param = det["param"] if param in modal_params: bimodal_by_shard.setdefault(shard_idx, {})[param] = det cluster_lower: list[int] = [] cluster_upper: list[int] = [] for shard_idx in range(n_shards): shard_bimodal = bimodal_by_shard.get(shard_idx, {}) if shard_bimodal: cluster_lower.append(shard_idx) cluster_upper.append(shard_idx) else: feature = [] for param in modal_params: if param in successful_samples[shard_idx].samples: mean_val = float( np.nanmean(successful_samples[shard_idx].samples[param]) ) else: idx = modal_params.index(param) mean_val = (centroid_lower[idx] + centroid_upper[idx]) / 2 feature.append(mean_val) feature_norm = np.array(feature) / scales_arr dist_lower = np.linalg.norm(feature_norm - centroid_lower_norm) dist_upper = np.linalg.norm(feature_norm - centroid_upper_norm) if dist_lower <= dist_upper: cluster_lower.append(shard_idx) else: cluster_upper.append(shard_idx) return cluster_lower, cluster_upper