ADR-004: Consensus Monte Carlo for Bayesian Inference¶
- Status:
Accepted
- Date:
2024–2025
- Deciders:
Core team
Context¶
XPCS datasets from modern synchrotron sources can exceed \(10^8\) data points (the full \(c_2(t_1, t_2)\) matrix at high temporal resolution). Standard MCMC methods — including NUTS — scale as \(O(n)\) per leapfrog step, where \(n\) is the number of data points. For \(10^8\) points, a single NUTS step with 100 leapfrog substeps would require evaluating \(10^{10}\) model instances, which is computationally intractable even with JIT compilation.
The question is: how to make full Bayesian inference tractable on large XPCS datasets while preserving statistical validity?
Decision¶
Homodyne implements Consensus Monte Carlo (CMC, Scott et al. 2016) as its Bayesian backend. The algorithm is:
Shard: Partition the \(n\) data points into \(M\) shards of size \(n_s \approx n/M\).
Parallel NUTS: Run independent NUTS chains on each shard (in separate worker processes), obtaining \(K\) posterior samples \(\{\theta^{(k)}_s\}\) from the shard-specific posterior \(p_s(\theta | \mathcal{D}_s)\).
Consensus: Combine the \(M\) sets of shard samples into a single approximation of the global posterior \(p(\theta | \mathcal{D})\).
The multiprocessing backend spawns \(N_\mathrm{workers} = \lfloor N_\mathrm{cores}/2\rfloor - 1\)
worker processes, each with 4 virtual JAX devices (via
--xla_force_host_platform_device_count=4). Each worker runs NUTS in parallel mode
(pmap over 4 devices), achieving near-full CPU utilization.
Chain Execution Method¶
Each worker process runs 4 NUTS chains using NumPyro’s parallel execution method
(pmap over 4 virtual JAX devices):
kernel = NUTS(model, max_tree_depth=10)
mcmc = MCMC(kernel, num_warmup=plan.n_warmup, num_samples=plan.n_samples)
mcmc.run(rng_key, ..., extra_fields=("energy",))
The parallel method is empirically 20x faster than vectorized (vmap) for the
multiprocessing backend (4.9 s vs 101 s wall time for identical workloads). The reason:
pmap distributes chains across 4 virtual CPU “devices” (NUMA-aware XLA partitions),
while vmap batches chains on a single device sequentially.
Worker Environment¶
Before spawning workers, the backend:
Saves the current environment (
OMP_PROC_BIND,OMP_PLACES).Clears
OMP_PROC_BINDandOMP_PLACESto prevent OpenMP thread binding conflicts between workers.Sets
OMP_NUM_THREADS=1or 2 per worker to prevent thread oversubscription (each worker manages its own JAX device count via XLA_FLAGS).Restores the parent environment after all workers are spawned.
Rationale¶
1. CMC is asymptotically exact
The consensus combination is exact when:
Measurements are independent (true for XPCS: each \((t_1, t_2, \phi)\) triplet is an independent measurement conditioned on the parameters).
The shards are drawn i.i.d. from the full dataset.
The prior is the same in all shard models.
Under these conditions, the consensus product of shard posteriors equals the global posterior (up to normalization). See Scott et al. 2016.
2. CMC enables linear scalability
With \(M\) shards and \(P\) parallel workers:
Total NUTS cost: \(O(n_s \cdot L)\) per worker per chain, not \(O(n \cdot L)\).
Consensus step: \(O(M \cdot K)\) — negligible.
Wall time scales as \(O(n / (P \cdot n_s))\) — linear in \(1/P\).
In practice, wall time is dominated by the NUTS warmup, which is \(O(n_s)\) per shard.
3. NLSQ warm-start dramatically improves CMC quality
Without NLSQ initialization, shard-level NUTS chains have high divergence rates (~28%) because the default broad priors place chains far from the posterior mode. The NLSQ covariance matrix provides tight, data-informed priors that:
Reduce divergences to <5%.
Allow NUTS warmup to complete in fewer steps.
Prevent chains from exploring unphysical parameter regions.
4. Quality filtering prevents posterior corruption
Shards with divergence rate > 10% are excluded from the consensus. This is a conservative threshold that discards shards where NUTS clearly failed (bad geometry, wrong step size) while retaining the majority of shards with acceptable mixing.
Consequences¶
Positive:
Scales to arbitrarily large datasets by increasing the number of shards.
Full posterior uncertainty quantification, not just linearized NLSQ errors.
ArviZ diagnostics (\(\hat{R}\), ESS, BFMI, divergence fraction) provide quantitative quality assessment.
Negative / Accepted trade-offs:
CMC produces an approximation of the global posterior, not the exact posterior. The approximation quality depends on shard size (larger shards → better approximation).
CMC requires the multiprocessing spawn of worker processes; startup overhead is ~1–2 seconds per worker.
Auto shard-size selection may be suboptimal for unusual dataset characteristics; users can override with
max_points_per_shard: <integer>.
Alternatives Considered¶
A. Standard NUTS on full dataset
Exact posterior. Rejected because: \(O(n)\) per step is intractable for \(n > 10^6\).
B. Minibatch MCMC (stochastic gradient MCMC)
Scales better than standard NUTS. Rejected because: stochastic gradient MCMC has known bias and is difficult to diagnose; the convergence guarantees are weaker.
C. Variational inference (ADVI)
Fast (\(O(\text{epochs} \cdot n_\mathrm{minibatch})\)). Rejected because: mean-field ADVI systematically underestimates posterior variance for correlated parameters — a known failure mode for the \((D_0, \dot{\gamma}_0)\) correlation in the laminar-flow model.
D. Sequential Monte Carlo (SMC)
SMC is asymptotically exact and handles multimodal posteriors well. Rejected for the initial version because: SMC requires many sequential passes over the data, which is harder to parallelize across shards than independent NUTS chains. Reconsidered for future work.
See also
Scott et al. 2016 — original CMC paper (References and Citations)
Computational Methods — NUTS and CMC algorithm details
homodyne.optimization.cmc.backends.multiprocessing— implementationhomodyne.optimization.cmc.sampler— SamplingPlan with adaptive scaling