Homodyne CMC (Consensus Monte Carlo) Fitting Architecture¶
Complete documentation of the CMC (Consensus Monte Carlo) fitting system in homodyne.
Version: 2.23.2 Last Updated: May 2026
Table of Contents¶
High-Level Architecture¶
┌─────────────────────────────────────────────────────────────────────────────────┐
│ USER ENTRY POINTS │
│ │
│ CLI: homodyne --method cmc API: fit_mcmc_jax(data, config) │
│ │ │ │
│ ▼ │ │
│ ┌────────────────────────┐ │ │
│ │ AUTOMATIC NLSQ WARMUP │ (v2.20.0) │ │
│ │ fit_nlsq_jax() first │◄────────────────────────┤ (optional nlsq_result) │
│ │ unless --no-nlsq-... │ │ │
│ └────────────────────────┘ │ │
│ │ │ │
│ └──────────────────┬──────────────────────┘ │
│ │ │
│ ▼ │
│ fit_mcmc_jax(data, config, nlsq_result=...) │
│ (core.py) │
└─────────────────────────────────────────────────────────────────────────────────┘
│
════════════════════════════════════╪══════════════════════════════════════════════
▼
┌─────────────────────────────────────────────────────────────────────────────────┐
│ 1. DATA PREPARATION │
│ (data_prep.py) │
│ │
│ Validation → Diagonal Filtering → Noise Estimation │
└─────────────────────────────────────────────────────────────────────────────────┘
│
════════════════════════════════════╪══════════════════════════════════════════════
▼
┌─────────────────────────────────────────────────────────────────────────────────┐
│ 2. SHARDING DECISION │
│ │
│ Auto shard size → Stratified or Random sharding │
│ │
│ Single Shard ◄──────────────► Multiple Shards │
└─────────────────────────────────────────────────────────────────────────────────┘
│
════════════════════════════════════╪══════════════════════════════════════════════
▼
┌─────────────────────────────────────────────────────────────────────────────────┐
│ 3. BACKEND SELECTION │
│ │
│ MultiprocessingBackend │ PjitBackend │ PBSBackend │
└─────────────────────────────────────────────────────────────────────────────────┘
│
════════════════════════════════════╪══════════════════════════════════════════════
▼
┌─────────────────────────────────────────────────────────────────────────────────┐
│ 4. PER-SHARD NUTS SAMPLING │
│ (sampler.py) │
│ │
│ Z-Space Transform → Preflight → NUTS (dense_mass) → Sample Extraction │
└─────────────────────────────────────────────────────────────────────────────────┘
│
════════════════════════════════════╪══════════════════════════════════════════════
▼
┌─────────────────────────────────────────────────────────────────────────────────┐
│ 5. SAMPLE COMBINATION │
│ (backends/base.py) │
│ │
│ Bimodal Detection → Mode-Aware or Standard Consensus MC → Combined Posterior │
└─────────────────────────────────────────────────────────────────────────────────┘
│
════════════════════════════════════╪══════════════════════════════════════════════
▼
┌─────────────────────────────────────────────────────────────────────────────────┐
│ 6. RESULT & DIAGNOSTICS │
│ (results.py) │
│ │
│ R-hat │ ESS │ Divergences → CMCResult + ArviZ │
└─────────────────────────────────────────────────────────────────────────────────┘
1. Entry Point & Orchestration¶
File: core.py
fit_mcmc_jax() Signature¶
def fit_mcmc_jax(
data: np.ndarray, # Pooled C2 correlation (n_total,)
t1, t2: np.ndarray, # Time coordinates
phi: np.ndarray, # Phi angles
q, L: float, # Physics parameters
analysis_mode: str, # "static" or "laminar_flow"
method: str = "mcmc",
cmc_config: dict | None = None,
initial_values: dict | None = None,
parameter_space: ParameterSpace | None = None,
dt: float | None = None,
output_dir: Path | str | None = None,
progress_bar: bool = True,
run_id: str | None = None,
nlsq_result: dict | None = None, # NLSQ warm-start (reduces divergences)
**kwargs, # Forward compatibility
) -> CMCResult
Orchestration Flow¶
┌───────────────────────────────────────────────────────────────────────────┐
│ fit_mcmc_jax() Orchestration │
│ │
│ 1. Normalize analysis mode, generate run_id │
│ 2. Validate and prepare pooled data │
│ └─ prepare_mcmc_data() → filter diagonals, estimate noise │
│ 3. Determine shard size via _resolve_max_points_per_shard() │
│ 4. Construct time_grid with proper dt spacing │
│ 5. Shard data if needed: │
│ ├─ shard_data_stratified() for multiple phi angles │
│ └─ shard_data_random() for single phi angle │
│ 6. Select backend: multiprocessing (default), pjit, or pbs │
│ 7. Execute: │
│ ├─ If shards: backend.run() → parallel NUTS → combine │
│ └─ If single: run_nuts_sampling() directly │
│ 8. Create CMCResult with diagnostics │
└───────────────────────────────────────────────────────────────────────────┘
2. Data Preparation & Sharding¶
File: data_prep.py
Data Validation (prepare_mcmc_data)¶
┌───────────────────────────────────────────────────────────────────────────┐
│ prepare_mcmc_data() │
│ │
│ • Ensure pooled arrays are 1D with matching lengths │
│ • Check for NaN/Inf values │
│ • Filter diagonal points (t1 == t2) - autocorrelation artifacts (v2.14.2)│
│ • Extract unique phi angles and create index mapping │
│ • Estimate noise using robust MAD (Median Absolute Deviation) │
│ • Return PreparedData dataclass │
└───────────────────────────────────────────────────────────────────────────┘
3. Auto Shard Size Selection¶
Function: _resolve_max_points_per_shard() in core.py
Critical Design Principles (v2.20.0)¶
Minimum Shard Size Enforcement:
laminar_flow: 3,000 points minimum (reparameterization fixes bimodal posteriors)
static: 5,000 points minimum
Dynamic max_shards Scaling:
┌───────────────────────────────────────────────────────────────────────────┐
│ max_shards by Dataset Size (v2.20.0) │
│ │
│ Dataset Size │ max_shards │ Rationale │
│ ────────────────┼────────────┼────────────────────────────────────────── │
│ < 10M points │ 2,000 │ Standard parallel workload │
│ 10M - 100M │ 10,000 │ Balanced shard count │
│ 100M - 1B │ 50,000 │ High parallelism for large datasets │
│ 1B+ │ 100,000 │ Extreme scale support │
└───────────────────────────────────────────────────────────────────────────┘
Angle-Aware Scaling¶
The shard size is adjusted based on the number of phi angles to ensure each shard has sufficient data per angle:
┌───────────────────────────────────────────────────────────────────────────┐
│ Angle Factor by n_phi │
│ │
│ n_phi │ angle_factor │ Effect on Shard Size │
│ ───────┼──────────────┼───────────────────────────────────────────────── │
│ ≤ 3 │ 0.6 │ 40% reduction (ensures coverage per angle) │
│ 4-5 │ 0.7 │ 30% reduction │
│ 6-10 │ 0.85 │ 15% reduction │
│ > 10 │ 1.0 │ No reduction (many angles spread data) │
└───────────────────────────────────────────────────────────────────────────┘
Decision Logic by Mode¶
┌───────────────────────────────────────────────────────────────────────────┐
│ LAMINAR FLOW MODE (7 parameters, complex gradients) │
│ │
│ Dataset Size │ Base Size │ After n_phi≤3 │ Est. Shards │ Per-Shard │
│ ────────────────┼───────────┼───────────────┼─────────────┼───────────── │
│ < 2M points │ 8K │ 4.8K │ ~400 │ ~1-2 min │
│ 2M - 50M │ 5K │ 3K │ 600-16K │ ~1 min │
│ 50M - 100M │ 5K │ 3K │ 10K-20K │ ~1 min │
│ 100M - 1B │ 8K │ 4.8K │ 20K-50K │ <1 min │
│ 1B+ │ 10K │ 6K │ 100K+ │ <1 min │
│ │
│ MINIMUM ENFORCED: 3,000 points per shard (reparameterization fixes bimodal posteriors) │
└───────────────────────────────────────────────────────────────────────────┘
┌───────────────────────────────────────────────────────────────────────────┐
│ STATIC MODE (3 parameters, simple gradients) │
│ │
│ Dataset Size │ max_points_per_shard │ Est. Shards │
│ ────────────────┼──────────────────────┼────────────── │
│ < 50M points │ 10K │ ~5K │
│ 50M - 100M │ 15K │ ~3K-7K │
│ 100M+ │ 20K │ ~5K+ │
│ │
│ MINIMUM ENFORCED: 5,000 points per shard │
└───────────────────────────────────────────────────────────────────────────┘
Memory Capping¶
┌───────────────────────────────────────────────────────────────────────────┐
│ Memory Capping Logic │
│ │
│ Per-shard result size: ~100KB │
│ (13 params × 2 chains × 1500 samples × 8 bytes) │
│ │
│ Dynamic max_shards by dataset size (see table above) │
│ │
│ If shard count would exceed max_shards: │
│ • Increases shard size (no subsampling, all data used) │
│ • For laminar_flow: caps adjusted shard size at 50K max │
│ • Enforces minimum shard size (3K laminar_flow, 5K static) │
│ │
│ Platform Scaling (based on dynamic max_shards): │
│ ├─ 3M dataset: ~600-1000 shards (manageable on personal systems) │
│ ├─ 100M dataset: ~10K-20K shards (requires cluster) │
│ └─ 1B dataset: ~100K shards (extreme scale, HPC required) │
└───────────────────────────────────────────────────────────────────────────┘
4. Time Grid Construction¶
Critical Fix (December 2025)
┌───────────────────────────────────────────────────────────────────────────┐
│ Time Grid Construction │
│ │
│ Problem: Previously used np.unique(t1, t2) which gave incorrect grid │
│ density when data is subsampled │
│ │
│ Solution: Construct time_grid explicitly with config dt spacing: │
│ │
│ dt_used = dt if dt is not None else inferred_dt │
│ t_max = max(t1_pooled.max(), t2_pooled.max()) │
│ n_time_points = int(round(t_max / dt_used)) + 1 │
│ time_grid = np.linspace(0.0, t_max, n_time_points) │
│ │
│ This ensures physics integration (trapezoidal cumsum) uses correct │
│ grid density matching NLSQ │
└───────────────────────────────────────────────────────────────────────────┘
5. Physics Model¶
File: model.py
Five Model Variants (v2.22.2)¶
| Model | Purpose | Per-Angle Mode | Params (laminar_flow, 23 angles) |
|——-|———|—————-|———————————-| |
xpcs_model_scaled() | Gradient-balanced z-space sampling | individual | 54 (46
per-angle + 7 physical + 1 σ) | | xpcs_model_constant() | Fixed per-angle scaling (not
sampled) | constant | 8 (7 physical + 1 σ) | | xpcs_model_averaged() | Sampled
averaged contrast/offset | auto | 10 (2 averaged + 7 physical + 1 σ) | |
xpcs_model_constant_averaged() | Fixed averaged scaling (NLSQ parity) |
constant_averaged | 8 (7 physical + 1 σ) | | xpcs_model_reparameterized() |
Reparameterized sampling space | auto + reparam | 10 (2 averaged + 7 physical + 1 σ) |
Per-Angle Mode Selection (v2.22.2)¶
┌───────────────────────────────────────────────────────────────────────────┐
│ Per-Angle Mode Decision (get_effective_per_angle_mode) │
│ │
│ Priority: nlsq_per_angle_mode > explicit config > auto-selection │
│ │
│ 1. NLSQ warm-start present with per_angle_mode? │
│ ├─ YES + both sides "auto": → "constant_averaged" │
│ │ • Fixes scaling for stability (fewer sampled params) │
│ │ • Uses xpcs_model_constant_averaged() │
│ │ │
│ └─ YES + explicit mode: → match NLSQ mode │
│ • Ensures parameterization parity │
│ │
│ 2. per_angle_mode = "auto" (default)? │
│ ├─ n_phi >= threshold (3): → "auto" │
│ │ • Uses xpcs_model_averaged() │
│ │ • SAMPLES single averaged contrast + offset │
│ │ • 10 params (2 averaged + 7 physical + 1 sigma) │
│ │ • If use_reparameterization=True: │
│ │ Uses xpcs_model_reparameterized() instead │
│ │ │
│ └─ n_phi < threshold (3): → "individual" │
│ • Uses xpcs_model_scaled() │
│ • Samples per-angle contrast/offset │
│ • 8 + 2×n_phi sampled params │
│ │
│ 3. Explicit mode: │
│ ├─ "constant": xpcs_model_constant() │
│ │ • Fixed per-angle values (not averaged, not sampled) │
│ │ • 8 params (7 physical + 1 sigma) │
│ │ │
│ ├─ "constant_averaged": xpcs_model_constant_averaged() │
│ │ • Fixed averaged values (NLSQ parity, not sampled) │
│ │ • 8 params (7 physical + 1 sigma) │
│ │ │
│ └─ "individual": xpcs_model_scaled() │
│ • Sample per-angle contrast/offset │
│ • 8 + 2×n_phi sampled params │
└───────────────────────────────────────────────────────────────────────────┘
Key Distinction: Five Model Modes
| Mode | Scaling Behavior | Values Sampled? | Params (laminar_flow) |
|——|——————|—————–|———————–| | auto (n_phi ≥
3) | Single averaged contrast/offset | Yes (sampled) | 10 | | auto + reparam |
Single averaged + log-space transforms | Yes (sampled) | 10 | | constant |
Per-angle from quantile estimation | No (fixed per-angle) | 8 | |
constant_averaged | Averaged from quantile estimation | No (fixed averaged) | 8 |
| individual | Per-angle contrast/offset | Yes (sampled per-angle) | 8 + 2×n_phi |
xpcs_model_scaled() Structure¶
┌───────────────────────────────────────────────────────────────────────────┐
│ Sampling Order (CRITICAL for init_to_value to work) │
│ │
│ 1. PER-ANGLE CONTRAST PARAMETERS (FIRST) │
│ for i in range(n_phi): │
│ contrast_i = sample_scaled_parameter(f"contrast_{i}", scaling) │
│ │
│ 2. PER-ANGLE OFFSET PARAMETERS (SECOND) │
│ for i in range(n_phi): │
│ offset_i = sample_scaled_parameter(f"offset_{i}", scaling) │
│ │
│ 3. PHYSICAL PARAMETERS (THIRD) │
│ Static: D0, alpha, D_offset │
│ Laminar flow: + gamma_dot_t0, beta, gamma_dot_t_offset, phi0 │
│ │
│ 4. NOISE PARAMETER (FOURTH) │
│ sigma ~ HalfNormal(scale=noise_scale × 3.0) │
└───────────────────────────────────────────────────────────────────────────┘
xpcs_model_constant() Structure (v2.18.0)¶
The constant model takes fixed contrast/offset arrays instead of sampling them:
┌───────────────────────────────────────────────────────────────────────────┐
│ xpcs_model_constant() - Fixed Per-Angle Scaling │
│ │
│ REQUIRED INPUTS (not sampled): │
│ fixed_contrast: jnp.ndarray (n_phi,) - Pre-computed contrast values │
│ fixed_offset: jnp.ndarray (n_phi,) - Pre-computed offset values │
│ │
│ SAMPLING ORDER (8 parameters total): │
│ │
│ 1. PHYSICAL PARAMETERS (FIRST) │
│ Static: D0, alpha, D_offset │
│ Laminar flow: + gamma_dot_t0, beta, gamma_dot_t_offset, phi0 │
│ │
│ 2. NOISE PARAMETER (SECOND) │
│ sigma ~ HalfNormal(scale=noise_scale × 3.0) │
│ │
│ NO per-angle contrast/offset sampling - they are FIXED from quantiles │
└───────────────────────────────────────────────────────────────────────────┘
Quantile-Based Scaling Estimation (core.py):
from homodyne.core.scaling_utils import estimate_per_angle_scaling
# Estimate per-angle contrast/offset from raw C2 data
estimates = estimate_per_angle_scaling(
c2_data, t1, t2, phi_indices, n_phi,
contrast_bounds, offset_bounds
)
# Returns: {"contrast_0": 0.4, "offset_0": 0.95, ...}
# For AUTO mode: average the estimates
if config.per_angle_mode == "auto":
contrast_avg = np.mean([estimates[f"contrast_{i}"] for i in range(n_phi)])
offset_avg = np.mean([estimates[f"offset_{i}"] for i in range(n_phi)])
fixed_contrast = np.full(n_phi, contrast_avg) # Same for all angles
fixed_offset = np.full(n_phi, offset_avg)
else:
# CONSTANT mode: use per-angle estimates directly
fixed_contrast = np.array([estimates[f"contrast_{i}"] for i in range(n_phi)])
fixed_offset = np.array([estimates[f"offset_{i}"] for i in range(n_phi)])
xpcs_model_averaged() Structure (v2.22.2)¶
The averaged model samples a single averaged contrast and offset, then broadcasts to
all angles. This is the default model for auto mode (n_phi >= 3).
┌───────────────────────────────────────────────────────────────────────────┐
│ xpcs_model_averaged() - Sampled Averaged Scaling (auto mode) │
│ │
│ SAMPLING ORDER (10 parameters total for laminar_flow): │
│ │
│ 1. AVERAGED SCALING PARAMETERS (FIRST, sampled) │
│ contrast_avg ~ TruncatedNormal (single value, broadcast to n_phi) │
│ offset_avg ~ TruncatedNormal (single value, broadcast to n_phi) │
│ │
│ 2. PHYSICAL PARAMETERS (SECOND) │
│ Static: D0, alpha, D_offset │
│ Laminar flow: + gamma_dot_t0, beta, gamma_dot_t_offset, phi0 │
│ │
│ 3. NOISE PARAMETER (THIRD) │
│ sigma ~ HalfNormal(scale=noise_scale × 3.0) │
│ │
│ KEY: 10 params vs 54 for individual mode (laminar_flow, 23 angles) │
│ Prevents per-angle parameter absorption degeneracy │
└───────────────────────────────────────────────────────────────────────────┘
xpcs_model_constant_averaged() Structure (v2.22.2)¶
Uses fixed averaged contrast/offset (not sampled). Provides exact NLSQ parity when warm-start is available. Selected automatically when both CMC and NLSQ use “auto” mode with NLSQ warm-start present.
┌───────────────────────────────────────────────────────────────────────────┐
│ xpcs_model_constant_averaged() - Fixed Averaged Scaling (NLSQ parity) │
│ │
│ REQUIRED INPUTS (not sampled): │
│ fixed_contrast: jnp.ndarray (n_phi,) → averaged to single value │
│ fixed_offset: jnp.ndarray (n_phi,) → averaged to single value │
│ │
│ SAMPLING ORDER (8 parameters total): │
│ │
│ 1. PHYSICAL PARAMETERS (FIRST) │
│ Static: D0, alpha, D_offset │
│ Laminar flow: + gamma_dot_t0, beta, gamma_dot_t_offset, phi0 │
│ │
│ 2. NOISE PARAMETER (SECOND) │
│ sigma ~ HalfNormal(scale=noise_scale × 3.0) │
│ │
│ CRITICAL: Fixes the parameter shift issue where xpcs_model_averaged() │
│ samples contrast/offset, introducing extra uncertainty that biases │
│ physical parameters. Uses FIXED averaged values for NLSQ parity. │
└───────────────────────────────────────────────────────────────────────────┘
xpcs_model_reparameterized() Structure (v2.23.0)¶
Transforms correlated parameters to orthogonal sampling space for better NUTS exploration:
┌───────────────────────────────────────────────────────────────────────────┐
│ xpcs_model_reparameterized() - Reference-Time Reparameterized Sampling │
│ │
│ REPARAMETERIZATION TRANSFORMS: │
│ D0, alpha → log_D_ref, alpha │
│ where D_ref = D0 * t_ref^alpha (decorrelates) │
│ D_offset → D_offset_ratio = D_offset / D_ref │
│ TruncatedNormal(low=-1+ε): supports negative D_offset│
│ (jammed/arrested systems); ratio ≤ -1 non-physical │
│ gamma_dot_t0 → log_gamma_ref │
│ where gamma_ref = gamma_dot_t0 * t_ref^beta │
│ │
│ SAMPLING ORDER (10 parameters total for laminar_flow): │
│ │
│ 1. AVERAGED SCALING (sampled, as in xpcs_model_averaged) │
│ 2. REPARAMETERIZED PHYSICAL PARAMETERS │
│ log_D_ref, alpha, D_offset_ratio │
│ + log_gamma_ref, beta, gamma_dot_t_offset, phi0 (laminar_flow) │
│ 3. DETERMINISTIC TRANSFORMS (in trace for output) │
│ D0, D_offset, gamma_dot_t0 computed from reparameterized values │
│ 4. NOISE PARAMETER │
│ sigma ~ HalfNormal(scale=noise_scale × 1.5) │
│ │
│ EXTRA INPUT: t_ref (reference time, from compute_t_ref(dt, t_max)) │
│ EXTRA INPUT: reparam_config (ReparamConfig, optional) │
└───────────────────────────────────────────────────────────────────────────┘
Physics Computation¶
# Compute g1 using exact same physics as NLSQ
g1_all_phi = compute_g1_total(params, t1, t2, phi_unique, q, L, dt, time_grid)
# g1_all_phi shape: (n_phi, n_points)
# Map per-point g1 using phi indices
g1_per_point = g1_all_phi[phi_indices, point_idx]
# Apply per-angle scaling
c2_theory = contrast[phi_idx] × g1² + offset[phi_idx]
Likelihood¶
sigma = numpyro.sample("sigma", dist.HalfNormal(scale=sigma_scale))
numpyro.sample("obs", dist.Normal(c2_theory, sigma), obs=data)
6. Gradient Balancing (Z-Space)¶
File: scaling.py
The Problem¶
Parameters span vastly different magnitudes:
D0: ~10⁴ (diffusion)
alpha: ~10⁰ (exponent)
gamma_dot_t0: ~10⁻³ (shear)
contrast: ~10⁻¹ (optical)
This causes 10⁶:1 gradient imbalance → 0% NUTS acceptance rate
The Solution: Non-Centered Reparameterization¶
┌───────────────────────────────────────────────────────────────────────────┐
│ Z-Space Transformation │
│ │
│ Sample in z-space: z ~ Normal(0, 1) │
│ Transform to original: P = center + scale × z │
│ Apply smooth bounds: P = smooth_bound(P, low, high) using tanh │
│ │
│ Smooth bounding (avoids hard clipping artifacts): │
│ smooth_bound(x) = mid + (half × tanh((x - mid) / half)) │
│ │
│ ParameterScaling dataclass: │
│ name: str # Parameter name │
│ center: float # Midpoint of bounds or prior mean │
│ scale: float # (high - low) / 4 or prior std │
│ low: float # Lower bound │
│ high: float # Upper bound │
│ │
│ Key methods: │
│ to_normalized(value): Original space → z-space (for initialization) │
│ to_original(z_value): z-space → original space (in model) │
└───────────────────────────────────────────────────────────────────────────┘
7. NUTS Sampling¶
File: sampler.py
run_nuts_sampling() Workflow¶
┌─────────────────────────────────────────────────────────────────────────────────┐
│ NUTS SAMPLING WORKFLOW │
│ │
│ ┌───────────────────────────────────────────────────────────────────────────┐ │
│ │ 1. PARAMETER ORDER RESOLUTION │ │
│ │ get_param_names_in_order(n_phi, analysis_mode) │ │
│ │ → [contrast_0, ..., offset_0, ..., D0, alpha, ...] │ │
│ │ + "sigma" as final parameter │ │
│ └───────────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌───────────────────────────────────────────────────────────────────────────┐ │
│ │ 2. MCMC-SAFE D0 CHECK (_compute_mcmc_safe_d0) │ │
│ │ • Detects if initial D0 causes g1 → 0 everywhere (vanishing gradients) │ │
│ │ • Computes scaled D0 that produces g1 ≈ 0.5 at typical time lag │ │
│ │ • Ensures gradients are alive for NUTS exploration │ │
│ └───────────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌───────────────────────────────────────────────────────────────────────────┐ │
│ │ 3. BUILD FULL INITIAL VALUES │ │
│ │ build_init_values_dict() from priors.py │ │
│ │ • Data-driven contrast/offset estimation │ │
│ │ • Combine with config initial_values │ │
│ └───────────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌───────────────────────────────────────────────────────────────────────────┐ │
│ │ 4. GRADIENT BALANCING TRANSFORMATION │ │
│ │ scalings = compute_scaling_factors(parameter_space, n_phi, mode) │ │
│ │ z_space_init = transform_initial_values_to_z(full_init, scalings) │ │
│ └───────────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌───────────────────────────────────────────────────────────────────────────┐ │
│ │ 5. PREFLIGHT VALIDATION │ │
│ │ _preflight_log_density(model, params, ...) │ │
│ │ • Catches non-finite log density before expensive sampling │ │
│ │ • Validates model can compute gradients at init point │ │
│ └───────────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌───────────────────────────────────────────────────────────────────────────┐ │
│ │ 6. NUTS KERNEL CREATION │ │
│ │ │ │
│ │ kernel = NUTS( │ │
│ │ model, │ │
│ │ init_strategy=init_to_value(values=z_space_init), │ │
│ │ target_accept_prob=0.85, # For laminar_flow mode, automatically │ │
│ │ # elevated to 0.9 if below 0.9 │ │
│ │ dense_mass=True # CRITICAL: Learn cross-correlations │ │
│ │ ) │ │
│ │ │ │
│ │ Why dense_mass=True: │ │
│ │ Diagonal mass matrix can't adapt per-dimension to handle the │ │
│ │ 10⁶:1 gradient imbalance. Dense matrix learns covariance structure │ │
│ │ during warmup. │ │
│ └───────────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌───────────────────────────────────────────────────────────────────────────┐ │
│ │ 7. MCMC EXECUTION │ │
│ │ │ │
│ │ mcmc = MCMC( │ │
│ │ kernel, │ │
│ │ num_warmup=500, │ │
│ │ num_samples=1500, │ │
│ │ num_chains=4 │ │
│ │ ) │ │
│ │ mcmc.run( │ │
│ │ rng_key, │ │
│ │ extra_fields=("accept_prob", "diverging", "num_steps"), │ │
│ │ **model_kwargs │ │
│ │ ) │ │
│ └───────────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌───────────────────────────────────────────────────────────────────────────┐ │
│ │ 8. COMPUTATION BLOCKING (Critical for proper timing) │ │
│ │ │ │
│ │ jax.block_until_ready(last_state) │ │
│ │ │ │
│ │ Without this, lazy evaluation delays computation to device_get(), │ │
│ │ causing misleading timing measurements │ │
│ └───────────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌───────────────────────────────────────────────────────────────────────────┐ │
│ │ 9. SAMPLE EXTRACTION │ │
│ │ │ │
│ │ • Extract in group_by_chain format: (n_chains, n_samples) │ │
│ │ • Compute per-shard diagnostics (accept_prob, divergences, step_size) │ │
│ │ • Return (MCMCSamples, SamplingStats) │ │
│ └───────────────────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────────────────────┘
Key Dataclasses¶
@dataclass
class SamplingStats:
warmup_time: float
sampling_time: float
total_time: float
num_divergent: int
accept_prob: float
step_size: float
step_size_min: float | None
step_size_max: float | None
inverse_mass_matrix_summary: str | None
tree_depth: float
@dataclass
class MCMCSamples:
samples: dict[str, np.ndarray] # {name: (n_chains, n_samples)}
param_names: list[str]
n_chains: int
n_samples: int
extra_fields: dict[str, Any] # diverging, accept_prob, etc.
num_shards: int = 1 # For correct divergence rate in CMC
bimodal_consensus: Any = None # BimodalConsensusResult (Feb 2026)
8. Backend Execution¶
File: backends/
Backend Selection¶
def select_backend(config: CMCConfig) -> CMCBackend:
backend_name = config.backend_name
if backend_name == "auto":
backend_name = "multiprocessing" # Default for CPU
if backend_name == "multiprocessing":
return MultiprocessingBackend()
elif backend_name == "pjit":
return PjitBackend() # Sequential JAX execution
elif backend_name == "pbs":
return PBSBackend() # HPC cluster execution
MultiprocessingBackend (Primary)¶
File: backends/multiprocessing.py
┌───────────────────────────────────────────────────────────────────────────┐
│ MultiprocessingBackend Architecture │
│ │
│ 1. Pre-generate all shard PRNG keys in single JAX call │
│ (_generate_shard_keys - amortizes JAX compilation) │
│ │
│ 2. Compute LPT schedule via _compute_lpt_schedule() │
│ Noise-weighted Longest Processing Time ordering: │
│ cost = n_points * (1 + normalized_noise) │
│ Largest/noisiest shards dispatched first → minimize tail latency │
│ │
│ 3. Create worker pool with adaptive thread limiting │
│ physical_cores = psutil.cpu_count(logical=False) │
│ threads_per_worker = physical_cores // n_workers │
│ (see also: backends/worker_pool.py for pool management) │
│ │
│ 4. Load shard data arrays via shared memory │
│ _load_shared_shard_data() / _load_shared_array() │
│ Eliminates per-process serialization overhead │
│ │
│ 5. Submit shards to queue-based workers │
│ │
│ 6. Monitor with progress bar and timeout handling │
│ • Adaptive polling interval (reduces CPU spinning) │
│ • Event.wait() with timeout (efficient heartbeat) │
│ │
│ 7. Combine results via combine_shard_samples() │
└───────────────────────────────────────────────────────────────────────────┘
┌───────────────────────────────────────────────────────────────────────────┐
│ Worker Function (_run_shard_worker) │
│ │
│ • Set XLA environment variables for thread safety │
│ • Import JAX, deserialize parameter_space │
│ • Call run_nuts_sampling() on shard │
│ • Return serialized MCMCSamples │
└───────────────────────────────────────────────────────────────────────────┘
### WorkerPool Manager
**File:** `backends/worker_pool.py` (354 lines)
Manages process pool lifecycle, health monitoring, and graceful shutdown
separate from the dispatch logic in `multiprocessing.py`.
┌───────────────────────────────────────────────────────────────────────────┐ │ Worker Pool Responsibilities │ │ │ │ • Process pool creation with spawn context │ │ • Worker health monitoring (detect hung/dead workers) │ │ • Graceful shutdown with configurable timeout │ │ • Restart of failed workers up to max_restarts limit │ │ • Thread count enforcement (OMP_NUM_THREADS per worker) │ └───────────────────────────────────────────────────────────────────────────┘
### Chain Parallelism (chain_method)
**File:** `sampler.py` (run_nuts_sampling)
```yaml
# Config field in per_shard_mcmc:
chain_method: "parallel" # "parallel" (default) or "sequential"
┌───────────────────────────────────────────────────────────────────────────┐
│ Chain Method Selection │
│ │
│ "parallel" (default): │
│ NumPyro runs num_chains NUTS chains concurrently via JAX vmap. │
│ XLA device count dynamically matches num_chains via │
│ HOMODYNE_CMC_NUM_CHAINS env var. │
│ Fastest on multi-core CPUs where per-chain NUTS is the bottleneck. │
│ │
│ "sequential": │
│ Chains run one at a time. Lower peak memory. │
│ Preferred when per-shard data is very small. │
│ │
│ Auto-fallback (in run_nuts_sampling): │
│ if chain_method == "parallel" and shard_size < 500: │
│ effective_chain_method = "sequential" │
│ Rationale: For tiny shards, NUTS runs in microseconds per step; │
│ the JAX vmap overhead for parallel chains exceeds the compute benefit. │
│ │
│ Config validation (config.py): │
│ chain_method must be in ["parallel", "sequential"] │
└───────────────────────────────────────────────────────────────────────────┘
9. Sample Combination¶
File: backends/base.py
Hierarchical Combination¶
┌───────────────────────────────────────────────────────────────────────────┐
│ Hierarchical Combination for Large Shard Counts │
│ │
│ If K > 500 shards: │
│ 1. Combine in chunks of 500 │
│ 2. Recursively combine chunks │
│ 3. Reduces peak memory from O(K) to O(500) × K/500 pattern │
│ 4. Garbage collection between chunks │
└───────────────────────────────────────────────────────────────────────────┘
Combination Methods¶
╔═══════════════════════════════════════════════════════════════════════════╗
║ CONSENSUS_MC (Default, v2.12.0+) ║
║ Implements Scott et al. (2016) correctly ║
╠═══════════════════════════════════════════════════════════════════════════╣
║ ║
║ For each parameter: ║
║ 1. Compute per-shard mean μ_s and variance σ²_s ║
║ 2. Combined precision: 1/σ²_combined = Σ_s (1/σ²_s) ║
║ 3. Combined mean: μ = σ² × Σ_s (μ_s / σ²_s) ║
║ 4. Generate new samples: N(μ, σ²_combined) ║
║ ║
║ Precision-weighted combination of posterior moments ║
║ ║
║ LIMITATION: Biases toward low-variance shards when heterogeneity exists ║
╚═══════════════════════════════════════════════════════════════════════════╝
╔═══════════════════════════════════════════════════════════════════════════╗
║ MODE-AWARE CONSENSUS MC (v2.22.0+) ║
║ Handles bimodal per-shard posteriors correctly ║
╠═══════════════════════════════════════════════════════════════════════════╣
║ ║
║ Problem: Standard consensus_mc assumes per-shard posteriors are ║
║ approximately Gaussian. When shards have bimodal posteriors ║
║ (e.g., D0 ~19K and ~32K), np.mean falls between modes in the ║
║ density trough, and np.var is inflated by w1*w2*(mu1-mu2)^2. ║
║ ║
║ Algorithm (combine_shard_samples_bimodal): ║
║ 1. Detect bimodal shards via per-shard GMM (check_shard_bimodality) ║
║ 2. Summarize cross-shard bimodality patterns ║
║ 3. Jointly cluster shards into two mode populations ║
║ (range-normalized feature vectors, seeded from cross-shard means) ║
║ 4. For each mode cluster, run precision-weighted consensus: ║
║ • Bimodal shards: use per-component GMM stats (mu, sigma^2) ║
║ • Unimodal shards: use full posterior stats ║
║ 5. Generate mixture-drawn output samples ║
║ 6. Attach BimodalConsensusResult to MCMCSamples.bimodal_consensus ║
║ ║
║ Auto-triggered: When cross-shard bimodal fraction > 5% for any param ║
║ Fallback: <3 shards in a cluster → simple mean instead of consensus ║
║ ║
║ Output: MCMCSamples with mixture samples + BimodalConsensusResult: ║
║ modes: list[ModeCluster] # Per-mode consensus stats + samples ║
║ modal_params: list[str] # Parameters that triggered detection ║
║ co_occurrence: dict # D0-alpha correlation info ║
╚═══════════════════════════════════════════════════════════════════════════╝
╔═══════════════════════════════════════════════════════════════════════════╗
║ ROBUST_CONSENSUS_MC (v2.20.0+) ║
║ Outlier-resistant combination for heterogeneous shards ║
╠═══════════════════════════════════════════════════════════════════════════╣
║ ║
║ Algorithm: ║
║ 1. Compute per-shard means and variances ║
║ 2. Detect outlier shards using MAD (Median Absolute Deviation): ║
║ median_mean = median(shard_means) ║
║ mad = median(|shard_means - median_mean|) ║
║ outlier if |mean - median_mean| > 3 × 1.4826 × mad ║
║ 3. Exclude outlier shards from combination ║
║ 4. Apply standard consensus_mc to retained shards ║
║ ║
║ Use when: High per-shard heterogeneity detected (CV > 0.5) ║
║ Auto-enabled: When heterogeneity_abort=False but CV > threshold ║
╚═══════════════════════════════════════════════════════════════════════════╝
┌───────────────────────────────────────────────────────────────────────────┐
│ WEIGHTED_GAUSSIAN (Deprecated) │
│ Element-wise weighted averaging (mathematically incorrect) │
└───────────────────────────────────────────────────────────────────────────┘
┌───────────────────────────────────────────────────────────────────────────┐
│ SIMPLE_AVERAGE (Deprecated) │
│ Element-wise mean across shards │
└───────────────────────────────────────────────────────────────────────────┘
10. Result Creation & Diagnostics¶
File: results.py
CMCResult Structure¶
@dataclass
class CMCResult:
# Core results
parameters: np.ndarray # Posterior means
uncertainties: np.ndarray # Posterior stds
param_names: list[str]
# MCMC-specific
samples: dict[str, np.ndarray] # {name: (n_chains, n_samples)}
convergence_status: str # "converged", "divergences", "not_converged"
r_hat: dict[str, float]
ess_bulk: dict[str, float]
ess_tail: dict[str, float]
divergences: int
# ArviZ
inference_data: az.InferenceData
# Timing
execution_time: float
warmup_time: float
# Config
n_chains: int = 4
n_samples: int = 2000
n_warmup: int = 500
analysis_mode: str = "static"
per_angle_mode: str = "auto" # Effective per-angle mode used
# CMC-specific
num_shards: int = 1 # For correct divergence rate
# Quality & diagnostics
covariance: np.ndarray | None = None
chi_squared: float | None = None
reduced_chi_squared: float | None = None
device_info: dict | None = None
recovery_actions: list[str] | None = None
quality_flag: str | None = None # "good", "warning", "poor"
# Per-angle parameter statistics
mean_params: ParameterStats = field(default_factory=lambda: ParameterStats([], []))
std_params: ParameterStats = field(default_factory=lambda: ParameterStats([], []))
mean_contrast: np.ndarray | None = None
std_contrast: np.ndarray | None = None
mean_offset: np.ndarray | None = None
std_offset: np.ndarray | None = None
CMCResult.from_mcmc_samples() Workflow¶
┌───────────────────────────────────────────────────────────────────────────┐
│ Result Creation │
│ │
│ 1. Compute R-hat from samples: per-parameter convergence statistic │
│ 2. Compute ESS (bulk & tail): effective sample size │
│ 3. Check convergence thresholds: │
│ ├─ R-hat < 1.1 │
│ ├─ ESS > 400 (min_ess default) │
│ └─ Divergence rate < 10% (max_divergence_rate default) │
│ 4. Aggregate legacy stats (contrast, offset means/stds) │
│ 5. Create ArviZ InferenceData for plotting │
│ 6. Return CMCResult │
│ │
│ CRITICAL FIX (Dec 2025): Excludes _z (z-space) parameters from │
│ legacy stats. Scaled model samples contrast_0_z but registers │
│ contrast_0 as deterministic. │
└───────────────────────────────────────────────────────────────────────────┘
Additional Diagnostics Functions¶
File: diagnostics.py
Beyond R-hat/ESS, diagnostics.py provides post-run analysis utilities:
# Posterior quality vs NLSQ baseline
compute_nlsq_comparison_metrics(
cmc_result: CMCResult,
nlsq_result: dict,
tolerance_sigma: float = 3.0,
) -> dict
# Returns: per-parameter {diff_pct, z_score, status} table
# Flags parameters exceeding tolerance (default 3σ)
# Posterior uncertainty contraction (CMC vs prior)
compute_posterior_contraction(
cmc_result: CMCResult,
prior_std: dict[str, float],
) -> dict[str, float]
# Values near 1.0 = strong data constraint; near 0.0 = prior-dominated
# Precision analysis for parameter absorption detection
compute_precision_analysis(
cmc_result: CMCResult,
nlsq_result: dict | None = None,
) -> dict
# End-of-run structured log summary
log_analysis_summary(
cmc_result: CMCResult,
logger: logging.Logger,
nlsq_result: dict | None = None,
) -> None
# Emits: shard count, divergence rate, R-hat table, NLSQ comparison (if provided)
Convergence Status Determination¶
├─ "converged": All R-hat < 1.1 AND All ESS > 400 AND divergence rate < 10%
├─ "divergences": Divergence rate >= 10%
└─ "not_converged": R-hat >= 1.1 OR ESS < 400
Complete Data Flow¶
fit_mcmc_jax() [core.py]
│
├─> prepare_mcmc_data() [data_prep.py]
│ ├─ Validate pooled arrays
│ ├─ Filter diagonal points (v2.14.2+)
│ └─ Extract phi info & estimate noise
│
├─> _resolve_max_points_per_shard()
│ ├─ Auto-size based on analysis_mode & dataset size
│ └─ Cap shard count to limit memory
│
├─> Construct time_grid (with proper dt spacing)
│
├─> Sharding decision (CMC needed?)
│ │
│ ├─ YES (large dataset, multiple angles):
│ │ ├─> shard_data_stratified() or shard_data_random()
│ │ └─> Create list[PreparedData] for each shard
│ │
│ └─ NO (small dataset):
│ └─> shards = None (single-shard mode)
│
├─> select_backend(config)
│ └─ MultiprocessingBackend() [or pjit/pbs]
│
├─> If shards:
│ │
│ └─> backend.run()
│ ├─ For each shard (parallel via multiprocessing):
│ │ ├─ Worker: _run_shard_worker()
│ │ │ ├─ Set thread limits (XLA_FLAGS, OMP_NUM_THREADS)
│ │ │ ├─ Reconstruct JAX key from tuple
│ │ │ ├─ Call run_nuts_sampling() → MCMCSamples
│ │ │ └─ Queue result
│ │ └─ Progress bar & timeout monitoring
│ │
│ ├─ Bimodal detection (check_shard_bimodality per shard)
│ │ ├─ If significant bimodality detected (>5% fraction):
│ │ │ ├─ summarize_cross_shard_bimodality()
│ │ │ ├─ cluster_shard_modes() → (lower_cluster, upper_cluster)
│ │ │ └─ combine_shard_samples_bimodal()
│ │ │ ├─ Per-mode consensus using component GMM stats
│ │ │ ├─ Mixture-drawn output samples
│ │ │ └─ Attach BimodalConsensusResult to MCMCSamples
│ │ └─ Else: standard combination
│ │
│ └─ Combine results:
│ ├─ Hierarchical combination for K > 500 shards
│ └─ combine_shard_samples(shards, method="robust_consensus_mc")
│ ├─ Per-param: combine means & variances
│ ├─ Generate new samples from combined Gaussian
│ └─ Return combined MCMCSamples
│
└─> Single-shard mode:
└─> run_nuts_sampling(xpcs_model_scaled, ...)
├─ MCMC-safe D0 check
├─ Build init values (data-driven contrast/offset)
├─ Transform to z-space
├─ Preflight validation
├─ NUTS sampling (dense_mass=True)
├─ Extract samples & stats
└─ Return (MCMCSamples, SamplingStats)
│
▼
CMCResult.from_mcmc_samples()
├─ Compute R-hat, ESS
├─ Check convergence thresholds
├─ Create ArviZ InferenceData
└─ Return CMCResult
Quick Reference Tables¶
Auto Shard Size Selection (v2.20.0)¶
Laminar Flow Mode (with n_phi ≤ 3, angle_factor = 0.6)¶
| Dataset Size | Base Size | After Scaling | max_shards | Est. Shards | |————–|———–|—————|————|————-| | < 2M | 8K | 4.8K | 2,000 | ~400 | | 2M - 50M | 5K | 3K | 2,000 | 600-16K | | 50M - 100M | 5K | 3K | 10,000 | 10K-20K | | 100M - 1B | 8K | 4.8K | 50,000 | 20K-50K | | 1B+ | 10K | 6K | 100,000 | 100K+ |
Minimum shard size: 3,000 points (reparameterization fixes bimodal posteriors)
Static Mode¶
| Dataset Size | max_points_per_shard | Est. Shards | |————–|———————|————-| | < 50M | 10K | ~5K | | 50M - 100M | 15K | ~3K-7K | | 100M+ | 20K | ~5K+ |
Minimum shard size: 5,000 points
Dynamic max_shards by Dataset Size¶
| Dataset Size | max_shards | Rationale | |————–|————|———–| | < 10M points | 2,000 | Standard parallel workload | | 10M - 100M | 10,000 | Balanced shard count | | 100M - 1B | 50,000 | High parallelism for large datasets | | 1B+ | 100,000 | Extreme scale support |
Mode-Specific Parameters¶
Individual per-angle mode (23 angles):
| Mode | Physical Params | Per-Angle Params | Total | |——|—————-|——————|——-| | static | 3: D₀, alpha, D_offset | 46: contrast + offset | 49 + sigma | | laminar_flow | 7: + gamma_dot_t0, beta, gamma_dot_t_offset, phi0 | 46: contrast + offset | 53 + sigma |
Auto mode (averaged scaling, default for n_phi >= 3):
| Mode | Physical Params | Averaged Scaling | Total | |——|—————-|——————|——-| | static | 3 | 2: avg_contrast + avg_offset | 5 + sigma | | laminar_flow | 7 | 2: avg_contrast + avg_offset | 9 + sigma |
Constant / constant_averaged modes (fixed scaling, not sampled):
| Mode | Physical Params | Scaling | Total | |——|—————-|———|——-| | static | 3 | Fixed (0 sampled) | 3 + sigma | | laminar_flow | 7 | Fixed (0 sampled) | 7 + sigma |
CMC Configuration Defaults (v2.22.2)¶
| Parameter | Default | Description | |———–|———|————-| | min_points_for_cmc | 100,000 | Auto-enable threshold | | sharding_strategy | “random” | “stratified” or “random” | | backend_name | “auto” | → “multiprocessing” | | num_warmup | 500 | NUTS warmup iterations (pre-adaptive) | | num_samples | 1500 | NUTS sampling iterations (pre-adaptive) | | num_chains | 4 | Parallel chains | | target_accept_prob | 0.85 | NUTS target acceptance. For laminar_flow mode, automatically elevated to 0.9 if configured value is below 0.9. | | max_r_hat | 1.1 | Convergence threshold | | min_ess | 400.0 | Minimum effective sample size | | max_divergence_rate | 0.10 | Quality filter threshold | | require_nlsq_warmstart | False | Enforce NLSQ warm-start | | combination_method | “robust_consensus_mc” | Robust CMC with MAD outlier filtering | | per_shard_timeout | 3600 | 1 hour max per shard | | heartbeat_timeout | 600 | 10 min - terminate unresponsive workers | | min_points_per_shard | 10,000 | Config field default; actual enforced minimum in code is MIN_SHARD_SIZE_LAMINAR=3000 | | min_points_per_param | 1,500 | Minimum points per parameter per shard | | max_parameter_cv | 1.0 | Heterogeneity abort threshold | | heterogeneity_abort | True | Abort on high heterogeneity | | adaptive_sampling | True | Scale warmup/samples based on shard size | | max_tree_depth | 10 | NUTS depth (max 2^depth leapfrog steps) | | min_warmup | 100 | Minimum warmup even for small datasets | | min_samples | 200 | Minimum samples even for small datasets | | use_nlsq_informed_priors | True | Build TruncatedNormal priors from NLSQ | | nlsq_prior_width_factor | 2.0 | Width = NLSQ_std x factor (~95.4% coverage) | | prior_tempering | True | Scale priors by 1/K per shard (Scott et al.) | | reparameterization_d_total | True | Sample D_total = D0 + D_offset | | reparameterization_log_gamma | True | Sample log(gamma_dot_t0) | | bimodal_min_weight | 0.2 | Minimum weight for GMM bimodal detection | | bimodal_min_separation | 0.5 | Minimum relative separation for bimodal | | enable_jax_profiling | False | Enable jax.profiler tracing | | seed | 42 | Base seed for PRNG key generation |
Key Files Reference¶
File |
Lines |
Purpose |
|---|---|---|
core.py |
~1566 |
Main orchestration, shard size selection, runtime estimation |
data_prep.py |
~848 |
Validation, sharding (stratified & random), noise estimation |
sampler.py |
~1326 |
NUTS sampling, SamplingPlan, preflight checks, adaptive scaling |
model.py |
~1168 |
5 model variants (scaled, constant, averaged, constant_averaged, reparameterized) |
priors.py |
~1100 |
Prior distributions, NLSQ-informed priors, data-driven estimation |
results.py |
~789 |
CMCResult dataclass, convergence diagnostics, quality flags |
config.py |
~887 |
CMCConfig parsing, validation, defaults, effective mode selection |
diagnostics.py |
~1269 |
R-hat, ESS, bimodal detection, cross-shard analysis, |
reparameterization.py |
~336 |
t_ref computation, log-space transforms, ReparamConfig |
backends/base.py |
~887 |
Abstract backend, |
backends/multiprocessing.py |
~1953 |
Parallel execution, shared memory, LPT scheduling, divergence filtering |
backends/worker_pool.py |
~354 |
Worker pool lifecycle, health monitoring, graceful shutdown |
backends/pbs.py |
~494 |
PBS/Torque HPC cluster backend |
backends/pjit.py |
~269 |
Single-process sequential backend (debug) |
io.py |
~430 |
Result serialization (JSON/NPZ), |
Critical Features & Fixes¶
v2.14.2: Diagonal Point Filtering¶
Excludes t1 == t2 points (autocorrelation peaks)
Prevents biasing fit with synthetic/interpolated data
v2.12.0: Correct Consensus MC¶
Implements Scott et al. (2016) properly
Precision-weighted combination of posterior moments
December 2025: Smooth Bounding¶
Replaced hard clipping with tanh-based smooth bounds
Prevents non-smooth behavior at parameter boundaries
Enables better HMC/NUTS adaptation during warmup
December 2025: Proper Time Grid Construction¶
Fixed incorrect grid density from
np.unique(t1, t2)Constructs with explicit dt spacing matching NLSQ
Dense Mass Matrix¶
dense_mass=Trueis CRITICAL for NUTSLearns parameter covariance during warmup
Handles 10⁶:1 gradient imbalance from different parameter scales
January 2026: Quality Filtering & Warm-Start (v2.19.0)¶
Divergence-Based Shard Quality Filter
After root cause analysis of CMC runs with 28% divergence rates, automatic quality filtering was added:
┌───────────────────────────────────────────────────────────────────────────┐
│ Divergence Filtering (backends/multiprocessing.py) │
│ │
│ max_divergence_rate = config.max_divergence_rate (default: 0.10) │
│ │
│ After all shards complete: │
│ 1. Calculate divergence_rate = num_divergent / total_samples per shard │
│ 2. Filter out shards where divergence_rate > max_divergence_rate │
│ 3. Log excluded shards with divergence details │
│ 4. Re-check post-filter success rate against min_success_rate │
│ │
│ Purpose: Prevent corrupted posteriors from biasing consensus combination│
│ Shards with >10% divergences have unreliable posterior estimates │
└───────────────────────────────────────────────────────────────────────────┘
January 2026: CMC Divergence & Precision Loss Fix (v2.20.0)¶
Root Cause Analysis
CMC was producing parameter estimates diverging significantly from NLSQ:
D0: -37% difference (12,444 vs 19,665)
D_offset: -92% difference (71 vs 844)
CMC uncertainties artificially small (precision-weighted bias)
Root causes identified:
Excessive sharding: 999 shards with only 3000 points each (data-starved)
No NLSQ warm-start: Cold start from config values → 28% divergence rate
Consensus MC bias: Precision-weighted combination biased toward low-variance shards
Automatic NLSQ→CMC Warm-Start (CLI)
┌───────────────────────────────────────────────────────────────────────────┐
│ AUTOMATIC NLSQ Warm-Start (cli/commands.py, v2.20.0) │
│ │
│ When user runs: homodyne --method cmc --config my_config.yaml │
│ │
│ The CLI AUTOMATICALLY: │
│ 1. Runs NLSQ optimization first (unless --no-nlsq-warmstart) │
│ 2. Uses NLSQ results as initial values for CMC │
│ 3. Reduces divergence rate from ~28% to <5% │
│ │
│ To disable (NOT recommended): │
│ homodyne --method cmc --no-nlsq-warmstart --config my_config.yaml │
└───────────────────────────────────────────────────────────────────────────┘
Heterogeneity Detection & Abort
┌───────────────────────────────────────────────────────────────────────────┐
│ Heterogeneity Detection (backends/multiprocessing.py, v2.20.0) │
│ │
│ After collecting all shard results: │
│ │
│ 1. Compute coefficient of variation (CV) for each parameter: │
│ CV = std(shard_means) / |mean(shard_means)| │
│ │
│ 2. Check against threshold (max_parameter_cv, default 1.0): │
│ • CV > threshold for critical params → high heterogeneity detected │
│ • Critical params: D0, D_offset (static); + gamma_dot_t0 (laminar) │
│ │
│ 3. If heterogeneity_abort=True (default): │
│ → Raises RuntimeError with actionable guidance: │
│ "High heterogeneity detected (D0 CV=1.80). Consider: │
│ 1. Run NLSQ first for warm-start │
│ 2. Increase min_points_per_shard │
│ 3. Reduce n_shards" │
│ │
│ 4. If heterogeneity_abort=False: │
│ → Falls back to robust_consensus_mc combination │
└───────────────────────────────────────────────────────────────────────────┘
Robust Consensus MC Combination
┌───────────────────────────────────────────────────────────────────────────┐
│ Robust Combination (backends/base.py, v2.20.0) │
│ │
│ When standard consensus_mc would produce biased results: │
│ │
│ 1. Compute per-shard means for each parameter │
│ 2. Detect outlier shards using MAD (Median Absolute Deviation): │
│ • median_mean = median(shard_means) │
│ • mad = median(|shard_means - median_mean|) │
│ • threshold = 3 × 1.4826 × mad │
│ • outlier if |mean - median_mean| > threshold │
│ 3. Exclude outlier shards from combination │
│ 4. Apply standard consensus_mc to retained shards │
│ │
│ Auto-enabled when: combination_method="robust_consensus_mc" │
│ Or when heterogeneity detected but heterogeneity_abort=False │
└───────────────────────────────────────────────────────────────────────────┘
Dynamic Shard Sizing for Large Datasets
┌───────────────────────────────────────────────────────────────────────────┐
│ Dynamic Shard Sizing (core.py, v2.20.0) │
│ │
│ Problem: Fixed max_shards=500/1000 doesn't scale to 100M+ datasets │
│ │
│ Solution: Dynamic max_shards based on dataset size │
│ │
│ Dataset Size │ max_shards │ Rationale │
│ ────────────────┼────────────┼───────────────────────────────────────── │
│ < 10M │ 2,000 │ Standard parallelism │
│ 10M - 100M │ 10,000 │ Balanced shard count │
│ 100M - 1B │ 50,000 │ High parallelism │
│ 1B+ │ 100,000 │ Extreme scale │
│ │
│ Additionally enforces MINIMUM shard size: │
│ • laminar_flow: 3,000 points (reparameterization fixes bimodal posteriors) │
│ • static: 5,000 points │
│ │
│ Angle-aware scaling factor (less aggressive, 0.3→0.6 for n_phi≤3) │
└───────────────────────────────────────────────────────────────────────────┘
Enhanced CMC vs NLSQ Diagnostics
┌───────────────────────────────────────────────────────────────────────────┐
│ Enhanced Precision Analysis (diagnostics.py, v2.20.0) │
│ │
│ When comparing CMC to NLSQ results: │
│ │
│ 1. Logs detailed comparison table: │
│ Parameter │ NLSQ Value │ CMC Value │ Diff% │ Z-Score │ Status │
│ ──────────┼────────────┼───────────┼───────┼─────────┼───────────── │
│ D0 │ 19665 ±68 │ 12444 ±14 │ -37% │ 106.2 │ ⚠ EXCEEDS │
│ D_offset │ 844 ±0.9 │ 71 ±0.4 │ -92% │ 858.9 │ ⚠ EXCEEDS │
│ │
│ 2. Flags parameters exceeding tolerance (default 3σ) │
│ 3. Provides actionable guidance if discrepancies detected │
└───────────────────────────────────────────────────────────────────────────┘
CMCConfig Fields Added in v2.20.0¶
| Field | Type | Default | Description | |——-|——|———|————-| |
max_divergence_rate | float | 0.10 | Filter shards exceeding this divergence rate | |
require_nlsq_warmstart | bool | False | Require NLSQ warm-start (API-level) | |
min_points_per_shard | int | 10,000 | Config field default; actual enforced minimum in
code is MIN_SHARD_SIZE_LAMINAR=3000 | | max_parameter_cv | float | 1.0 | Heterogeneity
abort threshold | | heterogeneity_abort | bool | True | Abort on high heterogeneity |
| min_points_per_param | int | 1,500 | Param-aware shard sizing floor |
February 2026: Mode-Aware Consensus MC (v2.22.0)¶
Standard consensus MC assumes per-shard posteriors are approximately Gaussian. When
shards have bimodal posteriors (e.g., D₀ ~19K and ~32K with 50/50 weight splits), the
naive mean falls in the density trough between modes and the variance is inflated by
w1·w2·(μ1−μ2)². Mode-aware consensus decomposes bimodal shards into per-component GMM
statistics and runs precision-weighted consensus separately per mode cluster, producing
a mixture-drawn output. See §9 for algorithm details.
January 2026: Heterogeneity Prevention (v2.21.0)¶
Parameter Degeneracy in Laminar Flow Mode
The laminar_flow model has two known parameter degeneracies that can cause high
heterogeneity across CMC shards:
1. D₀/D_offset Linear Degeneracy
The diffusion contribution depends on D₀ + D_offset, creating a linear manifold in
parameter space where different (D₀, D_offset) pairs produce equivalent fits.
| Symptom | Cause | |———|——-| | D_offset CV > 1.0 | Shards find different
points along the D₀ + D_offset = const ridge | | D_offset spans positive and negative
| Ridge crosses zero for D_offset | | High D₀ range despite good NLSQ fit |
Compensating D_offset values |
Mitigation (automatic in v2.23.0+): CMC internally reparameterizes to
log_D_ref = log(D₀ × t_ref^α) and D_offset_ratio = D_offset / D_ref, which are
orthogonal and well-constrained. D_offset_ratio supports negative values (jammed/arrested
systems) via a TruncatedNormal(low=-1+ε) prior. Results are automatically converted
back to D₀/D_offset for output.
2. γ̇₀/β Multiplicative Correlation
The shear contribution scales as γ̇₀ · t^(1+β). Higher γ̇₀ with more negative β can
produce similar effects to lower γ̇₀ with less negative β.
| Symptom | Cause | |———|——-| | gamma_dot_t0 CV > 1.0 | Shards explore the
γ̇₀-β correlation ridge | | gamma_dot_t0 spans 10-100× range | Compensating β values |
| beta moderate heterogeneity (CV ~0.5-0.8) | Correlated with γ̇₀ |
Mitigation (automatic in v2.21.0+): CMC samples log(γ̇₀) instead of γ̇₀ directly,
which improves conditioning and reduces posterior ridge exploration.
Bimodal Detection & Mode-Aware Consensus (v2.22.0)
┌───────────────────────────────────────────────────────────────────────────┐
│ Per-Shard Bimodal Detection (diagnostics.py) │
│ │
│ After MCMC sampling, each shard is checked for bimodal posteriors: │
│ │
│ 1. Fit 2-component GMM to each parameter's samples │
│ 2. Flag as bimodal if: │
│ • min(weights) > 0.2 (both modes significant) │
│ • relative_separation > 0.5 (modes well-separated) │
│ 3. Store per-detection record: │
│ {shard, param, mode1, mode2, std1, std2, weights, separation} │
└───────────────────────────────────────────────────────────────────────────┘
┌───────────────────────────────────────────────────────────────────────────┐
│ Cross-Shard Bimodality Analysis (diagnostics.py) │
│ │
│ summarize_cross_shard_bimodality(): │
│ • Groups detections by parameter name │
│ • For params with bimodal fraction > 5%: │
│ - Computes mode means/stds across shards │
│ - Checks if consensus mean falls in density trough │
│ • Detects D₀-alpha co-occurrence (parameter degeneracy) │
│ │
│ cluster_shard_modes(): │
│ • Builds range-normalized feature vectors from per-shard modes │
│ • Seeds KMeans from cross-shard lower/upper means │
│ • Assigns each bimodal shard to lower or upper mode cluster │
│ • Returns (lower_indices, upper_indices) for combine step │
│ │
│ Structured logging via _log_bimodality_summary(): │
│ • Per-parameter table: bimodal%, mode means±std, separation │
│ • Consensus impact: density trough warnings, co-occurrence stats │
│ • Actionable guidance (shard size, prior tightening) │
└───────────────────────────────────────────────────────────────────────────┘
┌───────────────────────────────────────────────────────────────────────────┐
│ Mode-Aware Combination (backends/base.py) │
│ │
│ combine_shard_samples_bimodal(): │
│ • For each mode cluster, runs precision-weighted consensus: │
│ - Bimodal shards: uses per-component GMM stats (mu, sigma²) │
│ - Unimodal shards: uses full posterior stats │
│ • Generates mixture-drawn output samples from both modes │
│ • Returns (MCMCSamples, BimodalConsensusResult) │
│ │
│ Key dataclasses: │
│ BimodalResult: per-shard GMM fit (means, stds, weights, separation) │
│ ModeCluster: per-mode consensus (means, stds, n_shards, samples) │
│ BimodalConsensusResult: both modes + modal_params + co_occurrence │
│ │
│ Auto-triggered in multiprocessing backend when bimodal fraction > 5% │
│ Falls back to standard consensus_mc if <3 shards in a cluster │
└───────────────────────────────────────────────────────────────────────────┘
Param-Aware Shard Sizing
┌───────────────────────────────────────────────────────────────────────────┐
│ Param-Aware Shard Sizing (config.py, v2.21.0) │
│ │
│ Problem: High-dimensional models need more points per shard │
│ Solution: Scale shard size with parameter count │
│ │
│ adjusted_max = max(base_max × param_factor, min_points_per_param × n) │
│ where param_factor = max(1.0, n_params / 7.0) │
│ │
│ Example (laminar_flow + individual scaling, 23 angles): │
│ • n_params = 7 + 46 + 1 = 54 │
│ • param_factor = 54/7 = 7.71 │
│ • min_required = 1500 × 54 = 81,000 points │
│ • For 500K points → ~6 shards (vs 50+ with default sizing) │
│ │
│ Prevents data starvation in high-dimensional per-angle modes │
└───────────────────────────────────────────────────────────────────────────┘
Diagnostic Indicators
When heterogeneity abort triggers, check these indicators:
| Indicator | Healthy | Problematic | |———–|———|————-| | D_offset CV | < 0.5 | > 1.0 | | D_offset range | Within ±20% of D₀ | Spans ±D₀ or sign changes | | gamma_dot_t0 CV | < 0.5 | > 1.0 | | Bimodal warnings | 0 | Multiple shards |
Configuration Options
If heterogeneity persists after v2.21.0+ mitigations:
optimization:
cmc:
reparameterization:
d_total: true # Default: true for laminar_flow
log_gamma_dot: true # Default: true for laminar_flow
sharding:
max_points_per_shard: 50000 # Increase for more statistical power
validation:
max_parameter_cv: 1.5 # Relax threshold if physical heterogeneity expected
February 2026: NLSQ-Informed Priors & Prior Tempering (v2.22.2)¶
NLSQ-Informed Priors
When nlsq_result is provided to fit_mcmc_jax(), the prior builder constructs
TruncatedNormal priors centered on NLSQ estimates with width = NLSQ_std x
nlsq_prior_width_factor (default 2.0, ~95.4% coverage). This dramatically reduces
warmup time and divergence rates.
┌───────────────────────────────────────────────────────────────────────────┐
│ NLSQ-Informed Prior Construction (priors.py, v2.22.2) │
│ │
│ For each physical parameter: │
│ center = nlsq_values[param] │
│ width = nlsq_uncertainties[param] × width_factor │
│ prior = TruncatedNormal(center, width, low=lb, high=ub) │
│ │
│ Config: use_nlsq_informed_priors=True (default) │
│ nlsq_prior_width_factor=2.0 (default) │
└───────────────────────────────────────────────────────────────────────────┘
Prior Tempering (Scott et al. 2016)
Without tempering, K shards each apply the full prior, producing combined posterior = prior^K x likelihood. With tempering, each shard uses prior^(1/K), producing the correct combined posterior = prior x likelihood. For Normal(mu, sigma): prior^(1/K) ~ Normal(mu, sigma*sqrt(K)), i.e., widen std by sqrt(num_shards).
Enabled by default via prior_tempering=True.
February 2026: Adaptive Sampling & SamplingPlan (v2.22.2)¶
Small datasets receive fewer warmup/samples to reduce NUTS overhead:
┌───────────────────────────────────────────────────────────────────────────┐
│ SamplingPlan (sampler.py, v2.22.2) │
│ │
│ SamplingPlan.from_config(config, shard_size, n_params): │
│ • Scales warmup/samples based on shard_size │
│ • Enforces min_warmup (100) and min_samples (200) │
│ • Records was_adapted flag │
│ │
│ IMPORTANT: Use SamplingPlan instead of config.num_warmup directly │
│ in sampling hot paths. config.num_warmup/num_samples are pre- │
│ adaptation defaults for logging and timeout estimation only. │
│ │
│ Shard Size │ Warmup │ Samples │ Reduction │
│ ───────────┼────────┼─────────┼────────── │
│ 50 pts │ 140 │ 350 │ 75% │
│ 5K pts │ 250 │ 750 │ 50% │
│ 50K+ pts │ 500 │ 1,500 │ None │
└───────────────────────────────────────────────────────────────────────────┘
February 2026: JAX Profiling Support (v2.22.2)¶
py-spy only profiles Python code; XLA runs native code invisible to it. Enable JAX profiling for XLA-level insights:
optimization:
cmc:
per_shard_mcmc:
enable_jax_profiling: true
jax_profile_dir: ./profiles/jax
View with TensorBoard: tensorboard --logdir=./profiles/jax
February 2026: Constant-Averaged Mode & NLSQ Parity (v2.22.2)¶
When both CMC and NLSQ use “auto” mode and NLSQ warm-start is present,
get_effective_per_angle_mode() automatically upgrades to “constant_averaged”. This
fixes scaling values and reduces the parameter count, preventing contrast/offset
sampling from absorbing physical parameter signal.