Source code for homodyne.viz.mcmc_diagnostics

"""MCMC Convergence Diagnostics Visualization.

Provides trace plots, KL divergence matrix heatmaps, and convergence
diagnostics (R-hat, ESS) for NUTS and CMC results.
"""

from __future__ import annotations

from pathlib import Path
from typing import Any

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.figure import Figure

from homodyne.utils.logging import get_logger
from homodyne.utils.path_validation import PathValidationError, validate_plot_save_path
from homodyne.viz.mcmc_arviz import _create_empty_figure

logger = get_logger(__name__)


[docs] def plot_trace_plots( result: Any, # MCMCResult type param_names: list[str] | None = None, max_params: int = 9, figsize: tuple[float, float] | None = None, show: bool = False, save_path: str | Path | None = None, dpi: int = 150, ) -> Figure: """Plot MCMC trace plots for convergence visualization. Creates a grid of trace plots showing parameter evolution across MCMC samples. For CMC results, overlays traces from multiple shards with different colors. Parameters ---------- result : MCMCResult MCMC or CMC result object containing samples param_names : list of str, optional Parameter names to plot. If None, uses default names (param_0, param_1, ...) max_params : int, default=9 Maximum number of parameters to plot (to avoid cluttered figures) figsize : tuple, optional Figure size (width, height). If None, auto-calculated based on param count show : bool, default=False If True, display the figure interactively save_path : str or Path, optional If provided, save figure to this path dpi : int, default=150 DPI for saved figure Returns ------- Figure Matplotlib figure object Examples -------- >>> # Standard NUTS result >>> plot_trace_plots(nuts_result, param_names=['D0', 'alpha', 'D_offset']) >>> # CMC result with multiple shards >>> plot_trace_plots(cmc_result, save_path='traces_cmc.png') Notes ----- - For NUTS: Single trace line per parameter - For CMC: Multiple colored lines (one per shard) - X-axis: Sample index (after warmup) - Y-axis: Parameter value - Good mixing: trace should look like "hairy caterpillar" - Poor mixing: trace shows trends or gets stuck """ # Extract samples if not hasattr(result, "samples_params") or result.samples_params is None: logger.warning("No parameter samples available for trace plots") fig = _create_empty_figure("No samples available") if not show: plt.close(fig) return fig samples = result.samples_params num_params = samples.shape[-1] if samples.ndim >= 2 else 1 # Limit number of parameters to plot num_params_to_plot = min(num_params, max_params) # Generate parameter names if not provided if param_names is None: if getattr(result, "analysis_mode", None) == "static": param_names = ["D0", "alpha", "D_offset"] elif getattr(result, "analysis_mode", None) == "laminar_flow": param_names = [ "D0", "alpha", "D_offset", "gamma_dot_t0", "beta", "gamma_dot_t_offset", "phi0", ] else: param_names = [f"param_{i}" for i in range(num_params)] # Ensure we have enough parameter names if len(param_names) < num_params_to_plot: param_names.extend( [f"param_{i}" for i in range(len(param_names), num_params_to_plot)] ) # Calculate figure layout ncols = min(3, num_params_to_plot) nrows = (num_params_to_plot + ncols - 1) // ncols if figsize is None: figsize = (5 * ncols, 3 * nrows) fig, axes = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False) axes = axes.flatten() # Check if this is a CMC result is_cmc = result.is_cmc_result() if hasattr(result, "is_cmc_result") else False if is_cmc and result.per_shard_diagnostics is not None: # CMC: Plot multiple shard traces num_shards = len(result.per_shard_diagnostics) colors = matplotlib.colormaps["tab10"](np.linspace(0, 1, num_shards)) for param_idx in range(num_params_to_plot): ax = axes[param_idx] # Plot trace for each shard for shard_idx, shard_diag in enumerate(result.per_shard_diagnostics): if "trace_data" in shard_diag: trace_key = f"param_{param_idx}" if trace_key in shard_diag["trace_data"]: trace = np.array(shard_diag["trace_data"][trace_key]) # Handle multi-chain traces (flatten if needed) if trace.ndim == 2: # Multi-chain: plot all chains for chain_idx in range(trace.shape[0]): ax.plot( trace[chain_idx, :], color=colors[shard_idx], alpha=0.7, linewidth=0.5, label=( f"Shard {shard_idx}" if chain_idx == 0 else "" ), ) else: # Single chain ax.plot( trace, color=colors[shard_idx], alpha=0.7, linewidth=0.8, label=f"Shard {shard_idx}", ) ax.set_xlabel("Sample Index") ax.set_ylabel(param_names[param_idx]) ax.set_title(f"{param_names[param_idx]} Trace (CMC)") ax.grid(True, alpha=0.3) # Add legend for first subplot only (to avoid clutter) if param_idx == 0 and num_shards <= 10: ax.legend(loc="upper right", fontsize=8, ncol=2) else: # Standard NUTS: Plot single trace (possibly multi-chain) if samples.ndim == 1: # Single parameter, single chain samples = samples.reshape(-1, 1) elif samples.ndim == 2: # (num_samples, num_params) - already correct shape pass elif samples.ndim == 3: # (num_chains, num_samples, num_params) - flatten chains num_chains, num_samples, num_params_actual = samples.shape samples = samples.reshape(num_chains * num_samples, num_params_actual) for param_idx in range(num_params_to_plot): ax = axes[param_idx] trace = samples[:, param_idx] ax.plot(trace, linewidth=0.5, alpha=0.8, color="steelblue") ax.set_xlabel("Sample Index") ax.set_ylabel(param_names[param_idx]) ax.set_title(f"{param_names[param_idx]} Trace") ax.grid(True, alpha=0.3) # Hide unused subplots for idx in range(num_params_to_plot, len(axes)): axes[idx].set_visible(False) # Add overall title title = "MCMC Trace Plots" if is_cmc: title += f" (CMC: {result.num_shards} shards)" fig.suptitle(title, fontsize=14, fontweight="bold", y=0.995) plt.tight_layout() # Save or show if save_path is not None: try: validated_path = validate_plot_save_path(save_path) if validated_path is not None: fig.savefig(validated_path, dpi=dpi, bbox_inches="tight") logger.info(f"Trace plots saved to {validated_path.name}") except (PathValidationError, ValueError) as e: logger.warning(f"Could not save trace plots: {e}") if show: plt.show() elif save_path is not None: plt.close(fig) # Note: when show=False and save_path=None, caller owns the figure and must close it. return fig
[docs] def plot_kl_divergence_matrix( result: Any, # MCMCResult type figsize: tuple[float, float] = (8, 7), cmap: str = "coolwarm", threshold: float = 2.0, show: bool = False, save_path: str | Path | None = None, dpi: int = 150, ) -> Figure: """Plot KL divergence matrix heatmap for CMC results. Visualizes pairwise KL divergence between shards to assess posterior agreement. High KL divergence (>2.0) indicates shards found different posteriors. Parameters ---------- result : MCMCResult CMC result object with cmc_diagnostics containing KL matrix figsize : tuple, default=(8, 7) Figure size (width, height) cmap : str, default='coolwarm' Matplotlib colormap name (coolwarm shows cool=low, warm=high KL) threshold : float, default=2.0 KL divergence threshold to highlight (standard: 2.0) show : bool, default=False If True, display the figure interactively save_path : str or Path, optional If provided, save figure to this path dpi : int, default=150 DPI for saved figure Returns ------- Figure Matplotlib figure object Raises ------ ValueError If result is not a CMC result or KL matrix not available Examples -------- >>> plot_kl_divergence_matrix(cmc_result, threshold=2.0, save_path='kl_matrix.png') Notes ----- - Diagonal elements are 0.0 (self-divergence) - Matrix is symmetric by construction (averaged KL) - Values < 0.5: Excellent agreement - Values 0.5-2.0: Acceptable agreement - Values > 2.0: Poor agreement (possible multimodality) """ # Check if CMC result is_cmc = result.is_cmc_result() if hasattr(result, "is_cmc_result") else False if not is_cmc: raise ValueError("KL divergence matrix is only available for CMC results") # Extract KL matrix from diagnostics if result.cmc_diagnostics is None or "kl_matrix" not in result.cmc_diagnostics: raise ValueError("KL divergence matrix not found in CMC diagnostics") kl_matrix = np.array(result.cmc_diagnostics["kl_matrix"]) num_shards = kl_matrix.shape[0] # Create figure fig, ax = plt.subplots(figsize=figsize) # Plot heatmap kl_max = float(np.nanmax(kl_matrix)) if np.any(np.isfinite(kl_matrix)) else 0.0 vmax = max(threshold * 1.5, kl_max) im = ax.imshow( kl_matrix, cmap=cmap, aspect="auto", vmin=0, vmax=vmax, ) # Add colorbar cbar = plt.colorbar(im, ax=ax, label="KL Divergence") # Add threshold line on colorbar cbar.ax.axhline( y=threshold, color="red", linestyle="--", linewidth=2, label=f"Threshold ({threshold})", ) # Annotate cells with KL values. For large matrices (>20 shards) text # annotations would be illegible and O(n^2) ax.text() calls degrade render # performance significantly; skip text but keep the red-border highlights # for problematic off-diagonal cells, which remain readable at any scale. _ANNOTATION_THRESHOLD = 20 _annotate = num_shards <= _ANNOTATION_THRESHOLD for i in range(num_shards): for j in range(num_shards): kl_val = kl_matrix[i, j] if _annotate: # Choose text color based on background text_color = "white" if kl_val > threshold else "black" _text = ax.text( # noqa: F841 - Text object kept for reference j, i, f"{kl_val:.2f}", ha="center", va="center", color=text_color, fontsize=9, ) # Highlight problematic off-diagonal shards (KL > threshold) if kl_val > threshold and i != j: ax.add_patch( plt.Rectangle( (j - 0.5, i - 0.5), 1, 1, fill=False, edgecolor="red", linewidth=2, ) ) # Set labels and title ax.set_xticks(np.arange(num_shards)) ax.set_yticks(np.arange(num_shards)) ax.set_xticklabels([f"S{i}" for i in range(num_shards)]) ax.set_yticklabels([f"S{i}" for i in range(num_shards)]) ax.set_xlabel("Shard Index") ax.set_ylabel("Shard Index") ax.set_title( f"Between-Shard KL Divergence Matrix\n({num_shards} shards, threshold={threshold})", fontweight="bold", ) # Add grid ax.set_xticks(np.arange(num_shards) - 0.5, minor=True) ax.set_yticks(np.arange(num_shards) - 0.5, minor=True) ax.grid(which="minor", color="gray", linestyle="-", linewidth=0.5) plt.tight_layout() # Save or show if save_path is not None: try: validated_path = validate_plot_save_path(save_path) if validated_path is not None: fig.savefig(validated_path, dpi=dpi, bbox_inches="tight") logger.info(f"KL divergence matrix saved to {validated_path.name}") except (PathValidationError, ValueError) as e: logger.warning(f"Could not save KL divergence matrix: {e}") if show: plt.show() elif save_path is not None: plt.close(fig) return fig
[docs] def plot_convergence_diagnostics( result: Any, # MCMCResult type metrics: list[str] | None = None, figsize: tuple[float, float] | None = None, rhat_threshold: float = 1.1, ess_threshold: float = 400.0, show: bool = False, save_path: str | Path | None = None, dpi: int = 150, ) -> Figure: """Plot convergence diagnostics (R-hat and ESS) for MCMC results. Visualizes convergence metrics to assess MCMC sampling quality. For CMC results, shows per-shard and combined diagnostics. Parameters ---------- result : MCMCResult MCMC or CMC result object with convergence diagnostics metrics : list of str, optional Metrics to plot. Options: 'rhat', 'ess'. Defaults to ['rhat', 'ess'] figsize : tuple, optional Figure size (width, height). If None, auto-calculated rhat_threshold : float, default=1.1 R-hat threshold for convergence (standard: 1.1) ess_threshold : float, default=400.0 ESS threshold for adequate sampling (standard: 400) show : bool, default=False If True, display the figure interactively save_path : str or Path, optional If provided, save figure to this path dpi : int, default=150 DPI for saved figure Returns ------- Figure Matplotlib figure object Examples -------- >>> plot_convergence_diagnostics(result, metrics=['rhat', 'ess']) >>> # CMC result with per-shard diagnostics >>> plot_convergence_diagnostics(cmc_result, save_path='convergence.png') Notes ----- - R-hat < 1.1: Converged (good) - R-hat > 1.1: Not converged (bad) - ESS > 100: Adequate sampling (good) - ESS < 100: Poor sampling efficiency (bad) """ # Set default metrics if not provided if metrics is None: metrics = ["rhat", "ess"] # Check if CMC result is_cmc = result.is_cmc_result() if hasattr(result, "is_cmc_result") else False # Determine number of subplots num_metrics = len(metrics) if figsize is None: figsize = (10, 4 * num_metrics) fig, axes = plt.subplots(num_metrics, 1, figsize=figsize, squeeze=False) axes = axes.flatten() # Extract parameter names (guard against None mean_params from failed CMC runs) num_params = len(result.mean_params) if result.mean_params is not None else 0 if result.analysis_mode == "static": param_names = ["D0", "alpha", "D_offset"][:num_params] elif result.analysis_mode == "laminar_flow": param_names = [ "D0", "alpha", "D_offset", "gamma_dot_t0", "beta", "gamma_dot_t_offset", "phi0", ][:num_params] else: param_names = [f"param_{i}" for i in range(num_params)] # Plot each metric for metric_idx, metric in enumerate(metrics): ax = axes[metric_idx] if metric == "rhat": _plot_rhat(ax, result, param_names, rhat_threshold, is_cmc) elif metric == "ess": _plot_ess(ax, result, param_names, ess_threshold, is_cmc) else: logger.warning(f"Unknown metric: {metric}") ax.text( 0.5, 0.5, f"Unknown metric: {metric}", ha="center", va="center", transform=ax.transAxes, ) # Add overall title title = "MCMC Convergence Diagnostics" if is_cmc: title += f" (CMC: {result.num_shards} shards)" fig.suptitle(title, fontsize=14, fontweight="bold", y=0.995) plt.tight_layout() # Save or show if save_path is not None: try: validated_path = validate_plot_save_path(save_path) if validated_path is not None: fig.savefig(validated_path, dpi=dpi, bbox_inches="tight") logger.info(f"Convergence diagnostics saved to {validated_path.name}") except (PathValidationError, ValueError) as e: logger.warning(f"Could not save convergence diagnostics: {e}") if show: plt.show() elif save_path is not None: plt.close(fig) return fig
def _plot_rhat( ax: Any, result: Any, param_names: list[str], threshold: float, is_cmc: bool ) -> None: """Helper function to plot R-hat diagnostics.""" if is_cmc and result.per_shard_diagnostics is not None: # CMC: Plot per-shard R-hat values num_shards = len(result.per_shard_diagnostics) num_params = len(param_names) # Collect R-hat values for each parameter across shards rhat_matrix = np.full((num_shards, num_params), np.nan) for shard_idx, shard_diag in enumerate(result.per_shard_diagnostics): if "rhat" in shard_diag and shard_diag["rhat"] is not None: for param_idx in range(num_params): param_key = f"param_{param_idx}" if param_key in shard_diag["rhat"]: rhat_matrix[shard_idx, param_idx] = shard_diag["rhat"][ param_key ] # Plot as grouped bar chart x = np.arange(num_params) width = 0.8 / num_shards if num_shards < 10 else 0.8 / 10 for shard_idx in range( min(num_shards, 10) ): # Limit to first 10 shards for clarity offset = (shard_idx - num_shards / 2) * width values = rhat_matrix[shard_idx, :] # Color based on convergence colors = ["green" if v < threshold else "red" for v in values] ax.bar( x + offset, values, width, label=f"Shard {shard_idx}", color=colors, alpha=0.7, ) ax.axhline( y=threshold, color="red", linestyle="--", linewidth=2, label=f"Threshold ({threshold})", ) ax.set_xlabel("Parameter") ax.set_ylabel("R-hat") ax.set_title("R-hat Convergence Diagnostic (per shard)") ax.set_xticks(x) ax.set_xticklabels(param_names, rotation=45, ha="right") ax.legend(loc="upper right", fontsize=8, ncol=2) ax.grid(True, alpha=0.3, axis="y") else: # Standard NUTS: Plot combined R-hat if result.r_hat is None: ax.text( 0.5, 0.5, "R-hat not available\n(requires multiple chains)", ha="center", va="center", transform=ax.transAxes, ) return # Extract R-hat values rhat_values = [] for param_name in param_names: # Try different key formats for key in [param_name, param_name.lower(), param_name.replace("_", "")]: if key in result.r_hat: rhat_values.append(result.r_hat[key]) break else: rhat_values.append(np.nan) # Plot as bar chart x = np.arange(len(param_names)) colors = ["green" if v < threshold else "red" for v in rhat_values] ax.bar(x, rhat_values, color=colors, alpha=0.7) ax.axhline( y=threshold, color="red", linestyle="--", linewidth=2, label=f"Threshold ({threshold})", ) ax.set_xlabel("Parameter") ax.set_ylabel("R-hat") ax.set_title("R-hat Convergence Diagnostic") ax.set_xticks(x) ax.set_xticklabels(param_names, rotation=45, ha="right") ax.legend() ax.grid(True, alpha=0.3, axis="y") def _plot_ess( ax: Any, result: Any, param_names: list[str], threshold: float, is_cmc: bool ) -> None: """Helper function to plot ESS diagnostics.""" if is_cmc and result.per_shard_diagnostics is not None: # CMC: Plot per-shard ESS values num_shards = len(result.per_shard_diagnostics) num_params = len(param_names) # Collect ESS values for each parameter across shards ess_matrix = np.full((num_shards, num_params), np.nan) for shard_idx, shard_diag in enumerate(result.per_shard_diagnostics): if "ess" in shard_diag and shard_diag["ess"] is not None: for param_idx in range(num_params): param_key = f"param_{param_idx}" if param_key in shard_diag["ess"]: ess_matrix[shard_idx, param_idx] = shard_diag["ess"][param_key] # Plot as grouped bar chart x = np.arange(num_params) width = 0.8 / num_shards if num_shards < 10 else 0.8 / 10 for shard_idx in range(min(num_shards, 10)): # Limit to first 10 shards offset = (shard_idx - num_shards / 2) * width values = ess_matrix[shard_idx, :] # Color based on adequacy colors = ["green" if v > threshold else "orange" for v in values] ax.bar( x + offset, values, width, label=f"Shard {shard_idx}", color=colors, alpha=0.7, ) ax.axhline( y=threshold, color="red", linestyle="--", linewidth=2, label=f"Threshold ({threshold})", ) ax.set_xlabel("Parameter") ax.set_ylabel("Effective Sample Size (ESS)") ax.set_title("ESS Diagnostic (per shard)") ax.set_xticks(x) ax.set_xticklabels(param_names, rotation=45, ha="right") ax.legend(loc="upper right", fontsize=8, ncol=2) ax.grid(True, alpha=0.3, axis="y") else: # Standard NUTS: Plot combined ESS if result.effective_sample_size is None: ax.text( 0.5, 0.5, "ESS not available", ha="center", va="center", transform=ax.transAxes, ) return # Extract ESS values ess_values = [] for param_name in param_names: # Try different key formats for key in [param_name, param_name.lower(), param_name.replace("_", "")]: if key in result.effective_sample_size: ess_values.append(result.effective_sample_size[key]) break else: ess_values.append(np.nan) # Plot as bar chart x = np.arange(len(param_names)) colors = ["green" if v > threshold else "orange" for v in ess_values] ax.bar(x, ess_values, color=colors, alpha=0.7) ax.axhline( y=threshold, color="red", linestyle="--", linewidth=2, label=f"Threshold ({threshold})", ) ax.set_xlabel("Parameter") ax.set_ylabel("Effective Sample Size (ESS)") ax.set_title("ESS Diagnostic") ax.set_xticks(x) ax.set_xticklabels(param_names, rotation=45, ha="right") ax.legend() ax.grid(True, alpha=0.3, axis="y")
[docs] def compute_bfmi( energy: np.ndarray, *, per_chain: bool = False, ) -> float | np.ndarray: """Compute Bayesian Fraction of Missing Information (BFMI). BFMI measures how well the kinetic energy in HMC/NUTS matches the marginal energy distribution. Low BFMI (< 0.3) indicates the sampler is struggling with the posterior geometry (e.g., funnel shapes, poor mass matrix adaptation). Parameters ---------- energy : np.ndarray Potential energy trace. Shape: (n_samples,) for single chain, or (n_chains, n_samples) for multi-chain. per_chain : bool, default=False If True and energy is multi-chain, return BFMI per chain. If False, return pooled BFMI across all chains. Returns ------- float or np.ndarray BFMI value(s). Range [0, inf), typically [0, 2]. - < 0.3: Poor energy transitions (flag for review) - 0.3 - 1.0: Acceptable - > 1.0: Good energy mixing """ energy = np.asarray(energy, dtype=np.float64) if energy.ndim == 1: energy_diff = np.diff(energy) mean_sq_diff = np.nanmean(energy_diff**2) var_energy = np.nanvar(energy) if var_energy == 0.0: return float("nan") return float(mean_sq_diff / var_energy) if energy.ndim == 2: if per_chain: bfmi_per_chain = np.empty(energy.shape[0]) for i in range(energy.shape[0]): chain_energy = energy[i] energy_diff = np.diff(chain_energy) mean_sq_diff = np.nanmean(energy_diff**2) var_energy = np.nanvar(chain_energy) if var_energy == 0.0: bfmi_per_chain[i] = float("nan") else: bfmi_per_chain[i] = mean_sq_diff / var_energy return bfmi_per_chain else: # Pool all chains all_diffs = [] all_energies = [] for i in range(energy.shape[0]): all_diffs.append(np.diff(energy[i])) all_energies.append(energy[i]) pooled_diffs = np.concatenate(all_diffs) pooled_energies = np.concatenate(all_energies) mean_sq_diff = np.nanmean(pooled_diffs**2) var_energy = np.nanvar(pooled_energies) if var_energy == 0.0: return float("nan") return float(mean_sq_diff / var_energy) raise ValueError(f"Expected 1D or 2D energy array, got shape {energy.shape}")