Source code for homodyne.io.mcmc_writers

"""MCMC result saving functions for homodyne XPCS analysis.

This module provides functions for creating dictionaries to save MCMC/CMC
optimization results to disk.
Extracted from cli/commands.py for better modularity.
"""

from datetime import datetime
from typing import Any, Literal

import numpy as np

from homodyne.config.parameter_names import get_physical_param_names
from homodyne.io.json_utils import json_safe as _json_safe
from homodyne.utils.logging import get_logger

logger = get_logger(__name__)


def _get_parameter_names(analysis_mode: str) -> list[str]:
    """Get physical parameter names for given analysis mode.

    This is a thin wrapper around get_physical_param_names() that handles
    unknown modes gracefully with a warning instead of raising an exception.

    Parameters
    ----------
    analysis_mode : str
        Analysis mode ("static" or "laminar_flow")

    Returns
    -------
    list[str]
        List of physical parameter names (without contrast/offset)
    """
    mode: Literal["static_isotropic", "laminar_flow"]
    if analysis_mode == "static":
        mode = "static_isotropic"
    elif analysis_mode == "laminar_flow":
        mode = "laminar_flow"
    else:
        logger.warning(f"Unknown analysis mode: {analysis_mode}, assuming static")
        mode = "static_isotropic"

    return get_physical_param_names(mode)


[docs] def create_mcmc_parameters_dict(result: Any) -> dict: """Create parameters dictionary with posterior statistics. Parameters ---------- result : MCMCResult MCMC result with posterior samples and statistics Returns ------- dict Structured parameter dictionary with posterior mean ± std """ diag_summary = getattr(result, "diagnostic_summary", {}) or {} deterministic_params = set(diag_summary.get("deterministic_params") or []) param_dict: dict[str, Any] = { "timestamp": datetime.now().isoformat(), "analysis_mode": getattr(result, "analysis_mode", "unknown"), "method": ( "cmc" if (hasattr(result, "is_cmc_result") and result.is_cmc_result()) else "mcmc" ), "sampling_summary": { "n_samples": getattr(result, "n_samples", 0), "n_warmup": getattr(result, "n_warmup", 0), "n_chains": getattr(result, "n_chains", 1), "total_samples": getattr(result, "n_samples", 0) * getattr(result, "n_chains", 1), "computation_time": getattr(result, "computation_time", 0.0), }, "convergence": {}, "parameters": {}, } # Add convergence diagnostics if available if hasattr(result, "r_hat") and result.r_hat is not None: if isinstance(result.r_hat, dict): r_hat_values = [ v for name, v in result.r_hat.items() if v is not None and name not in deterministic_params ] if r_hat_values: convergence_dict = param_dict["convergence"] convergence_dict["all_chains_converged"] = bool( all(v < 1.1 for v in r_hat_values) ) convergence_dict["min_r_hat"] = _json_safe(float(min(r_hat_values))) convergence_dict["max_r_hat"] = _json_safe(float(max(r_hat_values))) else: r_hat = np.asarray(result.r_hat) finite_rhat = r_hat[np.isfinite(r_hat)] convergence_dict = param_dict["convergence"] convergence_dict["all_chains_converged"] = bool( np.all(finite_rhat < 1.1) if finite_rhat.size > 0 else False ) convergence_dict["min_r_hat"] = _json_safe( float(np.min(finite_rhat)) if finite_rhat.size > 0 else None ) convergence_dict["max_r_hat"] = _json_safe( float(np.max(finite_rhat)) if finite_rhat.size > 0 else None ) if ( hasattr(result, "effective_sample_size") and result.effective_sample_size is not None ): if isinstance(result.effective_sample_size, dict): ess_values = [ v for v in result.effective_sample_size.values() if v is not None and np.isfinite(v) ] if ess_values: convergence_dict = param_dict["convergence"] convergence_dict["min_ess"] = _json_safe(float(min(ess_values))) else: ess = np.asarray(result.effective_sample_size) finite_ess = ess[np.isfinite(ess)] convergence_dict = param_dict["convergence"] convergence_dict["min_ess"] = ( _json_safe(float(np.min(finite_ess))) if finite_ess.size > 0 else None ) if hasattr(result, "acceptance_rate") and result.acceptance_rate is not None: convergence_dict = param_dict["convergence"] convergence_dict["acceptance_rate"] = _json_safe(float(result.acceptance_rate)) # Add scaling parameters (contrast, offset) if hasattr(result, "mean_contrast"): parameters_dict = param_dict["parameters"] parameters_dict["contrast"] = { "mean": _json_safe(float(result.mean_contrast)), "std": _json_safe(float(getattr(result, "std_contrast", 0.0))), } if hasattr(result, "mean_offset"): parameters_dict = param_dict["parameters"] parameters_dict["offset"] = { "mean": _json_safe(float(result.mean_offset)), "std": _json_safe(float(getattr(result, "std_offset", 0.0))), } # Add physical parameters if hasattr(result, "mean_params") and result.mean_params is not None: analysis_mode = getattr(result, "analysis_mode", "static") param_names = _get_parameter_names(analysis_mode) mean_params_obj = result.mean_params std_params_obj = getattr(result, "std_params", None) # CRITICAL FIX (Dec 2025): Check dict FIRST, before as_array. # ParameterStats inherits from dict AND has as_array property. # The as_array returns values in build order (from from_mcmc_samples), # which may not match canonical param_names order from get_physical_param_names(). # Using dict access ensures correct name-to-value mapping regardless of order. if isinstance(mean_params_obj, dict): mean_params_arr = np.array( [mean_params_obj.get(name, np.nan) for name in param_names] ) elif hasattr(mean_params_obj, "as_array"): mean_params_arr = np.asarray(mean_params_obj.as_array) else: mean_params_arr = np.asarray(mean_params_obj) # Same fix for std_params_obj - check dict first if isinstance(std_params_obj, dict): std_params_arr = np.array( [std_params_obj.get(name, 0.0) for name in param_names] ) elif std_params_obj is not None and hasattr(std_params_obj, "as_array"): std_params_arr = np.asarray(std_params_obj.as_array) else: std_params_arr = ( np.asarray(std_params_obj) if std_params_obj is not None else np.zeros_like(mean_params_arr) ) parameters_dict = param_dict["parameters"] for i, name in enumerate(param_names): if i < len(mean_params_arr): parameters_dict[name] = { "mean": _json_safe(float(mean_params_arr[i])), "std": _json_safe(float(std_params_arr[i])) if i < len(std_params_arr) else 0.0, } return param_dict
[docs] def create_mcmc_analysis_dict( result: Any, data: dict[str, Any], method_name: str, ) -> dict: """Create analysis results dictionary for MCMC/CMC. Parameters ---------- result : MCMCResult MCMC result with diagnostics data : dict Experimental data dictionary method_name : str "mcmc" or "cmc" Returns ------- dict Analysis summary dictionary """ # Get dataset dimensions c2_exp = data.get("c2_exp", []) n_angles = len(data.get("phi_angles_list", [])) n_time_points = ( c2_exp.shape[1] * c2_exp.shape[2] if hasattr(c2_exp, "shape") and len(c2_exp.shape) >= 3 else 0 ) total_data_points = c2_exp.size if hasattr(c2_exp, "size") else 0 # Determine sampling quality quality_flag = "unknown" warnings = [] recommendations = [] if hasattr(result, "r_hat") and result.r_hat is not None: if isinstance(result.r_hat, dict): r_hat_values = [ v for v in result.r_hat.values() if v is not None and np.isfinite(v) ] max_r_hat = max(r_hat_values) if r_hat_values else None else: r_hat = np.asarray(result.r_hat) finite_rhat_analysis = r_hat[np.isfinite(r_hat)] max_r_hat = ( float(np.max(finite_rhat_analysis)) if finite_rhat_analysis.size > 0 else None ) if max_r_hat is not None: if max_r_hat < 1.05: quality_flag = "good" elif max_r_hat < 1.1: quality_flag = "acceptable" warnings.append( f"Some parameters have R-hat between 1.05-1.1 (max={max_r_hat:.3f})" ) else: quality_flag = "poor" warnings.append( f"Convergence issues detected (max R-hat={max_r_hat:.3f})" ) recommendations.append("Consider increasing n_warmup or n_samples") # Resolve ESS source: ess_bulk (CMCResult) or effective_sample_size (legacy) ess_source_analysis = None if hasattr(result, "ess_bulk") and result.ess_bulk is not None: ess_source_analysis = result.ess_bulk elif ( hasattr(result, "effective_sample_size") and result.effective_sample_size is not None ): ess_source_analysis = result.effective_sample_size if ess_source_analysis is not None: if isinstance(ess_source_analysis, dict): ess_values = [ v for v in ess_source_analysis.values() if v is not None and np.isfinite(v) ] min_ess = min(ess_values) if ess_values else None else: ess = np.asarray(ess_source_analysis) finite_ess = ess[np.isfinite(ess)] min_ess = float(np.min(finite_ess)) if finite_ess.size > 0 else None if min_ess is not None and min_ess < 400: warnings.append(f"Low effective sample size (min ESS={min_ess:.0f})") recommendations.append( "Consider increasing n_samples for better posterior estimates" ) analysis_dict = { "method": method_name, "timestamp": datetime.now().isoformat(), "analysis_mode": getattr(result, "analysis_mode", "unknown"), "sampling_quality": { "convergence_status": ( "converged" if quality_flag in ["good", "acceptable"] else "not_converged" ), "quality_flag": quality_flag, "warnings": warnings, "recommendations": recommendations, }, "dataset_info": { "n_angles": n_angles, "n_time_points": n_time_points, "total_data_points": total_data_points, "q_value": ( _json_safe(float(data["wavevector_q_list"][0])) if data.get("wavevector_q_list") else 0.0 ), }, "sampling_summary": { "n_samples": getattr(result, "n_samples", 0), "n_warmup": getattr(result, "n_warmup", 0), "n_chains": getattr(result, "n_chains", 1), "execution_time": _json_safe( float( getattr( result, "execution_time", getattr(result, "computation_time", 0.0), ) ) ), }, } # v2.1.0: Add config-driven metadata if available if ( hasattr(result, "parameter_space_metadata") and result.parameter_space_metadata is not None ): analysis_dict["parameter_space"] = result.parameter_space_metadata if ( hasattr(result, "initial_values_metadata") and result.initial_values_metadata is not None ): analysis_dict["initial_values"] = result.initial_values_metadata if ( hasattr(result, "selection_decision_metadata") and result.selection_decision_metadata is not None ): analysis_dict["selection_decision"] = result.selection_decision_metadata return analysis_dict
[docs] def create_mcmc_diagnostics_dict(result: Any) -> dict: """Create diagnostics dictionary for MCMC/CMC. Parameters ---------- result : MCMCResult MCMC result with convergence diagnostics Returns ------- dict Diagnostics dictionary with convergence metrics """ diagnostics_dict: dict[str, Any] = { "convergence": {}, "sampling_efficiency": {}, "posterior_checks": {}, } diag_summary = getattr(result, "diagnostic_summary", {}) or {} deterministic_params = set(diag_summary.get("deterministic_params") or []) per_param_stats = diag_summary.get("per_param_stats") or {} # Convergence diagnostics if hasattr(result, "r_hat") and result.r_hat is not None: if isinstance(result.r_hat, dict): r_hat_values = [ v for v in result.r_hat.values() if v is not None and np.isfinite(v) ] if r_hat_values: diagnostics_dict["convergence"]["all_chains_converged"] = bool( all(v < 1.1 for v in r_hat_values) ) diagnostics_dict["convergence"]["r_hat_threshold"] = 1.1 # Add per-parameter diagnostics # Resolve ESS source: ess_bulk (CMCResult) or effective_sample_size (legacy) ess_dict: dict | None = None if hasattr(result, "ess_bulk") and isinstance(result.ess_bulk, dict): ess_dict = result.ess_bulk elif hasattr(result, "effective_sample_size") and isinstance( result.effective_sample_size, dict ): ess_dict = result.effective_sample_size per_param = [] for param_name, r_hat_val in result.r_hat.items(): ess_val = ess_dict.get(param_name) if ess_dict else None per_param.append( { "name": param_name, "r_hat": _json_safe(float(r_hat_val)) if r_hat_val is not None else None, "ess": _json_safe(float(ess_val)) if ess_val is not None else None, "converged": bool(r_hat_val is not None and r_hat_val < 1.1), "deterministic": param_name in deterministic_params, } ) if per_param: diagnostics_dict["convergence"]["per_parameter_diagnostics"] = per_param else: r_hat = np.asarray(result.r_hat) finite_rhat = r_hat[np.isfinite(r_hat)] diagnostics_dict["convergence"]["all_chains_converged"] = bool( np.all(finite_rhat < 1.1) if finite_rhat.size > 0 else False ) diagnostics_dict["convergence"]["r_hat_threshold"] = 1.1 analysis_mode = getattr(result, "analysis_mode", "static") param_names = _get_parameter_names(analysis_mode) per_param = [] ess_array = ( np.asarray(result.effective_sample_size) if ( hasattr(result, "effective_sample_size") and result.effective_sample_size is not None and not isinstance(result.effective_sample_size, dict) ) else None ) for i, name in enumerate(param_names): if i < len(r_hat): ess_val = ( ess_array[i] if (ess_array is not None and i < len(ess_array)) else 0.0 ) per_param.append( { "name": name, "r_hat": _json_safe(float(r_hat[i])), "ess": _json_safe(float(ess_val)), "converged": bool(r_hat[i] < 1.1), "deterministic": name in deterministic_params, } ) diagnostics_dict["convergence"]["per_parameter_diagnostics"] = per_param has_ess = (hasattr(result, "ess_bulk") and result.ess_bulk is not None) or ( hasattr(result, "effective_sample_size") and result.effective_sample_size is not None ) if has_ess: diagnostics_dict["convergence"]["ess_threshold"] = 400 # Sampling efficiency if hasattr(result, "acceptance_rate") and result.acceptance_rate is not None: diagnostics_dict["sampling_efficiency"]["acceptance_rate"] = _json_safe( float(result.acceptance_rate) ) diagnostics_dict["sampling_efficiency"]["target_acceptance"] = 0.80 if hasattr(result, "divergences"): import math _div = result.divergences # int() raises ValueError on NaN; guard before converting. diagnostics_dict["sampling_efficiency"]["divergences"] = ( int(_div) if math.isfinite(float(_div)) else 0 ) if hasattr(result, "tree_depth_warnings"): import math _tdw = result.tree_depth_warnings diagnostics_dict["sampling_efficiency"]["tree_depth_warnings"] = ( int(_tdw) if (isinstance(_tdw, (int, float)) and math.isfinite(float(_tdw))) else 0 ) # Posterior checks if hasattr(result, "ess") and hasattr(result, "n_samples"): ess = np.asarray(result.ess) total_samples = result.n_samples * getattr(result, "n_chains", 1) if total_samples > 0: ess_ratio = _json_safe(float(np.nanmean(ess) / total_samples)) diagnostics_dict["posterior_checks"]["effective_sample_size_ratio"] = ( ess_ratio ) # Fallback per-parameter diagnostics if "per_parameter_diagnostics" not in diagnostics_dict["convergence"]: param_keys = set(per_param_stats.keys()) if ( hasattr(result, "r_hat") and result.r_hat is not None and isinstance(result.r_hat, dict) ): param_keys.update(result.r_hat.keys()) if ( hasattr(result, "effective_sample_size") and result.effective_sample_size is not None and isinstance(result.effective_sample_size, dict) ): param_keys.update(result.effective_sample_size.keys()) if param_keys: fallback_entries = [] for name in sorted(param_keys): stats = per_param_stats.get(name, {}) r_hat_val = None if hasattr(result, "r_hat") and isinstance(result.r_hat, dict): r_hat_val = result.r_hat.get(name) elif "r_hat" in stats: r_hat_val = stats.get("r_hat") ess_val = None if hasattr(result, "effective_sample_size") and isinstance( result.effective_sample_size, dict ): ess_val = result.effective_sample_size.get(name) elif "ess" in stats: ess_val = stats.get("ess") fallback_entries.append( { "name": name, "r_hat": _json_safe(float(r_hat_val)) if r_hat_val is not None else None, "ess": _json_safe(float(ess_val)) if ess_val is not None else None, "converged": bool( r_hat_val is not None and r_hat_val < diagnostics_dict["convergence"].get( "r_hat_threshold", 1.1 ) ), "deterministic": name in deterministic_params or stats.get("deterministic", False), } ) diagnostics_dict["convergence"]["per_parameter_diagnostics"] = ( fallback_entries ) # CMC-specific diagnostics if hasattr(result, "is_cmc_result") and result.is_cmc_result(): diagnostics_dict["cmc_specific"] = {} if hasattr(result, "per_shard_diagnostics") and result.per_shard_diagnostics: per_shard = result.per_shard_diagnostics acceptance_rates = [] converged_shards = 0 for shard in per_shard: if isinstance(shard, dict): if shard.get("acceptance_rate") is not None: acceptance_rates.append(float(shard["acceptance_rate"])) if shard.get("converged", False): converged_shards += 1 shard_summary: dict[str, Any] = { "num_shards": len(per_shard), "shards_converged": converged_shards, "convergence_rate": ( float(converged_shards / len(per_shard)) if len(per_shard) > 0 else 0.0 ), } if acceptance_rates: shard_summary["acceptance_rate_stats"] = { "mean": _json_safe(float(np.nanmean(acceptance_rates))), "min": _json_safe(float(np.nanmin(acceptance_rates))), "max": _json_safe(float(np.nanmax(acceptance_rates))), "std": _json_safe(float(np.nanstd(acceptance_rates))), } diagnostics_dict["cmc_specific"]["shard_summary"] = shard_summary if hasattr(result, "cmc_diagnostics") and result.cmc_diagnostics: cmc_diag = result.cmc_diagnostics overall_metrics: dict[str, Any] = {} if isinstance(cmc_diag, dict): if "combination_success" in cmc_diag: overall_metrics["combination_success"] = bool( cmc_diag["combination_success"] ) if "n_shards_converged" in cmc_diag: overall_metrics["n_shards_converged"] = int( cmc_diag["n_shards_converged"] ) if "n_shards_total" in cmc_diag: overall_metrics["n_shards_total"] = int(cmc_diag["n_shards_total"]) if "weighted_product_std" in cmc_diag: overall_metrics["weighted_product_std"] = _json_safe( float(cmc_diag["weighted_product_std"]) ) if "combination_time" in cmc_diag: overall_metrics["combination_time"] = _json_safe( float(cmc_diag["combination_time"]) ) if "success_rate" in cmc_diag: overall_metrics["success_rate"] = _json_safe( float(cmc_diag["success_rate"]) ) diagnostics_dict["cmc_specific"]["overall_diagnostics"] = ( overall_metrics ) if hasattr(result, "combination_method") and result.combination_method: diagnostics_dict["cmc_specific"]["combination_method"] = str( result.combination_method ) if hasattr(result, "num_shards") and result.num_shards: diagnostics_dict["cmc_specific"]["num_shards"] = int(result.num_shards) return diagnostics_dict