Source code for homodyne.optimization.cmc.backends.base

"""Base class for CMC execution backends.

This module defines the abstract interface for CMC backends
and provides a factory function for selecting backends.
"""

from __future__ import annotations

import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import TYPE_CHECKING, Any

from homodyne.utils.logging import get_logger

if TYPE_CHECKING:
    from homodyne.optimization.cmc.config import CMCConfig
    from homodyne.optimization.cmc.data_prep import PreparedData
    from homodyne.optimization.cmc.diagnostics import BimodalConsensusResult
    from homodyne.optimization.cmc.sampler import MCMCSamples

logger = get_logger(__name__)


[docs] class CMCBackend(ABC): """Abstract base class for CMC execution backends. Backends handle the parallel execution of MCMC sampling across data shards and the combination of results. """
[docs] @abstractmethod def run( self, model: Callable, model_kwargs: dict[str, Any], config: CMCConfig, shards: list[PreparedData] | None = None, ) -> MCMCSamples: """Run MCMC sampling (potentially across shards). Parameters ---------- model : Callable NumPyro model function. model_kwargs : dict[str, Any] Common model arguments. config : CMCConfig CMC configuration. shards : list[PreparedData] | None Data shards for parallel execution. If None, runs single-threaded on full data. Returns ------- MCMCSamples Combined samples from all shards. """
[docs] @abstractmethod def get_name(self) -> str: """Get backend name."""
[docs] def is_available(self) -> bool: """Check if backend is available. Returns ------- bool True if backend can be used. """ return True
[docs] def select_backend( config: CMCConfig, ) -> CMCBackend: """Select appropriate backend based on configuration. Parameters ---------- config : CMCConfig CMC configuration. Returns ------- CMCBackend Selected backend instance. Raises ------ ValueError If requested backend is not available. """ backend_name = config.backend_name if backend_name == "auto": # Default to multiprocessing for CPU backend_name = "multiprocessing" # Backward compatibility: allow legacy "jax" alias # NOTE: Map to multiprocessing, not pjit, because pjit backend is sequential # (it processes shards one at a time in a for loop, not in parallel) if backend_name == "jax": logger.warning( "CMC backend 'jax' is deprecated; mapping to 'multiprocessing' " "for parallel execution. Set backend_config.name to " "'multiprocessing' or 'auto' instead." ) backend_name = "multiprocessing" if backend_name == "multiprocessing": from homodyne.optimization.cmc.backends.multiprocessing import ( MultiprocessingBackend, ) return MultiprocessingBackend() elif backend_name == "pjit": try: from homodyne.optimization.cmc.backends.pjit import PjitBackend return PjitBackend() except ImportError: logger.warning( "pjit backend not available, falling back to multiprocessing" ) from homodyne.optimization.cmc.backends.multiprocessing import ( MultiprocessingBackend, ) return MultiprocessingBackend() elif backend_name == "pbs": try: from homodyne.optimization.cmc.backends.pbs import PBSBackend return PBSBackend() except ImportError: logger.warning("PBS backend not available, falling back to multiprocessing") from homodyne.optimization.cmc.backends.multiprocessing import ( MultiprocessingBackend, ) return MultiprocessingBackend() else: raise ValueError(f"Unknown backend: {backend_name}")
[docs] def combine_shard_samples( shard_samples: list[MCMCSamples], method: str = "weighted_gaussian", chunk_size: int = 500, ) -> MCMCSamples: """Combine samples from multiple shards. For K <= chunk_size shards, uses a single-pass combination. For K > chunk_size shards (hierarchical mode), accumulates posterior moments (mean, variance) across chunks without drawing intermediate synthetic samples. A single Gaussian draw is performed at the end from the aggregated moments. This avoids the precision-multiplication artefact that arises when recursive combination re-applies precision-weighting to synthetically drawn intermediate samples (P1-R6-01). Memory scaling: - Each shard result: ~100KB (13 params x 4 chains x 1500 samples x 8 bytes) - Hierarchical (chunk=500): processes max(chunk_size) shards at once (~50MB), then releases them. Moment accumulation uses O(n_params) space. Parameters ---------- shard_samples : list[MCMCSamples] Samples from each shard. method : str Combination method: "robust_consensus_mc" (recommended), "consensus_mc", "weighted_gaussian", "simple_average", or "auto". chunk_size : int Number of shards to process per chunk for hierarchical combination. Default 500 keeps peak memory under ~50MB per processing step. Returns ------- MCMCSamples Combined samples. """ import gc if len(shard_samples) == 1: return shard_samples[0] # For large shard counts, use moment-accumulation to limit memory. # P1-R6-01: Do NOT draw intermediate samples and re-combine them. # Recursive precision-weighting on synthetic draws inflates precision # by the number of hierarchical levels, over-narrowing the final posterior. # Instead: accumulate (weighted_mean_sum, precision_sum) across all chunks, # then perform a single Gaussian draw at the end. if len(shard_samples) > chunk_size and method in ( "robust_consensus_mc", "consensus_mc", "auto", ): logger.info( f"Hierarchical combination (moment-accumulation): {len(shard_samples)} shards " f"in chunks of {chunk_size}" ) param_names = shard_samples[0].param_names n_chains = shard_samples[0].n_chains n_samples = shard_samples[0].n_samples n_chunks = (len(shard_samples) + chunk_size - 1) // chunk_size # Two-pass approach: collect per-shard (mean, var) summaries, then # apply degenerate-shard filtering before precision-weighted combination. # This matches the robust filtering in _combine_shard_chunk while avoiding # the precision-inflation bug of recursive intermediate draws. import numpy as np # Pass 1: collect per-shard summaries (O(K * n_params) floats) shard_stats: dict[str, list[tuple[float, float]]] = { name: [] for name in param_names } n_excluded: dict[str, int] = dict.fromkeys(param_names, 0) for chunk_start in range(0, len(shard_samples), chunk_size): chunk = shard_samples[chunk_start : chunk_start + chunk_size] chunk_idx = chunk_start // chunk_size for name in param_names: for s in chunk: samples_flat = s.samples[name].flatten() if not np.all(np.isfinite(samples_flat)): n_excluded[name] += 1 continue shard_mean = float(np.mean(samples_flat)) shard_var = float(np.var(samples_flat, ddof=1)) shard_stats[name].append((shard_mean, shard_var)) del chunk gc.collect() logger.debug(f"Accumulated chunk {chunk_idx + 1}/{n_chunks}") # Pass 2: filter degenerate shards, then precision-weight rng = np.random.default_rng(42) combined_samples_dict: dict[str, np.ndarray] = {} for name in param_names: stats = shard_stats[name] if not stats: logger.warning(f"Hierarchical CMC: all shards excluded for '{name}'") combined_samples_dict[name] = rng.normal( loc=0.0, scale=1.0, size=(n_chains, n_samples) ) continue if n_excluded[name] > 0: logger.warning( f"Hierarchical CMC: {n_excluded[name]} non-finite shards " f"excluded for '{name}'" ) means_arr = np.array([m for m, _ in stats]) vars_arr = np.array([v for _, v in stats]) # Degenerate-shard exclusion: shards with variance < 1e-6 * median # indicate stuck chains that would dominate precision weighting. if len(vars_arr) >= 3: median_var = np.median(vars_arr) if median_var > 0: degenerate_mask = vars_arr < (median_var * 1e-6) n_degenerate = int(np.sum(degenerate_mask)) if 0 < n_degenerate < len(vars_arr): logger.warning( f"Hierarchical CMC: {n_degenerate} degenerate " f"shard(s) for '{name}' (var < 1e-6 * median); " f"excluding" ) keep = ~degenerate_mask means_arr = means_arr[keep] vars_arr = vars_arr[keep] # Precision-weighted combination on filtered data precisions = 1.0 / np.maximum(vars_arr, 1e-10) prec_sum = float(np.sum(precisions)) combined_variance = max(1.0 / prec_sum, 1e-12) if prec_sum > 0 else 1.0 combined_mean = ( float(np.sum(precisions * means_arr) / prec_sum) if prec_sum > 0 else 0.0 ) combined_std = np.sqrt(combined_variance) combined_samples_dict[name] = rng.normal( loc=combined_mean, scale=combined_std, size=(n_chains, n_samples), ) # Combine extra fields from the first chunk for metadata combined_extra: dict = {} first_chunk = shard_samples[: min(chunk_size, len(shard_samples))] for key in first_chunk[0].extra_fields.keys(): all_extra = [ s.extra_fields.get(key) for s in first_chunk if key in s.extra_fields ] if all_extra: try: if all_extra[0].ndim == 0: combined_extra[key] = np.stack(all_extra, axis=0) else: combined_extra[key] = np.concatenate(all_extra, axis=0) except (ValueError, TypeError): combined_extra[key] = all_extra[0] from homodyne.optimization.cmc.sampler import MCMCSamples return MCMCSamples( samples=combined_samples_dict, param_names=param_names, n_chains=n_chains, n_samples=n_samples, extra_fields=combined_extra, num_shards=sum(getattr(s, "num_shards", 1) for s in shard_samples), ) # For deprecated methods with large K, fall back to chunked recursion # (these methods don't support moment accumulation). if len(shard_samples) > chunk_size: import gc logger.info( f"Hierarchical combination (chunked recursion): {len(shard_samples)} shards " f"in chunks of {chunk_size} (method={method!r})" ) intermediate_results = [] n_chunks = (len(shard_samples) + chunk_size - 1) // chunk_size for i in range(0, len(shard_samples), chunk_size): chunk = shard_samples[i : i + chunk_size] chunk_idx = i // chunk_size chunk_result = _combine_shard_chunk(chunk, method, chunk_seed=chunk_idx) intermediate_results.append(chunk_result) del chunk gc.collect() logger.debug(f"Combined chunk {i // chunk_size + 1}/{n_chunks}") return combine_shard_samples(intermediate_results, method, chunk_size) return _combine_shard_chunk(shard_samples, method)
def _combine_shard_chunk( shard_samples: list[MCMCSamples], method: str, chunk_seed: int = 0, ) -> MCMCSamples: """Combine a chunk of shard samples (internal helper). Parameters ---------- shard_samples : list[MCMCSamples] Samples from each shard in the chunk. method : str Combination method: - "consensus_mc": Correct Consensus Monte Carlo (precision-weighted means) - "robust_consensus_mc": Robust CMC with trimmed statistics (Jan 2026) - "weighted_gaussian": Legacy element-wise weighted averaging (deprecated) - "simple_average": Simple element-wise averaging (deprecated) Returns ------- MCMCSamples Combined samples for this chunk. Notes ----- The "consensus_mc" method implements the correct Consensus Monte Carlo algorithm (Scott et al., 2016): 1. For each shard s, compute posterior mean μ_s and variance σ²_s 2. Combined precision: 1/σ² = Σ_s (1/σ²_s) 3. Combined mean: μ = σ² × Σ_s (μ_s / σ²_s) 4. Generate new samples from N(μ, σ²) The "robust_consensus_mc" method (Jan 2026) extends this with: - Trimmed statistics to exclude outlier shards - Winsorization of extreme variances - Automatic outlier detection based on median absolute deviation """ import numpy as np from homodyne.optimization.cmc.sampler import MCMCSamples if len(shard_samples) == 1: return shard_samples[0] # Get parameter names from first shard param_names = shard_samples[0].param_names n_chains = shard_samples[0].n_chains n_samples = shard_samples[0].n_samples if method in ("robust_consensus_mc", "auto"): # ROBUST Consensus Monte Carlo (Jan 2026): # Uses trimmed statistics to handle heterogeneous shards # "auto" resolves to robust_consensus_mc (the default) combined_samples: dict[str, np.ndarray] = {} # P2-R5-02: Incorporate chunk_seed so hierarchical combination passes # produce independent samples per chunk (not all seeded identically at 42). rng = np.random.default_rng(42 + chunk_seed) for name in param_names: # Compute per-shard posterior mean and variance # P0-1: Skip shards with non-finite samples to prevent NaN propagation shard_means = [] shard_variances = [] for s in shard_samples: samples = s.samples[name].flatten() if not np.all(np.isfinite(samples)): n_bad = int(np.sum(~np.isfinite(samples))) logger.warning( f"Robust CMC: shard has {n_bad}/{len(samples)} non-finite " f"samples for '{name}'; excluding shard for this parameter" ) continue shard_means.append(np.mean(samples)) # Use ddof=1 (unbiased estimator) for the posterior variance # used in precision-weighted combination. For typical shard # sizes (6000+ samples) the difference is negligible, but for # adaptive-sampled small shards (350 samples) ddof=0 would # underestimate posterior width by ~0.3%. shard_variances.append(np.var(samples, ddof=1)) if not shard_means: raise ValueError( f"All shards excluded for parameter '{name}' due to non-finite samples. " "Cannot combine posteriors - check NUTS divergence diagnostics." ) means_arr = np.array(shard_means) vars_arr = np.array(shard_variances) # Detect outliers using median absolute deviation (MAD) # More robust than using standard deviation median_mean = np.median(means_arr) mad = np.median(np.abs(means_arr - median_mean)) # Modified Z-score threshold (commonly used: 3.5) threshold = 3.5 if mad > 0: modified_z = 0.6745 * np.abs(means_arr - median_mean) / mad inlier_mask = modified_z < threshold else: # If MAD is 0 (all means identical), keep all inlier_mask = np.ones(len(means_arr), dtype=bool) # Require at least 3 shards for robust statistics n_inliers = np.sum(inlier_mask) if n_inliers < 3: # Fall back to standard CMC if too few inliers logger.warning( f"Robust CMC: Only {n_inliers} inliers for {name}, " "falling back to standard combination" ) inlier_mask = np.ones(len(means_arr), dtype=bool) # Use only inlier shards for combination filtered_means = means_arr[inlier_mask] filtered_vars = vars_arr[inlier_mask] # Detect degenerate shards: near-zero variance indicates a chain # stuck at its init value (step_size→0, 0% divergences but no mixing). # Such shards get precision≈1e10 and silently dominate the combined # posterior. Exclude them before precision weighting. if len(filtered_vars) >= 3: median_var = np.median(filtered_vars) if median_var > 0: degenerate_mask = filtered_vars < (median_var * 1e-6) n_degenerate = int(np.sum(degenerate_mask)) if n_degenerate > 0 and n_degenerate < len(filtered_vars): logger.warning( f"Robust CMC: {n_degenerate} degenerate shard(s) " f"for '{name}' (variance < 1e-6 * median); excluding" ) keep = ~degenerate_mask filtered_means = filtered_means[keep] filtered_vars = filtered_vars[keep] # Winsorize extreme variances (cap at 5th and 95th percentiles) if len(filtered_vars) >= 5: var_low, var_high = np.percentile(filtered_vars, [5, 95]) filtered_vars = np.clip(filtered_vars, var_low, var_high) # Precision-weighted combination on filtered data precisions = [1.0 / max(v, 1e-10) for v in filtered_vars] combined_precision = sum(precisions) # P2-3: Floor combined_variance to prevent point-mass posteriors. combined_variance = max(1.0 / combined_precision, 1e-12) weighted_mean_sum = sum( p * m for p, m in zip(precisions, filtered_means, strict=False) ) combined_mean = combined_variance * weighted_mean_sum # Generate new samples from the combined Gaussian combined_std = np.sqrt(combined_variance) new_samples = rng.normal( loc=combined_mean, scale=combined_std, size=(n_chains, n_samples), ) combined_samples[name] = new_samples # Log if outliers were excluded n_excluded = len(means_arr) - n_inliers if n_excluded > 0: logger.debug( f"Robust CMC: {name} excluded {n_excluded}/{len(means_arr)} " f"outlier shards (MAD-based detection)" ) elif method == "consensus_mc": # CORRECT Consensus Monte Carlo (Scott et al., 2016): # Combine posterior moments, then generate new samples combined_samples = {} # P2-R5-02: Use chunk_seed offset so hierarchical combination chunks # draw independent samples (not all from the same RNG state). rng = np.random.default_rng(42 + chunk_seed) # Deterministic + chunk-unique for name in param_names: # Compute per-shard posterior mean and variance # P0-1: Skip shards with non-finite samples shard_means = [] shard_variances = [] for s in shard_samples: samples = s.samples[name].flatten() if not np.all(np.isfinite(samples)): n_bad = int(np.sum(~np.isfinite(samples))) logger.warning( f"Consensus MC: shard has {n_bad}/{len(samples)} non-finite " f"samples for '{name}'; excluding shard for this parameter" ) continue shard_means.append(np.mean(samples)) shard_variances.append(np.var(samples, ddof=1)) if not shard_means: raise ValueError( f"All shards excluded for parameter '{name}' due to non-finite samples. " "Cannot combine posteriors - check NUTS divergence diagnostics." ) # Detect degenerate shards: near-zero variance indicates a chain # stuck at its init value — gets precision ~1e10 and dominates. means_arr = np.array(shard_means) vars_arr = np.array(shard_variances) if len(vars_arr) >= 3: median_var = np.median(vars_arr) if median_var > 0: degenerate_mask = vars_arr < (median_var * 1e-6) n_degenerate = int(np.sum(degenerate_mask)) if 0 < n_degenerate < len(vars_arr): logger.warning( f"Consensus MC: {n_degenerate} degenerate shard(s) " f"for '{name}' (variance < 1e-6 * median); excluding" ) keep = ~degenerate_mask means_arr = means_arr[keep] vars_arr = vars_arr[keep] shard_means = means_arr.tolist() shard_variances = vars_arr.tolist() # Precision-weighted combination # Combined precision = sum of precisions precisions = [1.0 / max(v, 1e-10) for v in shard_variances] combined_precision = sum(precisions) # P2-3: Floor combined_variance to prevent point-mass posteriors # when degenerate shards produce near-zero variance. combined_variance = max(1.0 / combined_precision, 1e-12) # Combined mean = (combined_variance) * sum(precision_s * mean_s) weighted_mean_sum = sum( p * m for p, m in zip(precisions, shard_means, strict=False) ) combined_mean = combined_variance * weighted_mean_sum # Generate new samples from the combined Gaussian # Shape: (n_chains, n_samples) combined_std = np.sqrt(combined_variance) new_samples = rng.normal( loc=combined_mean, scale=combined_std, size=(n_chains, n_samples), ) combined_samples[name] = new_samples elif method == "simple_average": warnings.warn( "combination_method='simple_average' is deprecated since v2.12.0 " "and will be removed in v3.0. Use 'consensus_mc' instead.", DeprecationWarning, stacklevel=2, ) # Legacy: Simple element-wise average across shards (deprecated) combined_samples = {} for name in param_names: all_shard_samples = [s.samples[name] for s in shard_samples] # NaN-safe: NUTS can produce NaN samples from divergent transitions combined_samples[name] = np.nanmean(all_shard_samples, axis=0) else: # weighted_gaussian (legacy default) warnings.warn( "combination_method='weighted_gaussian' is deprecated since v2.12.0 " "and will be removed in v3.0. Use 'consensus_mc' instead.", DeprecationWarning, stacklevel=2, ) # Legacy: Element-wise weighted averaging (deprecated) # WARNING: This is mathematically incorrect but kept for backward compatibility combined_samples = {} for name in param_names: all_shard_samples = [s.samples[name] for s in shard_samples] # NaN-safe: NUTS can produce NaN samples from divergent transitions variances = [float(np.nanvar(s)) for s in all_shard_samples] precisions = [1.0 / max(v, 1e-10) for v in variances] total_precision = sum(precisions) weights = [p / total_precision for p in precisions] weighted_sum = sum( w * s for w, s in zip(weights, all_shard_samples, strict=False) ) combined_samples[name] = weighted_sum # Combine extra fields combined_extra: dict[str, Any] = {} for key in shard_samples[0].extra_fields.keys(): all_extra = [ s.extra_fields.get(key) for s in shard_samples if key in s.extra_fields ] if all_extra: try: # Handle scalar fields (zero-dimensional arrays) if all_extra[0].ndim == 0: # Stack scalars into a 1D array combined_extra[key] = np.stack(all_extra, axis=0) else: # Concatenate arrays along the chain dimension (axis=0) combined_extra[key] = np.concatenate(all_extra, axis=0) except Exception as e: # Fallback for incompatible shapes logger.warning(f"Failed to combine extra field '{key}': {e}") combined_extra[key] = all_extra[0] # Track total shards for correct divergence rate calculation total_shards = sum(getattr(s, "num_shards", 1) for s in shard_samples) return MCMCSamples( samples=combined_samples, param_names=param_names, n_chains=n_chains, n_samples=n_samples, extra_fields=combined_extra, num_shards=total_shards, )
[docs] def combine_shard_samples_bimodal( shard_samples: list[MCMCSamples], cluster_assignments: tuple[list[int], list[int]], bimodal_detections: list[dict[str, Any]], modal_params: list[str], co_occurrence: dict[str, Any], method: str = "consensus_mc", chunk_seed: int = 0, ) -> tuple[MCMCSamples, BimodalConsensusResult]: """Combine shard samples using mode-aware consensus. For bimodal shards, uses per-component GMM statistics instead of full-posterior statistics to avoid density-trough corruption. Parameters ---------- shard_samples : list[MCMCSamples] All successful shard samples. cluster_assignments : tuple[list[int], list[int]] (lower_cluster_shards, upper_cluster_shards) from cluster_shard_modes(). Bimodal shards may appear in both lists. bimodal_detections : list[dict[str, Any]] Per-detection records with "shard", "param", "mode1", "mode2", "std1", "std2", "weights". modal_params : list[str] Parameters that triggered bimodal detection. co_occurrence : dict[str, Any] Cross-parameter co-occurrence info. method : str Base combination method for non-modal params. Returns ------- tuple[MCMCSamples, BimodalConsensusResult] (combined_samples, bimodal_result) where combined_samples has mixture-drawn primary samples and bimodal_result has per-mode details. """ import numpy as np from homodyne.optimization.cmc.diagnostics import ( BimodalConsensusResult, ModeCluster, ) from homodyne.optimization.cmc.sampler import MCMCSamples cluster_lower, cluster_upper = cluster_assignments n_total = len(shard_samples) param_names = shard_samples[0].param_names n_chains = shard_samples[0].n_chains n_samples = shard_samples[0].n_samples rng = np.random.default_rng(42 + chunk_seed) # Index bimodal detections by (shard, param) for fast lookup bimodal_index: dict[tuple[int, str], dict[str, Any]] = {} for det in bimodal_detections: bimodal_index[(det["shard"], det["param"])] = det modal_set = set(modal_params) def _consensus_for_cluster( cluster_shards: list[int], is_lower: bool, ) -> dict[str, tuple[float, float]]: """Compute consensus (mean, std) for each param in a cluster. Returns dict of {param: (combined_mean, combined_std)}. """ result: dict[str, tuple[float, float]] = {} for name in param_names: shard_means: list[float] = [] shard_variances: list[float] = [] for shard_idx in cluster_shards: key = (shard_idx, name) if name in modal_set and key in bimodal_index: # Bimodal shard + modal param: use component-level stats det = bimodal_index[key] m1, m2 = det["mode1"], det["mode2"] s1, s2 = det["std1"], det["std2"] lo, hi = sorted([(m1, s1), (m2, s2)], key=lambda x: x[0]) if is_lower: shard_means.append(lo[0]) shard_variances.append(lo[1] ** 2) else: shard_means.append(hi[0]) shard_variances.append(hi[1] ** 2) else: # Unimodal shard or non-modal param: use full posterior samples = shard_samples[shard_idx].samples[name].flatten() # NaN-safe: NUTS can produce NaN samples from divergent transitions if not np.all(np.isfinite(samples)): n_bad = int(np.sum(~np.isfinite(samples))) logger.warning( f"Bimodal CMC: shard {shard_idx} has " f"{n_bad}/{len(samples)} non-finite samples " f"for '{name}'; using finite subset" ) finite_samples = samples[np.isfinite(samples)] if finite_samples.size == 0: continue shard_means.append(float(np.mean(finite_samples))) shard_variances.append( float( np.var(finite_samples, ddof=1) if finite_samples.size > 1 else 0.0 ) ) else: shard_means.append(float(np.mean(samples))) shard_variances.append(float(np.var(samples, ddof=1))) # Exclude degenerate shards (near-zero variance) before consensus if len(shard_variances) >= 3: median_var = float(np.median(shard_variances)) degenerate_threshold = 1e-6 * median_var if median_var > 0 else 1e-30 valid_mask = [v >= degenerate_threshold for v in shard_variances] shard_means = [ m for m, ok in zip(shard_means, valid_mask, strict=True) if ok ] shard_variances = [ v for v, ok in zip(shard_variances, valid_mask, strict=True) if ok ] if len(shard_means) < 3: # Too few shards: use simple mean combined_mean = float(np.mean(shard_means)) if shard_means else 0.0 combined_std = ( float(np.std(shard_means)) if len(shard_means) > 1 else 1e-6 ) else: # Precision-weighted consensus precisions = [1.0 / max(v, 1e-10) for v in shard_variances] combined_precision = sum(precisions) combined_variance = 1.0 / combined_precision weighted_mean_sum = sum( p * m for p, m in zip(precisions, shard_means, strict=False) ) combined_mean = combined_variance * weighted_mean_sum combined_std = float(np.sqrt(combined_variance)) result[name] = (float(combined_mean), float(combined_std)) return result # Run per-mode consensus lower_stats = _consensus_for_cluster(cluster_lower, is_lower=True) upper_stats = _consensus_for_cluster(cluster_upper, is_lower=False) # Build mode weights # For bimodal shards that appear in both clusters, count them once total unique_shards = set(cluster_lower) | set(cluster_upper) w_lower = len(cluster_lower) / max(len(unique_shards), 1) w_upper = len(cluster_upper) / max(len(unique_shards), 1) # Normalize (bimodal shards counted in both lists inflate the sum) total_w = w_lower + w_upper w_lower /= total_w w_upper /= total_w # Generate per-mode samples n_lower_samples = int(round(w_lower * n_samples)) n_upper_samples = n_samples - n_lower_samples lower_samples: dict[str, np.ndarray] = {} upper_samples: dict[str, np.ndarray] = {} combined_samples: dict[str, np.ndarray] = {} for name in param_names: lo_mean, lo_std = lower_stats[name] up_mean, up_std = upper_stats[name] lower_samples[name] = rng.normal( loc=lo_mean, scale=max(lo_std, 1e-10), size=(n_chains, n_lower_samples), ) upper_samples[name] = rng.normal( loc=up_mean, scale=max(up_std, 1e-10), size=(n_chains, n_upper_samples), ) # Mixture-draw: concatenate and shuffle within each chain mixed = np.concatenate([lower_samples[name], upper_samples[name]], axis=1) for c in range(n_chains): rng.shuffle(mixed[c]) combined_samples[name] = mixed # Build ModeCluster objects with independent draws from each mode's consensus # Gaussian. These are separate from the mixture-drawn primary samples above; # they provide full per-mode sample sets for downstream analysis. mode_lower = ModeCluster( mean={n: lower_stats[n][0] for n in param_names}, std={n: lower_stats[n][1] for n in param_names}, weight=w_lower, n_shards=len(cluster_lower), samples={ n: rng.normal( loc=lower_stats[n][0], scale=max(lower_stats[n][1], 1e-10), size=(n_chains, n_samples), ) for n in param_names }, ) mode_upper = ModeCluster( mean={n: upper_stats[n][0] for n in param_names}, std={n: upper_stats[n][1] for n in param_names}, weight=w_upper, n_shards=len(cluster_upper), samples={ n: rng.normal( loc=upper_stats[n][0], scale=max(upper_stats[n][1], 1e-10), size=(n_chains, n_samples), ) for n in param_names }, ) bimodal_result = BimodalConsensusResult( modes=[mode_lower, mode_upper], modal_params=modal_params, co_occurrence=co_occurrence, ) # Combine extra fields from all shards combined_extra: dict[str, Any] = {} for key in shard_samples[0].extra_fields.keys(): all_extra = [ s.extra_fields.get(key) for s in shard_samples if key in s.extra_fields ] if all_extra: try: if all_extra[0].ndim == 0: combined_extra[key] = np.stack(all_extra, axis=0) else: combined_extra[key] = np.concatenate(all_extra, axis=0) except Exception as e: logger.warning(f"Failed to combine extra field '{key}': {e}") combined_extra[key] = all_extra[0] combined = MCMCSamples( samples=combined_samples, param_names=param_names, n_chains=n_chains, n_samples=n_samples, extra_fields=combined_extra, num_shards=n_total, ) return combined, bimodal_result