Source code for homodyne.optimization.cmc.plotting

"""ArviZ diagnostic plots for CMC results.

This module provides the 6 standard ArviZ diagnostic plots:
1. Pair plot (corner plot)
2. Forest plot
3. Energy plot
4. Autocorrelation plot
5. Rank plot
6. ESS plot
"""

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING

try:
    import arviz as az

    HAS_ARVIZ = True
except ImportError:
    HAS_ARVIZ = False
    az = None  # type: ignore

import matplotlib.pyplot as plt

from homodyne.utils.logging import get_logger

if TYPE_CHECKING:
    from homodyne.optimization.cmc.results import CMCResult

logger = get_logger(__name__)


def _is_scaling_var(name: str) -> bool:
    """Check if a variable name is a per-angle scaling parameter.

    Derives identity from the registry's ``is_scaling`` flag.
    """
    from homodyne.config.parameter_registry import ParameterRegistry

    return any(name.startswith(f"{s}_") for s in ParameterRegistry().scaling_names)


# Default figure settings
DEFAULT_FIGSIZE = (12, 10)
DEFAULT_DPI = 150


[docs] def generate_diagnostic_plots( result: CMCResult, output_dir: Path, figsize: tuple[int, int] = DEFAULT_FIGSIZE, dpi: int = DEFAULT_DPI, param_subset: list[str] | None = None, ) -> list[Path]: """Generate all 6 ArviZ diagnostic plots. Parameters ---------- result : CMCResult CMC result with inference_data. output_dir : Path Directory to save plots. figsize : tuple[int, int] Figure size in inches. dpi : int Figure resolution. param_subset : list[str] | None Subset of parameters to plot. If None, plots all. Returns ------- list[Path] Paths to saved plot files. """ if not HAS_ARVIZ: raise ImportError("Arviz is required for diagnostic plots") output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) idata = result.inference_data saved_plots: list[Path] = [] # Filter parameters if subset specified if param_subset is not None: var_names = param_subset else: var_names = None # 1. Pair plot try: path = plot_pair(idata, output_dir, var_names, figsize, dpi) saved_plots.append(path) except Exception as e: logger.warning(f"Failed to generate pair plot: {e}") # 2. Forest plot try: path = plot_forest(idata, output_dir, var_names, figsize, dpi) saved_plots.append(path) except Exception as e: logger.warning(f"Failed to generate forest plot: {e}") # 3. Energy plot try: path = plot_energy(idata, output_dir, figsize, dpi) saved_plots.append(path) except Exception as e: logger.warning(f"Failed to generate energy plot: {e}") # 4. Autocorrelation plot try: path = plot_autocorr(idata, output_dir, var_names, figsize, dpi) saved_plots.append(path) except Exception as e: logger.warning(f"Failed to generate autocorr plot: {e}") # 5. Rank plot try: path = plot_rank(idata, output_dir, var_names, figsize, dpi) saved_plots.append(path) except Exception as e: logger.warning(f"Failed to generate rank plot: {e}") # 6. ESS plot try: path = plot_ess(idata, output_dir, var_names, figsize, dpi) saved_plots.append(path) except Exception as e: logger.warning(f"Failed to generate ESS plot: {e}") logger.info(f"Generated {len(saved_plots)} diagnostic plots in {output_dir}") return saved_plots
[docs] def plot_pair( idata: az.InferenceData, output_dir: Path, var_names: list[str] | None = None, figsize: tuple[int, int] = DEFAULT_FIGSIZE, dpi: int = DEFAULT_DPI, ) -> Path: """Generate pair (corner) plot. Shows pairwise parameter correlations and marginal distributions. Parameters ---------- idata : az.InferenceData ArviZ inference data. output_dir : Path Output directory. var_names : list[str] | None Parameters to include. figsize : tuple[int, int] Figure size. dpi : int Resolution. Returns ------- Path Path to saved plot. """ # Limit to physical parameters for readability if var_names is None: # Get parameter names, limit to avoid huge plots all_vars = list(idata.posterior.data_vars) # Prioritize physical parameters over per-angle scaling physical = [v for v in all_vars if not _is_scaling_var(v)] scaling = [v for v in all_vars if _is_scaling_var(v)] # Limit per-angle to first 3 scaling_limited = scaling[:6] # contrast_0,1,2 + offset_0,1,2 var_names = physical + scaling_limited az.plot_pair( idata, var_names=var_names, kind="kde", figsize=figsize, marginals=True, ) output_path = output_dir / "pair_plot.png" plt.savefig(output_path, dpi=dpi, bbox_inches="tight") plt.close() logger.debug(f"Saved pair plot: {output_path}") return output_path
[docs] def plot_forest( idata: az.InferenceData, output_dir: Path, var_names: list[str] | None = None, figsize: tuple[int, int] = DEFAULT_FIGSIZE, dpi: int = DEFAULT_DPI, ) -> Path: """Generate forest plot. Shows posterior distributions with HDI intervals. Parameters ---------- idata : az.InferenceData ArviZ inference data. output_dir : Path Output directory. var_names : list[str] | None Parameters to include. figsize : tuple[int, int] Figure size. dpi : int Resolution. Returns ------- Path Path to saved plot. """ az.plot_forest( idata, var_names=var_names, combined=True, hdi_prob=0.94, figsize=figsize, ) output_path = output_dir / "forest_plot.png" plt.savefig(output_path, dpi=dpi, bbox_inches="tight") plt.close() logger.debug(f"Saved forest plot: {output_path}") return output_path
[docs] def plot_energy( idata: az.InferenceData, output_dir: Path, figsize: tuple[int, int] = (10, 6), dpi: int = DEFAULT_DPI, ) -> Path: """Generate energy plot. Compares marginal energy distribution to energy transition distribution. Large differences indicate sampling problems. Parameters ---------- idata : az.InferenceData ArviZ inference data. output_dir : Path Output directory. figsize : tuple[int, int] Figure size. dpi : int Resolution. Returns ------- Path Path to saved plot. """ # Energy plot requires sample_stats with energy info. # ArviZ expects "energy" in sample_stats; NumPyro stores "potential_energy". # The mapping is handled in create_inference_data, but we add a defensive # fallback here for InferenceData created outside that path. has_energy = False if hasattr(idata, "sample_stats") and idata.sample_stats is not None: if hasattr(idata.sample_stats, "energy"): has_energy = True elif hasattr(idata.sample_stats, "potential_energy"): # Rename in-place so az.plot_energy can find it idata.sample_stats = idata.sample_stats.rename( {"potential_energy": "energy"} ) has_energy = True if not has_energy: # Create minimal figure with message fig, ax = plt.subplots(figsize=figsize) ax.text( 0.5, 0.5, "Energy plot not available\n(energy/potential_energy not in sample_stats)", ha="center", va="center", fontsize=12, ) ax.set_xlim(0, 1) ax.set_ylim(0, 1) ax.axis("off") output_path = output_dir / "energy_plot.png" plt.savefig(output_path, dpi=dpi, bbox_inches="tight") plt.close() return output_path ax = az.plot_energy(idata, figsize=figsize) output_path = output_dir / "energy_plot.png" plt.savefig(output_path, dpi=dpi, bbox_inches="tight") plt.close() logger.debug(f"Saved energy plot: {output_path}") return output_path
[docs] def plot_autocorr( idata: az.InferenceData, output_dir: Path, var_names: list[str] | None = None, figsize: tuple[int, int] = DEFAULT_FIGSIZE, dpi: int = DEFAULT_DPI, ) -> Path: """Generate autocorrelation plot. Shows how quickly samples become independent. Parameters ---------- idata : az.InferenceData ArviZ inference data. output_dir : Path Output directory. var_names : list[str] | None Parameters to include. figsize : tuple[int, int] Figure size. dpi : int Resolution. Returns ------- Path Path to saved plot. """ # Limit parameters for readability if var_names is None: all_vars = list(idata.posterior.data_vars) # Focus on physical parameters var_names = [v for v in all_vars if not _is_scaling_var(v)] if len(var_names) == 0: var_names = all_vars[:6] az.plot_autocorr( idata, var_names=var_names, combined=True, figsize=figsize, ) output_path = output_dir / "autocorr_plot.png" plt.savefig(output_path, dpi=dpi, bbox_inches="tight") plt.close() logger.debug(f"Saved autocorr plot: {output_path}") return output_path
[docs] def plot_rank( idata: az.InferenceData, output_dir: Path, var_names: list[str] | None = None, figsize: tuple[int, int] = DEFAULT_FIGSIZE, dpi: int = DEFAULT_DPI, ) -> Path: """Generate rank plot. Rank plots help identify chain mixing problems. Parameters ---------- idata : az.InferenceData ArviZ inference data. output_dir : Path Output directory. var_names : list[str] | None Parameters to include. figsize : tuple[int, int] Figure size. dpi : int Resolution. Returns ------- Path Path to saved plot. """ # Limit parameters if var_names is None: all_vars = list(idata.posterior.data_vars) var_names = [v for v in all_vars if not _is_scaling_var(v)] if len(var_names) == 0: var_names = all_vars[:6] az.plot_rank( idata, var_names=var_names, figsize=figsize, ) output_path = output_dir / "rank_plot.png" plt.savefig(output_path, dpi=dpi, bbox_inches="tight") plt.close() logger.debug(f"Saved rank plot: {output_path}") return output_path
[docs] def plot_ess( idata: az.InferenceData, output_dir: Path, var_names: list[str] | None = None, figsize: tuple[int, int] = (10, 6), dpi: int = DEFAULT_DPI, ) -> Path: """Generate ESS evolution plot. Shows how effective sample size grows with more samples. Parameters ---------- idata : az.InferenceData ArviZ inference data. output_dir : Path Output directory. var_names : list[str] | None Parameters to include. figsize : tuple[int, int] Figure size. dpi : int Resolution. Returns ------- Path Path to saved plot. """ # Limit parameters if var_names is None: all_vars = list(idata.posterior.data_vars) var_names = [v for v in all_vars if not _is_scaling_var(v)] if len(var_names) == 0: var_names = all_vars[:6] az.plot_ess( idata, var_names=var_names, kind="evolution", figsize=figsize, ) output_path = output_dir / "ess_plot.png" plt.savefig(output_path, dpi=dpi, bbox_inches="tight") plt.close() logger.debug(f"Saved ESS plot: {output_path}") return output_path
[docs] def plot_trace( idata: az.InferenceData, output_dir: Path, var_names: list[str] | None = None, figsize: tuple[int, int] = DEFAULT_FIGSIZE, dpi: int = DEFAULT_DPI, ) -> Path: """Generate trace plot (bonus diagnostic). Shows parameter values over sampling iterations. Parameters ---------- idata : az.InferenceData ArviZ inference data. output_dir : Path Output directory. var_names : list[str] | None Parameters to include. figsize : tuple[int, int] Figure size. dpi : int Resolution. Returns ------- Path Path to saved plot. """ # Limit parameters if var_names is None: all_vars = list(idata.posterior.data_vars) var_names = [v for v in all_vars if not _is_scaling_var(v)] if len(var_names) == 0: var_names = all_vars[:6] az.plot_trace( idata, var_names=var_names, figsize=figsize, ) output_path = output_dir / "trace_plot.png" plt.savefig(output_path, dpi=dpi, bbox_inches="tight") plt.close() logger.debug(f"Saved trace plot: {output_path}") return output_path