"""Base class for CMC execution backends.
This module defines the abstract interface for CMC backends
and provides a factory function for selecting backends.
"""
from __future__ import annotations
import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from homodyne.utils.logging import get_logger
if TYPE_CHECKING:
from homodyne.optimization.cmc.config import CMCConfig
from homodyne.optimization.cmc.data_prep import PreparedData
from homodyne.optimization.cmc.diagnostics import BimodalConsensusResult
from homodyne.optimization.cmc.sampler import MCMCSamples
logger = get_logger(__name__)
[docs]
class CMCBackend(ABC):
"""Abstract base class for CMC execution backends.
Backends handle the parallel execution of MCMC sampling across
data shards and the combination of results.
"""
[docs]
@abstractmethod
def run(
self,
model: Callable,
model_kwargs: dict[str, Any],
config: CMCConfig,
shards: list[PreparedData] | None = None,
) -> MCMCSamples:
"""Run MCMC sampling (potentially across shards).
Parameters
----------
model : Callable
NumPyro model function.
model_kwargs : dict[str, Any]
Common model arguments.
config : CMCConfig
CMC configuration.
shards : list[PreparedData] | None
Data shards for parallel execution.
If None, runs single-threaded on full data.
Returns
-------
MCMCSamples
Combined samples from all shards.
"""
[docs]
@abstractmethod
def get_name(self) -> str:
"""Get backend name."""
[docs]
def is_available(self) -> bool:
"""Check if backend is available.
Returns
-------
bool
True if backend can be used.
"""
return True
[docs]
def select_backend(
config: CMCConfig,
) -> CMCBackend:
"""Select appropriate backend based on configuration.
Parameters
----------
config : CMCConfig
CMC configuration.
Returns
-------
CMCBackend
Selected backend instance.
Raises
------
ValueError
If requested backend is not available.
"""
backend_name = config.backend_name
if backend_name == "auto":
# Default to multiprocessing for CPU
backend_name = "multiprocessing"
# Backward compatibility: allow legacy "jax" alias
# NOTE: Map to multiprocessing, not pjit, because pjit backend is sequential
# (it processes shards one at a time in a for loop, not in parallel)
if backend_name == "jax":
logger.warning(
"CMC backend 'jax' is deprecated; mapping to 'multiprocessing' "
"for parallel execution. Set backend_config.name to "
"'multiprocessing' or 'auto' instead."
)
backend_name = "multiprocessing"
if backend_name == "multiprocessing":
from homodyne.optimization.cmc.backends.multiprocessing import (
MultiprocessingBackend,
)
return MultiprocessingBackend()
elif backend_name == "pjit":
try:
from homodyne.optimization.cmc.backends.pjit import PjitBackend
return PjitBackend()
except ImportError:
logger.warning(
"pjit backend not available, falling back to multiprocessing"
)
from homodyne.optimization.cmc.backends.multiprocessing import (
MultiprocessingBackend,
)
return MultiprocessingBackend()
elif backend_name == "pbs":
try:
from homodyne.optimization.cmc.backends.pbs import PBSBackend
return PBSBackend()
except ImportError:
logger.warning("PBS backend not available, falling back to multiprocessing")
from homodyne.optimization.cmc.backends.multiprocessing import (
MultiprocessingBackend,
)
return MultiprocessingBackend()
else:
raise ValueError(f"Unknown backend: {backend_name}")
[docs]
def combine_shard_samples(
shard_samples: list[MCMCSamples],
method: str = "weighted_gaussian",
chunk_size: int = 500,
) -> MCMCSamples:
"""Combine samples from multiple shards.
For K <= chunk_size shards, uses a single-pass combination.
For K > chunk_size shards (hierarchical mode), accumulates posterior
moments (mean, variance) across chunks without drawing intermediate
synthetic samples. A single Gaussian draw is performed at the end from
the aggregated moments. This avoids the precision-multiplication artefact
that arises when recursive combination re-applies precision-weighting to
synthetically drawn intermediate samples (P1-R6-01).
Memory scaling:
- Each shard result: ~100KB (13 params x 4 chains x 1500 samples x 8 bytes)
- Hierarchical (chunk=500): processes max(chunk_size) shards at once (~50MB),
then releases them. Moment accumulation uses O(n_params) space.
Parameters
----------
shard_samples : list[MCMCSamples]
Samples from each shard.
method : str
Combination method: "robust_consensus_mc" (recommended),
"consensus_mc", "weighted_gaussian", "simple_average", or "auto".
chunk_size : int
Number of shards to process per chunk for hierarchical combination.
Default 500 keeps peak memory under ~50MB per processing step.
Returns
-------
MCMCSamples
Combined samples.
"""
import gc
if len(shard_samples) == 1:
return shard_samples[0]
# For large shard counts, use moment-accumulation to limit memory.
# P1-R6-01: Do NOT draw intermediate samples and re-combine them.
# Recursive precision-weighting on synthetic draws inflates precision
# by the number of hierarchical levels, over-narrowing the final posterior.
# Instead: accumulate (weighted_mean_sum, precision_sum) across all chunks,
# then perform a single Gaussian draw at the end.
if len(shard_samples) > chunk_size and method in (
"robust_consensus_mc",
"consensus_mc",
"auto",
):
logger.info(
f"Hierarchical combination (moment-accumulation): {len(shard_samples)} shards "
f"in chunks of {chunk_size}"
)
param_names = shard_samples[0].param_names
n_chains = shard_samples[0].n_chains
n_samples = shard_samples[0].n_samples
n_chunks = (len(shard_samples) + chunk_size - 1) // chunk_size
# Two-pass approach: collect per-shard (mean, var) summaries, then
# apply degenerate-shard filtering before precision-weighted combination.
# This matches the robust filtering in _combine_shard_chunk while avoiding
# the precision-inflation bug of recursive intermediate draws.
import numpy as np
# Pass 1: collect per-shard summaries (O(K * n_params) floats)
shard_stats: dict[str, list[tuple[float, float]]] = {
name: [] for name in param_names
}
n_excluded: dict[str, int] = dict.fromkeys(param_names, 0)
for chunk_start in range(0, len(shard_samples), chunk_size):
chunk = shard_samples[chunk_start : chunk_start + chunk_size]
chunk_idx = chunk_start // chunk_size
for name in param_names:
for s in chunk:
samples_flat = s.samples[name].flatten()
if not np.all(np.isfinite(samples_flat)):
n_excluded[name] += 1
continue
shard_mean = float(np.mean(samples_flat))
shard_var = float(np.var(samples_flat, ddof=1))
shard_stats[name].append((shard_mean, shard_var))
del chunk
gc.collect()
logger.debug(f"Accumulated chunk {chunk_idx + 1}/{n_chunks}")
# Pass 2: filter degenerate shards, then precision-weight
rng = np.random.default_rng(42)
combined_samples_dict: dict[str, np.ndarray] = {}
for name in param_names:
stats = shard_stats[name]
if not stats:
logger.warning(f"Hierarchical CMC: all shards excluded for '{name}'")
combined_samples_dict[name] = rng.normal(
loc=0.0, scale=1.0, size=(n_chains, n_samples)
)
continue
if n_excluded[name] > 0:
logger.warning(
f"Hierarchical CMC: {n_excluded[name]} non-finite shards "
f"excluded for '{name}'"
)
means_arr = np.array([m for m, _ in stats])
vars_arr = np.array([v for _, v in stats])
# Degenerate-shard exclusion: shards with variance < 1e-6 * median
# indicate stuck chains that would dominate precision weighting.
if len(vars_arr) >= 3:
median_var = np.median(vars_arr)
if median_var > 0:
degenerate_mask = vars_arr < (median_var * 1e-6)
n_degenerate = int(np.sum(degenerate_mask))
if 0 < n_degenerate < len(vars_arr):
logger.warning(
f"Hierarchical CMC: {n_degenerate} degenerate "
f"shard(s) for '{name}' (var < 1e-6 * median); "
f"excluding"
)
keep = ~degenerate_mask
means_arr = means_arr[keep]
vars_arr = vars_arr[keep]
# Precision-weighted combination on filtered data
precisions = 1.0 / np.maximum(vars_arr, 1e-10)
prec_sum = float(np.sum(precisions))
combined_variance = max(1.0 / prec_sum, 1e-12) if prec_sum > 0 else 1.0
combined_mean = (
float(np.sum(precisions * means_arr) / prec_sum)
if prec_sum > 0
else 0.0
)
combined_std = np.sqrt(combined_variance)
combined_samples_dict[name] = rng.normal(
loc=combined_mean,
scale=combined_std,
size=(n_chains, n_samples),
)
# Combine extra fields from the first chunk for metadata
combined_extra: dict = {}
first_chunk = shard_samples[: min(chunk_size, len(shard_samples))]
for key in first_chunk[0].extra_fields.keys():
all_extra = [
s.extra_fields.get(key) for s in first_chunk if key in s.extra_fields
]
if all_extra:
try:
if all_extra[0].ndim == 0:
combined_extra[key] = np.stack(all_extra, axis=0)
else:
combined_extra[key] = np.concatenate(all_extra, axis=0)
except (ValueError, TypeError):
combined_extra[key] = all_extra[0]
from homodyne.optimization.cmc.sampler import MCMCSamples
return MCMCSamples(
samples=combined_samples_dict,
param_names=param_names,
n_chains=n_chains,
n_samples=n_samples,
extra_fields=combined_extra,
num_shards=sum(getattr(s, "num_shards", 1) for s in shard_samples),
)
# For deprecated methods with large K, fall back to chunked recursion
# (these methods don't support moment accumulation).
if len(shard_samples) > chunk_size:
import gc
logger.info(
f"Hierarchical combination (chunked recursion): {len(shard_samples)} shards "
f"in chunks of {chunk_size} (method={method!r})"
)
intermediate_results = []
n_chunks = (len(shard_samples) + chunk_size - 1) // chunk_size
for i in range(0, len(shard_samples), chunk_size):
chunk = shard_samples[i : i + chunk_size]
chunk_idx = i // chunk_size
chunk_result = _combine_shard_chunk(chunk, method, chunk_seed=chunk_idx)
intermediate_results.append(chunk_result)
del chunk
gc.collect()
logger.debug(f"Combined chunk {i // chunk_size + 1}/{n_chunks}")
return combine_shard_samples(intermediate_results, method, chunk_size)
return _combine_shard_chunk(shard_samples, method)
def _combine_shard_chunk(
shard_samples: list[MCMCSamples],
method: str,
chunk_seed: int = 0,
) -> MCMCSamples:
"""Combine a chunk of shard samples (internal helper).
Parameters
----------
shard_samples : list[MCMCSamples]
Samples from each shard in the chunk.
method : str
Combination method:
- "consensus_mc": Correct Consensus Monte Carlo (precision-weighted means)
- "robust_consensus_mc": Robust CMC with trimmed statistics (Jan 2026)
- "weighted_gaussian": Legacy element-wise weighted averaging (deprecated)
- "simple_average": Simple element-wise averaging (deprecated)
Returns
-------
MCMCSamples
Combined samples for this chunk.
Notes
-----
The "consensus_mc" method implements the correct Consensus Monte Carlo
algorithm (Scott et al., 2016):
1. For each shard s, compute posterior mean μ_s and variance σ²_s
2. Combined precision: 1/σ² = Σ_s (1/σ²_s)
3. Combined mean: μ = σ² × Σ_s (μ_s / σ²_s)
4. Generate new samples from N(μ, σ²)
The "robust_consensus_mc" method (Jan 2026) extends this with:
- Trimmed statistics to exclude outlier shards
- Winsorization of extreme variances
- Automatic outlier detection based on median absolute deviation
"""
import numpy as np
from homodyne.optimization.cmc.sampler import MCMCSamples
if len(shard_samples) == 1:
return shard_samples[0]
# Get parameter names from first shard
param_names = shard_samples[0].param_names
n_chains = shard_samples[0].n_chains
n_samples = shard_samples[0].n_samples
if method in ("robust_consensus_mc", "auto"):
# ROBUST Consensus Monte Carlo (Jan 2026):
# Uses trimmed statistics to handle heterogeneous shards
# "auto" resolves to robust_consensus_mc (the default)
combined_samples: dict[str, np.ndarray] = {}
# P2-R5-02: Incorporate chunk_seed so hierarchical combination passes
# produce independent samples per chunk (not all seeded identically at 42).
rng = np.random.default_rng(42 + chunk_seed)
for name in param_names:
# Compute per-shard posterior mean and variance
# P0-1: Skip shards with non-finite samples to prevent NaN propagation
shard_means = []
shard_variances = []
for s in shard_samples:
samples = s.samples[name].flatten()
if not np.all(np.isfinite(samples)):
n_bad = int(np.sum(~np.isfinite(samples)))
logger.warning(
f"Robust CMC: shard has {n_bad}/{len(samples)} non-finite "
f"samples for '{name}'; excluding shard for this parameter"
)
continue
shard_means.append(np.mean(samples))
# Use ddof=1 (unbiased estimator) for the posterior variance
# used in precision-weighted combination. For typical shard
# sizes (6000+ samples) the difference is negligible, but for
# adaptive-sampled small shards (350 samples) ddof=0 would
# underestimate posterior width by ~0.3%.
shard_variances.append(np.var(samples, ddof=1))
if not shard_means:
raise ValueError(
f"All shards excluded for parameter '{name}' due to non-finite samples. "
"Cannot combine posteriors - check NUTS divergence diagnostics."
)
means_arr = np.array(shard_means)
vars_arr = np.array(shard_variances)
# Detect outliers using median absolute deviation (MAD)
# More robust than using standard deviation
median_mean = np.median(means_arr)
mad = np.median(np.abs(means_arr - median_mean))
# Modified Z-score threshold (commonly used: 3.5)
threshold = 3.5
if mad > 0:
modified_z = 0.6745 * np.abs(means_arr - median_mean) / mad
inlier_mask = modified_z < threshold
else:
# If MAD is 0 (all means identical), keep all
inlier_mask = np.ones(len(means_arr), dtype=bool)
# Require at least 3 shards for robust statistics
n_inliers = np.sum(inlier_mask)
if n_inliers < 3:
# Fall back to standard CMC if too few inliers
logger.warning(
f"Robust CMC: Only {n_inliers} inliers for {name}, "
"falling back to standard combination"
)
inlier_mask = np.ones(len(means_arr), dtype=bool)
# Use only inlier shards for combination
filtered_means = means_arr[inlier_mask]
filtered_vars = vars_arr[inlier_mask]
# Detect degenerate shards: near-zero variance indicates a chain
# stuck at its init value (step_size→0, 0% divergences but no mixing).
# Such shards get precision≈1e10 and silently dominate the combined
# posterior. Exclude them before precision weighting.
if len(filtered_vars) >= 3:
median_var = np.median(filtered_vars)
if median_var > 0:
degenerate_mask = filtered_vars < (median_var * 1e-6)
n_degenerate = int(np.sum(degenerate_mask))
if n_degenerate > 0 and n_degenerate < len(filtered_vars):
logger.warning(
f"Robust CMC: {n_degenerate} degenerate shard(s) "
f"for '{name}' (variance < 1e-6 * median); excluding"
)
keep = ~degenerate_mask
filtered_means = filtered_means[keep]
filtered_vars = filtered_vars[keep]
# Winsorize extreme variances (cap at 5th and 95th percentiles)
if len(filtered_vars) >= 5:
var_low, var_high = np.percentile(filtered_vars, [5, 95])
filtered_vars = np.clip(filtered_vars, var_low, var_high)
# Precision-weighted combination on filtered data
precisions = [1.0 / max(v, 1e-10) for v in filtered_vars]
combined_precision = sum(precisions)
# P2-3: Floor combined_variance to prevent point-mass posteriors.
combined_variance = max(1.0 / combined_precision, 1e-12)
weighted_mean_sum = sum(
p * m for p, m in zip(precisions, filtered_means, strict=False)
)
combined_mean = combined_variance * weighted_mean_sum
# Generate new samples from the combined Gaussian
combined_std = np.sqrt(combined_variance)
new_samples = rng.normal(
loc=combined_mean,
scale=combined_std,
size=(n_chains, n_samples),
)
combined_samples[name] = new_samples
# Log if outliers were excluded
n_excluded = len(means_arr) - n_inliers
if n_excluded > 0:
logger.debug(
f"Robust CMC: {name} excluded {n_excluded}/{len(means_arr)} "
f"outlier shards (MAD-based detection)"
)
elif method == "consensus_mc":
# CORRECT Consensus Monte Carlo (Scott et al., 2016):
# Combine posterior moments, then generate new samples
combined_samples = {}
# P2-R5-02: Use chunk_seed offset so hierarchical combination chunks
# draw independent samples (not all from the same RNG state).
rng = np.random.default_rng(42 + chunk_seed) # Deterministic + chunk-unique
for name in param_names:
# Compute per-shard posterior mean and variance
# P0-1: Skip shards with non-finite samples
shard_means = []
shard_variances = []
for s in shard_samples:
samples = s.samples[name].flatten()
if not np.all(np.isfinite(samples)):
n_bad = int(np.sum(~np.isfinite(samples)))
logger.warning(
f"Consensus MC: shard has {n_bad}/{len(samples)} non-finite "
f"samples for '{name}'; excluding shard for this parameter"
)
continue
shard_means.append(np.mean(samples))
shard_variances.append(np.var(samples, ddof=1))
if not shard_means:
raise ValueError(
f"All shards excluded for parameter '{name}' due to non-finite samples. "
"Cannot combine posteriors - check NUTS divergence diagnostics."
)
# Detect degenerate shards: near-zero variance indicates a chain
# stuck at its init value — gets precision ~1e10 and dominates.
means_arr = np.array(shard_means)
vars_arr = np.array(shard_variances)
if len(vars_arr) >= 3:
median_var = np.median(vars_arr)
if median_var > 0:
degenerate_mask = vars_arr < (median_var * 1e-6)
n_degenerate = int(np.sum(degenerate_mask))
if 0 < n_degenerate < len(vars_arr):
logger.warning(
f"Consensus MC: {n_degenerate} degenerate shard(s) "
f"for '{name}' (variance < 1e-6 * median); excluding"
)
keep = ~degenerate_mask
means_arr = means_arr[keep]
vars_arr = vars_arr[keep]
shard_means = means_arr.tolist()
shard_variances = vars_arr.tolist()
# Precision-weighted combination
# Combined precision = sum of precisions
precisions = [1.0 / max(v, 1e-10) for v in shard_variances]
combined_precision = sum(precisions)
# P2-3: Floor combined_variance to prevent point-mass posteriors
# when degenerate shards produce near-zero variance.
combined_variance = max(1.0 / combined_precision, 1e-12)
# Combined mean = (combined_variance) * sum(precision_s * mean_s)
weighted_mean_sum = sum(
p * m for p, m in zip(precisions, shard_means, strict=False)
)
combined_mean = combined_variance * weighted_mean_sum
# Generate new samples from the combined Gaussian
# Shape: (n_chains, n_samples)
combined_std = np.sqrt(combined_variance)
new_samples = rng.normal(
loc=combined_mean,
scale=combined_std,
size=(n_chains, n_samples),
)
combined_samples[name] = new_samples
elif method == "simple_average":
warnings.warn(
"combination_method='simple_average' is deprecated since v2.12.0 "
"and will be removed in v3.0. Use 'consensus_mc' instead.",
DeprecationWarning,
stacklevel=2,
)
# Legacy: Simple element-wise average across shards (deprecated)
combined_samples = {}
for name in param_names:
all_shard_samples = [s.samples[name] for s in shard_samples]
# NaN-safe: NUTS can produce NaN samples from divergent transitions
combined_samples[name] = np.nanmean(all_shard_samples, axis=0)
else: # weighted_gaussian (legacy default)
warnings.warn(
"combination_method='weighted_gaussian' is deprecated since v2.12.0 "
"and will be removed in v3.0. Use 'consensus_mc' instead.",
DeprecationWarning,
stacklevel=2,
)
# Legacy: Element-wise weighted averaging (deprecated)
# WARNING: This is mathematically incorrect but kept for backward compatibility
combined_samples = {}
for name in param_names:
all_shard_samples = [s.samples[name] for s in shard_samples]
# NaN-safe: NUTS can produce NaN samples from divergent transitions
variances = [float(np.nanvar(s)) for s in all_shard_samples]
precisions = [1.0 / max(v, 1e-10) for v in variances]
total_precision = sum(precisions)
weights = [p / total_precision for p in precisions]
weighted_sum = sum(
w * s for w, s in zip(weights, all_shard_samples, strict=False)
)
combined_samples[name] = weighted_sum
# Combine extra fields
combined_extra: dict[str, Any] = {}
for key in shard_samples[0].extra_fields.keys():
all_extra = [
s.extra_fields.get(key) for s in shard_samples if key in s.extra_fields
]
if all_extra:
try:
# Handle scalar fields (zero-dimensional arrays)
if all_extra[0].ndim == 0:
# Stack scalars into a 1D array
combined_extra[key] = np.stack(all_extra, axis=0)
else:
# Concatenate arrays along the chain dimension (axis=0)
combined_extra[key] = np.concatenate(all_extra, axis=0)
except Exception as e:
# Fallback for incompatible shapes
logger.warning(f"Failed to combine extra field '{key}': {e}")
combined_extra[key] = all_extra[0]
# Track total shards for correct divergence rate calculation
total_shards = sum(getattr(s, "num_shards", 1) for s in shard_samples)
return MCMCSamples(
samples=combined_samples,
param_names=param_names,
n_chains=n_chains,
n_samples=n_samples,
extra_fields=combined_extra,
num_shards=total_shards,
)
[docs]
def combine_shard_samples_bimodal(
shard_samples: list[MCMCSamples],
cluster_assignments: tuple[list[int], list[int]],
bimodal_detections: list[dict[str, Any]],
modal_params: list[str],
co_occurrence: dict[str, Any],
method: str = "consensus_mc",
chunk_seed: int = 0,
) -> tuple[MCMCSamples, BimodalConsensusResult]:
"""Combine shard samples using mode-aware consensus.
For bimodal shards, uses per-component GMM statistics instead of
full-posterior statistics to avoid density-trough corruption.
Parameters
----------
shard_samples : list[MCMCSamples]
All successful shard samples.
cluster_assignments : tuple[list[int], list[int]]
(lower_cluster_shards, upper_cluster_shards) from cluster_shard_modes().
Bimodal shards may appear in both lists.
bimodal_detections : list[dict[str, Any]]
Per-detection records with "shard", "param", "mode1", "mode2",
"std1", "std2", "weights".
modal_params : list[str]
Parameters that triggered bimodal detection.
co_occurrence : dict[str, Any]
Cross-parameter co-occurrence info.
method : str
Base combination method for non-modal params.
Returns
-------
tuple[MCMCSamples, BimodalConsensusResult]
(combined_samples, bimodal_result) where combined_samples has
mixture-drawn primary samples and bimodal_result has per-mode details.
"""
import numpy as np
from homodyne.optimization.cmc.diagnostics import (
BimodalConsensusResult,
ModeCluster,
)
from homodyne.optimization.cmc.sampler import MCMCSamples
cluster_lower, cluster_upper = cluster_assignments
n_total = len(shard_samples)
param_names = shard_samples[0].param_names
n_chains = shard_samples[0].n_chains
n_samples = shard_samples[0].n_samples
rng = np.random.default_rng(42 + chunk_seed)
# Index bimodal detections by (shard, param) for fast lookup
bimodal_index: dict[tuple[int, str], dict[str, Any]] = {}
for det in bimodal_detections:
bimodal_index[(det["shard"], det["param"])] = det
modal_set = set(modal_params)
def _consensus_for_cluster(
cluster_shards: list[int],
is_lower: bool,
) -> dict[str, tuple[float, float]]:
"""Compute consensus (mean, std) for each param in a cluster.
Returns dict of {param: (combined_mean, combined_std)}.
"""
result: dict[str, tuple[float, float]] = {}
for name in param_names:
shard_means: list[float] = []
shard_variances: list[float] = []
for shard_idx in cluster_shards:
key = (shard_idx, name)
if name in modal_set and key in bimodal_index:
# Bimodal shard + modal param: use component-level stats
det = bimodal_index[key]
m1, m2 = det["mode1"], det["mode2"]
s1, s2 = det["std1"], det["std2"]
lo, hi = sorted([(m1, s1), (m2, s2)], key=lambda x: x[0])
if is_lower:
shard_means.append(lo[0])
shard_variances.append(lo[1] ** 2)
else:
shard_means.append(hi[0])
shard_variances.append(hi[1] ** 2)
else:
# Unimodal shard or non-modal param: use full posterior
samples = shard_samples[shard_idx].samples[name].flatten()
# NaN-safe: NUTS can produce NaN samples from divergent transitions
if not np.all(np.isfinite(samples)):
n_bad = int(np.sum(~np.isfinite(samples)))
logger.warning(
f"Bimodal CMC: shard {shard_idx} has "
f"{n_bad}/{len(samples)} non-finite samples "
f"for '{name}'; using finite subset"
)
finite_samples = samples[np.isfinite(samples)]
if finite_samples.size == 0:
continue
shard_means.append(float(np.mean(finite_samples)))
shard_variances.append(
float(
np.var(finite_samples, ddof=1)
if finite_samples.size > 1
else 0.0
)
)
else:
shard_means.append(float(np.mean(samples)))
shard_variances.append(float(np.var(samples, ddof=1)))
# Exclude degenerate shards (near-zero variance) before consensus
if len(shard_variances) >= 3:
median_var = float(np.median(shard_variances))
degenerate_threshold = 1e-6 * median_var if median_var > 0 else 1e-30
valid_mask = [v >= degenerate_threshold for v in shard_variances]
shard_means = [
m for m, ok in zip(shard_means, valid_mask, strict=True) if ok
]
shard_variances = [
v for v, ok in zip(shard_variances, valid_mask, strict=True) if ok
]
if len(shard_means) < 3:
# Too few shards: use simple mean
combined_mean = float(np.mean(shard_means)) if shard_means else 0.0
combined_std = (
float(np.std(shard_means)) if len(shard_means) > 1 else 1e-6
)
else:
# Precision-weighted consensus
precisions = [1.0 / max(v, 1e-10) for v in shard_variances]
combined_precision = sum(precisions)
combined_variance = 1.0 / combined_precision
weighted_mean_sum = sum(
p * m for p, m in zip(precisions, shard_means, strict=False)
)
combined_mean = combined_variance * weighted_mean_sum
combined_std = float(np.sqrt(combined_variance))
result[name] = (float(combined_mean), float(combined_std))
return result
# Run per-mode consensus
lower_stats = _consensus_for_cluster(cluster_lower, is_lower=True)
upper_stats = _consensus_for_cluster(cluster_upper, is_lower=False)
# Build mode weights
# For bimodal shards that appear in both clusters, count them once total
unique_shards = set(cluster_lower) | set(cluster_upper)
w_lower = len(cluster_lower) / max(len(unique_shards), 1)
w_upper = len(cluster_upper) / max(len(unique_shards), 1)
# Normalize (bimodal shards counted in both lists inflate the sum)
total_w = w_lower + w_upper
w_lower /= total_w
w_upper /= total_w
# Generate per-mode samples
n_lower_samples = int(round(w_lower * n_samples))
n_upper_samples = n_samples - n_lower_samples
lower_samples: dict[str, np.ndarray] = {}
upper_samples: dict[str, np.ndarray] = {}
combined_samples: dict[str, np.ndarray] = {}
for name in param_names:
lo_mean, lo_std = lower_stats[name]
up_mean, up_std = upper_stats[name]
lower_samples[name] = rng.normal(
loc=lo_mean,
scale=max(lo_std, 1e-10),
size=(n_chains, n_lower_samples),
)
upper_samples[name] = rng.normal(
loc=up_mean,
scale=max(up_std, 1e-10),
size=(n_chains, n_upper_samples),
)
# Mixture-draw: concatenate and shuffle within each chain
mixed = np.concatenate([lower_samples[name], upper_samples[name]], axis=1)
for c in range(n_chains):
rng.shuffle(mixed[c])
combined_samples[name] = mixed
# Build ModeCluster objects with independent draws from each mode's consensus
# Gaussian. These are separate from the mixture-drawn primary samples above;
# they provide full per-mode sample sets for downstream analysis.
mode_lower = ModeCluster(
mean={n: lower_stats[n][0] for n in param_names},
std={n: lower_stats[n][1] for n in param_names},
weight=w_lower,
n_shards=len(cluster_lower),
samples={
n: rng.normal(
loc=lower_stats[n][0],
scale=max(lower_stats[n][1], 1e-10),
size=(n_chains, n_samples),
)
for n in param_names
},
)
mode_upper = ModeCluster(
mean={n: upper_stats[n][0] for n in param_names},
std={n: upper_stats[n][1] for n in param_names},
weight=w_upper,
n_shards=len(cluster_upper),
samples={
n: rng.normal(
loc=upper_stats[n][0],
scale=max(upper_stats[n][1], 1e-10),
size=(n_chains, n_samples),
)
for n in param_names
},
)
bimodal_result = BimodalConsensusResult(
modes=[mode_lower, mode_upper],
modal_params=modal_params,
co_occurrence=co_occurrence,
)
# Combine extra fields from all shards
combined_extra: dict[str, Any] = {}
for key in shard_samples[0].extra_fields.keys():
all_extra = [
s.extra_fields.get(key) for s in shard_samples if key in s.extra_fields
]
if all_extra:
try:
if all_extra[0].ndim == 0:
combined_extra[key] = np.stack(all_extra, axis=0)
else:
combined_extra[key] = np.concatenate(all_extra, axis=0)
except Exception as e:
logger.warning(f"Failed to combine extra field '{key}': {e}")
combined_extra[key] = all_extra[0]
combined = MCMCSamples(
samples=combined_samples,
param_names=param_names,
n_chains=n_chains,
n_samples=n_samples,
extra_fields=combined_extra,
num_shards=n_total,
)
return combined, bimodal_result