Source code for homodyne.viz.mcmc_arviz

"""ArviZ-Based MCMC Plotting Functions.

Provides ArviZ-powered trace, posterior, and pair plots for MCMC/CMC results.
Falls back gracefully when ArviZ is not installed.
"""

from __future__ import annotations

from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
from matplotlib.figure import Figure

from homodyne.utils.logging import get_logger
from homodyne.utils.path_validation import PathValidationError, validate_plot_save_path

logger = get_logger(__name__)


def _create_empty_figure(message: str) -> Figure:
    """Create an empty figure with a message."""
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.text(
        0.5, 0.5, message, ha="center", va="center", transform=ax.transAxes, fontsize=14
    )
    ax.set_xticks([])
    ax.set_yticks([])
    return fig


[docs] def plot_arviz_trace( result: Any, var_names: list[str] | None = None, figsize: tuple[float, float] | None = None, show: bool = False, save_path: str | Path | None = None, dpi: int = 150, **kwargs: Any, ) -> Figure: """Plot MCMC trace plots using ArviZ. Creates trace plots showing parameter evolution and posterior distributions side by side for each parameter. Parameters ---------- result : MCMCResult MCMC or CMC result object containing samples. var_names : list of str, optional Parameter names to plot. If None, plots all parameters. figsize : tuple, optional Figure size (width, height). If None, auto-calculated. show : bool, default=False If True, display the figure interactively. save_path : str or Path, optional If provided, save figure to this path. dpi : int, default=150 DPI for saved figure. **kwargs Additional arguments passed to az.plot_trace(). Returns ------- Figure Matplotlib figure object. Examples -------- >>> plot_arviz_trace(result, var_names=["D0", "alpha", "D_offset"]) >>> plot_arviz_trace(result, save_path="traces.png") """ try: import arviz as az except ImportError: logger.warning("ArviZ not available. Falling back to custom trace plots.") from homodyne.viz.mcmc_diagnostics import plot_trace_plots return plot_trace_plots(result, show=show, save_path=save_path, dpi=dpi) if result.samples_params is None: logger.warning("No parameter samples available for trace plots") fig = _create_empty_figure("No samples available") if not show: plt.close(fig) return fig # Convert to ArviZ InferenceData idata = result.to_arviz() # Plot traces axes = az.plot_trace(idata, var_names=var_names, figsize=figsize, **kwargs) fig = axes.ravel()[0].figure # Add title title = "MCMC Trace Plots (ArviZ)" if result.is_cmc_result(): title += f" - CMC ({result.num_shards} shards)" fig.suptitle(title, fontsize=14, fontweight="bold", y=1.02) plt.tight_layout() # Save or show if save_path is not None: try: validated_path = validate_plot_save_path(save_path) if validated_path is not None: fig.savefig(validated_path, dpi=dpi, bbox_inches="tight") logger.info(f"ArviZ trace plots saved to {validated_path.name}") except (PathValidationError, ValueError) as e: logger.warning(f"Could not save ArviZ trace plots: {e}") if show: plt.show() elif save_path is not None: plt.close(fig) return fig # type: ignore[no-any-return]
[docs] def plot_arviz_posterior( result: Any, var_names: list[str] | None = None, hdi_prob: float = 0.95, figsize: tuple[float, float] | None = None, show: bool = False, save_path: str | Path | None = None, dpi: int = 150, **kwargs: Any, ) -> Figure: """Plot posterior distributions with 95% credible intervals using ArviZ. Parameters ---------- result : MCMCResult MCMC or CMC result object containing samples. var_names : list of str, optional Parameter names to plot. If None, plots all parameters. hdi_prob : float, default=0.95 Highest density interval probability (e.g., 0.95 for 95% HDI). figsize : tuple, optional Figure size (width, height). If None, auto-calculated. show : bool, default=False If True, display the figure interactively. save_path : str or Path, optional If provided, save figure to this path. dpi : int, default=150 DPI for saved figure. **kwargs Additional arguments passed to az.plot_posterior(). Returns ------- Figure Matplotlib figure object. Examples -------- >>> plot_arviz_posterior(result, var_names=["D0", "alpha", "D_offset"]) >>> plot_arviz_posterior(result, hdi_prob=0.90) # 90% CI """ try: import arviz as az except ImportError: logger.warning("ArviZ not available. Cannot create posterior plots.") fig = _create_empty_figure("ArviZ not installed") if not show: plt.close(fig) return fig if result.samples_params is None: logger.warning("No parameter samples available for posterior plots") fig = _create_empty_figure("No samples available") if not show: plt.close(fig) return fig # Convert to ArviZ InferenceData idata = result.to_arviz() # Plot posteriors with HDI axes = az.plot_posterior( idata, var_names=var_names, hdi_prob=hdi_prob, figsize=figsize, **kwargs ) # Handle both single and multi-panel cases if hasattr(axes, "ravel"): fig = axes.ravel()[0].figure else: fig = axes.figure # Add title title = f"Posterior Distributions ({int(hdi_prob * 100)}% HDI)" if result.is_cmc_result(): title += f" - CMC ({result.num_shards} shards)" fig.suptitle(title, fontsize=14, fontweight="bold", y=1.02) plt.tight_layout() # Save or show if save_path is not None: try: validated_path = validate_plot_save_path(save_path) if validated_path is not None: fig.savefig(validated_path, dpi=dpi, bbox_inches="tight") logger.info(f"ArviZ posterior plots saved to {validated_path.name}") except (PathValidationError, ValueError) as e: logger.warning(f"Could not save ArviZ posterior plots: {e}") if show: plt.show() elif save_path is not None: plt.close(fig) return fig # type: ignore[no-any-return]
[docs] def plot_arviz_pair( result: Any, var_names: list[str] | None = None, figsize: tuple[float, float] | None = None, show: bool = False, save_path: str | Path | None = None, dpi: int = 150, **kwargs: Any, ) -> Figure: """Plot pair plots showing parameter correlations using ArviZ. Parameters ---------- result : MCMCResult MCMC or CMC result object containing samples. var_names : list of str, optional Parameter names to plot. If None, plots physical parameters only. figsize : tuple, optional Figure size (width, height). If None, auto-calculated. show : bool, default=False If True, display the figure interactively. save_path : str or Path, optional If provided, save figure to this path. dpi : int, default=150 DPI for saved figure. **kwargs Additional arguments passed to az.plot_pair(). Returns ------- Figure Matplotlib figure object. Examples -------- >>> plot_arviz_pair(result, var_names=["D0", "alpha", "D_offset"]) >>> plot_arviz_pair(result) # Auto-selects physical parameters """ try: import arviz as az except ImportError: logger.warning("ArviZ not available. Cannot create pair plots.") fig = _create_empty_figure("ArviZ not installed") if not show: plt.close(fig) return fig if result.samples_params is None: logger.warning("No parameter samples available for pair plots") fig = _create_empty_figure("No samples available") if not show: plt.close(fig) return fig # Convert to ArviZ InferenceData idata = result.to_arviz() # Default to physical parameters only (exclude per-angle scaling) if var_names is None: if result.analysis_mode == "static": var_names = ["D0", "alpha", "D_offset"] elif result.analysis_mode == "laminar_flow": var_names = [ "D0", "alpha", "D_offset", "gamma_dot_t0", "beta", "gamma_dot_t_offset", "phi0", ] # Plot pair plots axes = az.plot_pair( idata, var_names=var_names, figsize=figsize, kind="kde", marginals=True, **kwargs, ) # Get figure if hasattr(axes, "ravel"): fig = axes.ravel()[0].figure elif hasattr(axes, "figure"): fig = axes.figure else: fig = plt.gcf() # Add title title = "Parameter Correlations" if result.is_cmc_result(): title += f" - CMC ({result.num_shards} shards)" fig.suptitle(title, fontsize=14, fontweight="bold", y=1.02) plt.tight_layout() # Save or show if save_path is not None: try: validated_path = validate_plot_save_path(save_path) if validated_path is not None: fig.savefig(validated_path, dpi=dpi, bbox_inches="tight") logger.info(f"ArviZ pair plots saved to {validated_path.name}") except (PathValidationError, ValueError) as e: logger.warning(f"Could not save ArviZ pair plots: {e}") if show: plt.show() elif save_path is not None: plt.close(fig) return fig