Source code for homodyne.viz.mcmc_dashboard

"""MCMC Summary Dashboard Visualization.

Provides a comprehensive multi-panel CMC summary dashboard combining
KL divergence, convergence diagnostics, trace plots, and posterior histograms.
"""

from __future__ import annotations

from pathlib import Path
from typing import Any

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.figure import Figure

from homodyne.utils.logging import get_logger
from homodyne.utils.path_validation import PathValidationError, validate_plot_save_path

logger = get_logger(__name__)


[docs] def plot_cmc_summary_dashboard( result: Any, # MCMCResult type figsize: tuple[float, float] = (16, 12), show: bool = False, save_path: str | Path | None = None, dpi: int = 150, ) -> Figure: """Create comprehensive multi-panel CMC summary dashboard. Combines all diagnostic plots into a single comprehensive figure: - Panel 1: KL divergence matrix - Panel 2: Convergence diagnostics (R-hat, ESS) - Panel 3: Trace plots (selected parameters) - Panel 4: Posterior comparison Parameters ---------- result : MCMCResult CMC result object figsize : tuple, default=(16, 12) Figure size (width, height) show : bool, default=False If True, display the figure interactively save_path : str or Path, optional If provided, save figure to this path dpi : int, default=150 DPI for saved figure Returns ------- Figure Matplotlib figure object Raises ------ ValueError If result is not a CMC result Examples -------- >>> plot_cmc_summary_dashboard(cmc_result, save_path='cmc_summary.png') Notes ----- This is the primary diagnostic tool for CMC results. It provides a comprehensive overview of convergence, agreement between shards, and posterior quality in a single figure. """ # Check if CMC result is_cmc = result.is_cmc_result() if hasattr(result, "is_cmc_result") else False if not is_cmc: raise ValueError("Summary dashboard is only available for CMC results") # Create figure with custom layout fig = plt.figure(figsize=figsize) gs = fig.add_gridspec(3, 2, hspace=0.3, wspace=0.3) # Panel 1: KL divergence matrix (top left) ax_kl = fig.add_subplot(gs[0, 0]) try: if result.cmc_diagnostics is not None and "kl_matrix" in result.cmc_diagnostics: kl_matrix = np.array(result.cmc_diagnostics["kl_matrix"]) num_shards = kl_matrix.shape[0] threshold = 2.0 _kl_max = ( float(np.nanmax(kl_matrix)) if np.any(np.isfinite(kl_matrix)) else 0.0 ) im = ax_kl.imshow( kl_matrix, cmap="coolwarm", aspect="auto", vmin=0, vmax=max(threshold * 1.5, _kl_max), ) plt.colorbar(im, ax=ax_kl, label="KL Divergence") # Annotate cells (skip text for >20 shards -- O(n^2) and illegible) _annotate = num_shards <= 20 for i in range(num_shards): for j in range(num_shards): kl_val = kl_matrix[i, j] if _annotate: text_color = "white" if kl_val > threshold else "black" ax_kl.text( j, i, f"{kl_val:.2f}", ha="center", va="center", color=text_color, fontsize=7, ) ax_kl.set_xticks(np.arange(num_shards)) ax_kl.set_yticks(np.arange(num_shards)) ax_kl.set_xticklabels([f"S{i}" for i in range(num_shards)], fontsize=8) ax_kl.set_yticklabels([f"S{i}" for i in range(num_shards)], fontsize=8) ax_kl.set_xlabel("Shard Index", fontsize=9) ax_kl.set_ylabel("Shard Index", fontsize=9) ax_kl.set_title("KL Divergence Matrix", fontsize=10, fontweight="bold") else: ax_kl.text( 0.5, 0.5, "KL matrix not available", ha="center", va="center", transform=ax_kl.transAxes, ) except (ValueError, TypeError, KeyError, IndexError) as e: ax_kl.text( 0.5, 0.5, f"Error plotting KL matrix:\n{str(e)}", ha="center", va="center", transform=ax_kl.transAxes, ) # Panel 2: Convergence diagnostics (top right) ax_conv = fig.add_subplot(gs[0, 1]) try: # Plot ESS for all parameters if result.per_shard_diagnostics is not None: num_shards = len(result.per_shard_diagnostics) # Define canonical parameter names first, then derive count if result.analysis_mode == "static": param_names = ["D0", "alpha", "D_offset"] elif result.analysis_mode == "laminar_flow": param_names = [ "D0", "alpha", "D_offset", "gamma_dot_t0", "beta", "gamma_dot_t_offset", "phi0", ] else: _count = ( len(result.mean_params) if result.mean_params is not None else 0 ) param_names = [f"P{i}" for i in range(_count)] num_params = len(param_names) # Collect ESS values ess_list: list[list[float]] = [] for shard_diag in result.per_shard_diagnostics: if "ess" in shard_diag and shard_diag["ess"] is not None: ess_vals = [ shard_diag["ess"].get(f"param_{i}", np.nan) for i in range(num_params) ] ess_list.append(ess_vals) if ess_list: ess_matrix = np.array(ess_list) # Plot as box plot positions = np.arange(num_params) import matplotlib as _mpl _mpl_ver = tuple(int(x) for x in _mpl.__version__.split(".")[:2]) _bp_kwargs: dict = { "positions": positions, "patch_artist": True, } if _mpl_ver >= (3, 9): _bp_kwargs["tick_labels"] = param_names else: _bp_kwargs["labels"] = param_names bp = ax_conv.boxplot( [ess_matrix[:, i] for i in range(num_params)], **_bp_kwargs, ) # Color boxes for patch in bp["boxes"]: patch.set_facecolor("lightblue") ax_conv.axhline( y=400, color="red", linestyle="--", linewidth=2, label="ESS threshold (400)", ) ax_conv.set_ylabel("Effective Sample Size", fontsize=9) ax_conv.set_title( "ESS Distribution Across Shards", fontsize=10, fontweight="bold" ) ax_conv.legend(fontsize=8) ax_conv.grid(True, alpha=0.3, axis="y") else: ax_conv.text( 0.5, 0.5, "Convergence diagnostics not available", ha="center", va="center", transform=ax_conv.transAxes, ) except (ValueError, TypeError, KeyError, IndexError) as e: ax_conv.text( 0.5, 0.5, f"Error plotting convergence:\n{str(e)}", ha="center", va="center", transform=ax_conv.transAxes, ) # Panel 3: Trace plots for up to 2 parameters (middle row, 2 columns) _n_params_total = len(result.mean_params) if result.mean_params is not None else 0 num_trace_params = min(2, _n_params_total) for i in range(num_trace_params): ax_trace = fig.add_subplot(gs[1, i]) try: # Plot traces for this parameter if result.per_shard_diagnostics is not None: num_shards = len(result.per_shard_diagnostics) colors = matplotlib.colormaps["tab10"](np.linspace(0, 1, num_shards)) for shard_idx, shard_diag in enumerate(result.per_shard_diagnostics): if "trace_data" in shard_diag: trace_key = f"param_{i}" if trace_key in shard_diag["trace_data"]: trace = np.array(shard_diag["trace_data"][trace_key]) if trace.ndim == 2: trace = trace[0, :] # Use first chain only ax_trace.plot( trace, color=colors[shard_idx], alpha=0.6, linewidth=0.5 ) if result.analysis_mode == "static": param_names = ["D0", "alpha", "D_offset"] else: param_names = [f"param_{i}" for i in range(len(result.mean_params))] ax_trace.set_xlabel("Sample Index", fontsize=9) ax_trace.set_ylabel(param_names[i], fontsize=9) ax_trace.set_title( f"{param_names[i]} Trace", fontsize=10, fontweight="bold" ) ax_trace.grid(True, alpha=0.3) except (ValueError, TypeError, KeyError, IndexError) as e: ax_trace.text( 0.5, 0.5, f"Error:\n{str(e)}", ha="center", va="center", transform=ax_trace.transAxes, ) # Panel 4: Posterior histograms (bottom row) _n_hist_params_total = ( len(result.mean_params) if result.mean_params is not None else 0 ) num_hist_params = min(2, _n_hist_params_total) for i in range(num_hist_params): ax_hist = fig.add_subplot(gs[2, i]) try: # Plot posterior distribution if result.samples_params is not None: combined_samples = result.samples_params[:, i] ax_hist.hist( combined_samples, bins=30, alpha=0.7, color="steelblue", density=True, ) # Add vertical line for mean mean_val = result.mean_params[i] ax_hist.axvline( mean_val, color="red", linestyle="--", linewidth=2, label=f"Mean: {mean_val:.2f}", ) if result.analysis_mode == "static": param_names = ["D0", "alpha", "D_offset"] else: param_names = [f"param_{i}" for i in range(len(result.mean_params))] ax_hist.set_xlabel(param_names[i], fontsize=9) ax_hist.set_ylabel("Density", fontsize=9) ax_hist.set_title( f"{param_names[i]} Posterior", fontsize=10, fontweight="bold" ) ax_hist.legend(fontsize=8) except (ValueError, TypeError, KeyError, IndexError) as e: ax_hist.text( 0.5, 0.5, f"Error:\n{str(e)}", ha="center", va="center", transform=ax_hist.transAxes, ) # Add overall title fig.suptitle( f"CMC Summary Dashboard ({result.num_shards} shards, {result.analysis_mode})", fontsize=14, fontweight="bold", ) # Save or show if save_path is not None: try: validated_path = validate_plot_save_path(save_path) if validated_path is not None: fig.savefig(validated_path, dpi=dpi, bbox_inches="tight") logger.info(f"CMC summary dashboard saved to {validated_path.name}") except (PathValidationError, ValueError) as e: logger.warning(f"Could not save CMC summary dashboard: {e}") if show: plt.show() elif save_path is not None: plt.close(fig) return fig