"""MCMC Posterior Comparison Visualization.
Provides per-shard vs combined posterior distribution comparisons for CMC results.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.figure import Figure
from homodyne.utils.logging import get_logger
from homodyne.utils.path_validation import PathValidationError, validate_plot_save_path
logger = get_logger(__name__)
[docs]
def plot_posterior_comparison(
result: Any, # MCMCResult type
param_indices: list[int] | None = None,
figsize: tuple[float, float] | None = None,
bins: int = 30,
show: bool = False,
save_path: str | Path | None = None,
dpi: int = 150,
) -> Figure:
"""Compare per-shard posteriors with combined posterior (CMC only).
Visualizes posterior distributions for each parameter, showing both
per-shard distributions and the combined posterior.
Parameters
----------
result : MCMCResult
CMC result object with per-shard diagnostics
param_indices : list of int, optional
Parameter indices to plot. If None, plots first 6 parameters
figsize : tuple, optional
Figure size (width, height). If None, auto-calculated
bins : int, default=30
Number of bins for histograms
show : bool, default=False
If True, display the figure interactively
save_path : str or Path, optional
If provided, save figure to this path
dpi : int, default=150
DPI for saved figure
Returns
-------
Figure
Matplotlib figure object
Raises
------
ValueError
If result is not a CMC result
Examples
--------
>>> plot_posterior_comparison(cmc_result, param_indices=[0, 1, 2])
Notes
-----
- Light colored lines: Per-shard posteriors
- Bold colored line: Combined posterior
- Good agreement: All distributions overlap
- Poor agreement: Shards show different modes
"""
# Check if CMC result
is_cmc = result.is_cmc_result() if hasattr(result, "is_cmc_result") else False
if not is_cmc:
raise ValueError("Posterior comparison is only available for CMC results")
if result.per_shard_diagnostics is None:
raise ValueError("Per-shard diagnostics not available")
# Extract samples
num_params = (
result.samples_params.shape[-1] if result.samples_params is not None else 0
)
# Select parameters to plot
if param_indices is None:
param_indices = list(range(min(6, num_params)))
num_plots = len(param_indices)
# Calculate figure layout
ncols = min(3, num_plots)
nrows = (num_plots + ncols - 1) // ncols
if figsize is None:
figsize = (5 * ncols, 4 * nrows)
fig, axes = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False)
axes = axes.flatten()
# Get parameter names
if result.analysis_mode == "static":
param_names = ["D0", "alpha", "D_offset"]
elif result.analysis_mode == "laminar_flow":
param_names = [
"D0",
"alpha",
"D_offset",
"gamma_dot_t0",
"beta",
"gamma_dot_t_offset",
"phi0",
]
else:
param_names = [f"param_{i}" for i in range(num_params)]
# Plot each parameter
for plot_idx, param_idx in enumerate(param_indices):
ax = axes[plot_idx]
# Extract per-shard samples
num_shards = len(result.per_shard_diagnostics)
colors = matplotlib.colormaps["tab10"](np.linspace(0, 1, num_shards))
for shard_idx, shard_diag in enumerate(result.per_shard_diagnostics):
if "trace_data" in shard_diag:
trace_key = f"param_{param_idx}"
if trace_key in shard_diag["trace_data"]:
trace = np.array(shard_diag["trace_data"][trace_key])
# Flatten if multi-chain
if trace.ndim == 2:
trace = trace.flatten()
# Plot histogram
ax.hist(
trace,
bins=bins,
alpha=0.3,
color=colors[shard_idx],
density=True,
label=f"Shard {shard_idx}" if num_shards <= 10 else "",
)
# Plot combined posterior
if result.samples_params is not None:
combined_samples = result.samples_params[:, param_idx]
ax.hist(
combined_samples,
bins=bins,
alpha=0.5,
color="black",
density=True,
histtype="step",
linewidth=2,
label="Combined",
)
ax.set_xlabel(param_names[param_idx])
ax.set_ylabel("Density")
ax.set_title(f"{param_names[param_idx]} Posterior")
# Add legend for first subplot only
if plot_idx == 0 and num_shards <= 10:
ax.legend(loc="upper right", fontsize=8)
# Hide unused subplots
for idx in range(num_plots, len(axes)):
axes[idx].set_visible(False)
# Add overall title
fig.suptitle(
f"Posterior Comparison (CMC: {result.num_shards} shards)",
fontsize=14,
fontweight="bold",
y=0.995,
)
plt.tight_layout()
# Save or show
if save_path is not None:
try:
validated_path = validate_plot_save_path(save_path)
if validated_path is not None:
fig.savefig(validated_path, dpi=dpi, bbox_inches="tight")
logger.info(f"Posterior comparison saved to {validated_path.name}")
except (PathValidationError, ValueError) as e:
logger.warning(f"Could not save posterior comparison: {e}")
if show:
plt.show()
elif save_path is not None:
plt.close(fig)
return fig