homodyne.optimization.cmc.backends¶
The backends sub-package provides pluggable execution backends for CMC.
Each backend is responsible for distributing per-shard MCMC work across available
compute resources and combining the results.
Backend |
Description |
|---|---|
|
Spawns N worker processes (physical_cores/2 − 1); recommended for CPU |
|
JAX |
|
HPC batch scheduler backend (PBS/SLURM) |
|
Persistent worker pool — amortizes JAX init across shards |
|
Abstract |
Set backend_name: "auto" (default) to let Homodyne select the optimal
backend for the current environment.
Multiprocessing Backend¶
The recommended backend for CPU-based systems. Key architecture:
Worker spawn —
N = max(1, physical_cores // 2 − 1)worker processes are spawned using thespawnstart method (required for JAX safety).Shared memory —
SharedDataManagershares config, parameter space, time grids, and per-shard data arrays across workers viamultiprocessing.shared_memory. Thecreate_shared_shard_arrays()method places each shard’s numpy arrays (data, t1, t2, phi_unique, phi_indices) in shared memory, eliminating per-process serialization overhead through the spawn mechanism.XLA configuration — the parent process sets
JAX_ENABLE_X64=1inhomodyne/__init__.pyandcli/main.pybefore any JAX import. Each spawned worker also setsJAX_ENABLE_X64and configuresXLA_FLAGSwith--xla_force_host_platform_device_count=4before importing JAX, providing 4 virtual devices per worker forparallelchains. The redundant worker-side set is required because spawn-mode workers start fresh processes that do not inherit the parent’sjax.configstate.Thread environment —
OMP_NUM_THREADSis set to 1–2 per worker to prevent thread oversubscription.OMP_PROC_BINDandOMP_PLACESare cleared before spawning and restored afterwards.Adaptive polling — the manager adjusts the result-queue poll interval based on shard activity, reducing CPU spin.
Batch PRNG — all shard random keys are pre-generated in a single JAX call before spawning, avoiding repeated JAX initialisation.
LPT scheduling — shards are dispatched using noise-weighted Longest Processing Time first ordering. Cost is estimated as
n_points × (1 + normalized_noise), dispatching the most expensive shards first to minimize tail latency.JIT compilation cache — workers configure
jax.config.update()to enable the persistent compilation cache withmin_compile_time_secs=0. The first worker compiles all JIT functions; subsequent workers load from the disk cache (2.3× worker startup speedup).
- class homodyne.optimization.cmc.backends.multiprocessing.MultiprocessingBackend[source]
Bases:
CMCBackendCMC backend using Python multiprocessing.
Runs MCMC sampling in parallel across CPU cores using Python’s multiprocessing module.
- __init__(n_workers=None, spawn_method='spawn')[source]
Initialize multiprocessing backend.
- run(model, model_kwargs, config, shards=None, initial_values=None, parameter_space=None, analysis_mode='static', progress_bar=True)[source]
Run MCMC sampling across shards.
- Parameters:
model (
Callable) – NumPyro model function.config (
CMCConfig) – CMC configuration.initial_values (
dict[str,float] |None) – Initial parameter values.parameter_space (
ParameterSpace|None) – Parameter space for priors.analysis_mode (
str) – Analysis mode.progress_bar (
bool) – Whether to show progress bar for shard completion.
- Returns:
Combined samples from all shards.
- Return type:
MCMCSamples
Chain Execution Methods¶
Controls how MCMC chains are executed within each worker process.
Method |
Best For |
Description |
|---|---|---|
|
CPU multiprocessing (default) |
Uses |
|
Single-process only |
Uses |
|
Debugging |
Runs chains one at a time. No parallelism. |
Warning
chain_method: "vectorized" is NOT recommended for the multiprocessing
backend. Workers drop to 1–2 active CPUs because vmap does not distribute
across the 4 virtual JAX devices. Always use "parallel" in production.
XLA Device Setup¶
Each worker configures XLA before importing JAX:
# Executed inside each worker process (before JAX import)
import os
os.environ["JAX_ENABLE_X64"] = "1"
os.environ["XLA_FLAGS"] = (
"--xla_force_host_platform_device_count=4"
)
import jax # JAX sees 4 virtual CPU devices
This gives each worker 4 virtual devices for parallel chain execution.
The parent process restores its original environment after all workers have
been spawned. After importing JAX, each worker also enables the persistent
JIT compilation cache:
import jax
jax.config.update("jax_compilation_cache_dir", cache_dir)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
Warning
In JAX 0.8+, os.environ["JAX_COMPILATION_CACHE_DIR"] alone does NOT
enable the persistent cache. The jax.config.update() call is required.
optimization:
cmc:
backend_config:
name: "auto"
per_shard_mcmc:
chain_method: "parallel"
num_chains: 4
Heterogeneity Detection¶
The multiprocessing backend uses bounds-aware coefficient of variation (CV) to detect heterogeneous shards before combining posteriors.
For near-zero parameters (e.g., \(\dot\gamma_0 \sim 10^{-3}\)), dividing by the mean CV would artificially inflate the heterogeneity score. Instead, the scale is computed relative to the parameter’s bounds range:
# For near-zero params: scale = param_range * 0.01
# Falls back to: scale = max(abs(mean), 1e-10)
cv = std / scale
Shards with abnormally high cross-shard CV are flagged for review and may be excluded before the consensus combination step.
Worker Pool¶
The worker_pool module provides a persistent worker pool that avoids
repeated JAX re-initialization overhead. Workers are spawned once, process
multiple shards via task queues, and shut down when the pool context exits.
Falls back to per-shard process spawning when n_shards < 3 (where pool
overhead exceeds the amortization benefit).
- homodyne.optimization.cmc.backends.worker_pool.should_use_pool(n_shards, n_workers)[source]
Determine if worker pool is beneficial.
- class homodyne.optimization.cmc.backends.worker_pool.WorkerPool[source]
Bases:
objectPersistent process pool for CMC shard dispatch.
Workers are spawned once, process multiple tasks via queues, and shut down when the pool is no longer needed.
- Parameters:
n_workers (
int) – Number of persistent worker processes.worker_fn (
Callable[...,dict[str,Any] |None]) – Function each worker calls per task. Signature:worker_fn(task: dict, **init_kwargs) -> dict | None. Must be picklable (module-level function). If it returnsNone, the pool does not put a result on the result queue (useful when the worker manages its own queue).worker_init_kwargs (
dict[str,Any]) – One-time kwargs passed to every worker_fn call.worker_init_fn (
Callable[...,None] |None) – Optional one-time initialization function called once per worker before the event loop starts. Signature:worker_init_fn(worker_id: int, **init_kwargs) -> None. Use for expensive setup like JAX/OMP initialization.
- __init__(n_workers, worker_fn, worker_init_kwargs, worker_init_fn=None, startup_timeout=120.0)[source]
- property n_workers: int
Number of worker processes.
- property result_queue: Queue
The shared result queue drained by the parent.
- submit(task)[source]
Submit a task to the next available worker (round-robin).
- get_result(timeout=300.0)[source]
Block until a result is available.
- Parameters:
timeout (
float) – Maximum seconds to wait.- Returns:
Result from a worker.
- Return type:
- Raises:
queue.Empty – If no result within timeout.
Base Backend and Combination Utilities¶
- class homodyne.optimization.cmc.backends.base.CMCBackend[source]
Bases:
ABCAbstract base class for CMC execution backends.
Backends handle the parallel execution of MCMC sampling across data shards and the combination of results.
- abstractmethod run(model, model_kwargs, config, shards=None)[source]
Run MCMC sampling (potentially across shards).
- Parameters:
- Returns:
Combined samples from all shards.
- Return type:
MCMCSamples
- homodyne.optimization.cmc.backends.base.combine_shard_samples(shard_samples, method='weighted_gaussian', chunk_size=500)[source]
Combine samples from multiple shards.
For K <= chunk_size shards, uses a single-pass combination.
For K > chunk_size shards (hierarchical mode), accumulates posterior moments (mean, variance) across chunks without drawing intermediate synthetic samples. A single Gaussian draw is performed at the end from the aggregated moments. This avoids the precision-multiplication artefact that arises when recursive combination re-applies precision-weighting to synthetically drawn intermediate samples (P1-R6-01).
Memory scaling:
Each shard result: ~100KB (13 params x 4 chains x 1500 samples x 8 bytes)
Hierarchical (chunk=500): processes max(chunk_size) shards at once (~50MB), then releases them. Moment accumulation uses O(n_params) space.
- Parameters:
shard_samples (
list[MCMCSamples]) – Samples from each shard.method (
str) – Combination method: “robust_consensus_mc” (recommended), “consensus_mc”, “weighted_gaussian”, “simple_average”, or “auto”.chunk_size (
int) – Number of shards to process per chunk for hierarchical combination. Default 500 keeps peak memory under ~50MB per processing step.
- Returns:
Combined samples.
- Return type:
MCMCSamples
Usage Examples¶
Selecting the multiprocessing backend explicitly¶
optimization:
cmc:
backend_config:
name: "multiprocessing"
per_shard_mcmc:
chain_method: "parallel"
num_chains: 4
Inspecting worker count¶
import psutil
from homodyne.optimization.cmc.backends.multiprocessing import (
MultiprocessingBackend,
)
backend = MultiprocessingBackend(config=cmc_config)
n_workers = max(1, psutil.cpu_count(logical=False) // 2 - 1)
print(f"Worker processes: {n_workers}")
print(f"Virtual devices/worker: 4 (via xla_force_host_platform_device_count)")
print(f"Total parallel chains: {n_workers * 4}")
Manually running a backend¶
from homodyne.optimization.cmc.backends import select_backend
from homodyne.optimization.cmc.config import CMCConfig
config = CMCConfig()
backend = select_backend(config.backend_name, config)
result = backend.run(
shards=prepared_shards,
parameter_space=param_space,
model=xpcs_model,
initial_values=init_vals,
)
Multiprocessing backend for CMC execution.
This module provides parallel MCMC execution using Python’s multiprocessing module for CPU-based parallelism.
Optimizations (v2.9.1): - Batch PRNG key generation: Pre-generate all shard keys in single JAX call - Adaptive polling: Adjust poll interval based on shard activity - Event.wait heartbeat: Efficient heartbeat using Event.wait(timeout)
Optimizations (v2.22.2): - LPT scheduling: Dispatch highest-cost shards first (size + noise weighted) - Per-shard shared memory: Shard arrays stored in shared memory (avoids pickle overhead) - deque for pending shards: O(1) popleft instead of O(n) list.pop(0) - JIT cache fix: Enable persistent compilation cache via jax.config.update (env var alone insufficient in JAX 0.8+, min_compile_time lowered to 0)