homodyne.optimization.cmc.backends

The backends sub-package provides pluggable execution backends for CMC. Each backend is responsible for distributing per-shard MCMC work across available compute resources and combining the results.

Backend

Description

multiprocessing

Spawns N worker processes (physical_cores/2 − 1); recommended for CPU

pjit

JAX pjit-based backend for single-host multi-device execution

pbs

HPC batch scheduler backend (PBS/SLURM)

worker_pool

Persistent worker pool — amortizes JAX init across shards

base

Abstract CMCBackend base class and shard combination utilities

Set backend_name: "auto" (default) to let Homodyne select the optimal backend for the current environment.


Multiprocessing Backend

The recommended backend for CPU-based systems. Key architecture:

  1. Worker spawnN = max(1, physical_cores // 2 1) worker processes are spawned using the spawn start method (required for JAX safety).

  2. Shared memorySharedDataManager shares config, parameter space, time grids, and per-shard data arrays across workers via multiprocessing.shared_memory. The create_shared_shard_arrays() method places each shard’s numpy arrays (data, t1, t2, phi_unique, phi_indices) in shared memory, eliminating per-process serialization overhead through the spawn mechanism.

  3. XLA configuration — the parent process sets JAX_ENABLE_X64=1 in homodyne/__init__.py and cli/main.py before any JAX import. Each spawned worker also sets JAX_ENABLE_X64 and configures XLA_FLAGS with --xla_force_host_platform_device_count=4 before importing JAX, providing 4 virtual devices per worker for parallel chains. The redundant worker-side set is required because spawn-mode workers start fresh processes that do not inherit the parent’s jax.config state.

  4. Thread environmentOMP_NUM_THREADS is set to 1–2 per worker to prevent thread oversubscription. OMP_PROC_BIND and OMP_PLACES are cleared before spawning and restored afterwards.

  5. Adaptive polling — the manager adjusts the result-queue poll interval based on shard activity, reducing CPU spin.

  6. Batch PRNG — all shard random keys are pre-generated in a single JAX call before spawning, avoiding repeated JAX initialisation.

  7. LPT scheduling — shards are dispatched using noise-weighted Longest Processing Time first ordering. Cost is estimated as n_points × (1 + normalized_noise), dispatching the most expensive shards first to minimize tail latency.

  8. JIT compilation cache — workers configure jax.config.update() to enable the persistent compilation cache with min_compile_time_secs=0. The first worker compiles all JIT functions; subsequent workers load from the disk cache (2.3× worker startup speedup).

class homodyne.optimization.cmc.backends.multiprocessing.MultiprocessingBackend[source]

Bases: CMCBackend

CMC backend using Python multiprocessing.

Runs MCMC sampling in parallel across CPU cores using Python’s multiprocessing module.

__init__(n_workers=None, spawn_method='spawn')[source]

Initialize multiprocessing backend.

Parameters:
  • n_workers (int | None) – Number of worker processes. If None, uses CPU count.

  • spawn_method (str) – Process start method: “spawn”, “fork”, or “forkserver”.

get_name()[source]

Get backend name.

Return type:

str

run(model, model_kwargs, config, shards=None, initial_values=None, parameter_space=None, analysis_mode='static', progress_bar=True)[source]

Run MCMC sampling across shards.

Parameters:
  • model (Callable) – NumPyro model function.

  • model_kwargs (dict[str, Any]) – Common model arguments.

  • config (CMCConfig) – CMC configuration.

  • shards (list[PreparedData] | None) – Data shards.

  • initial_values (dict[str, float] | None) – Initial parameter values.

  • parameter_space (ParameterSpace | None) – Parameter space for priors.

  • analysis_mode (str) – Analysis mode.

  • progress_bar (bool) – Whether to show progress bar for shard completion.

Returns:

Combined samples from all shards.

Return type:

MCMCSamples

is_available()[source]

Check if multiprocessing is available.

Return type:

bool

SharedDataManager

class homodyne.optimization.cmc.backends.multiprocessing.SharedDataManager[source]

Bases: object

Manages shared memory blocks for data common to all CMC shards.

Uses multiprocessing.shared_memory to share config, parameter space, initial values, and time_grid across spawned worker processes, avoiding redundant pickling per shard.

Note on serialization: Uses pickle internally for trusted config dicts only (CMCConfig.to_dict(), ParameterSpace). This matches the existing multiprocessing behavior which also pickles all process arguments.

Must be used as a context manager or call cleanup() in a finally block.

__init__()[source]
create_shared_bytes(name, data)[source]

Store bytes in shared memory.

Return type:

dict[str, Any]

create_shared_array(name, array)[source]

Store a numpy array in shared memory.

Return type:

dict[str, Any]

create_shared_dict(name, d)[source]

Serialize a trusted internal dict to shared memory.

Only used for CMCConfig and ParameterSpace dicts — never for external/untrusted data.

Return type:

dict[str, Any]

create_shared_shard_arrays(shard_data_list)[source]

Place per-shard numpy arrays into shared memory (packed format).

Instead of creating one SharedMemory segment per array per shard (n_shards * 5 = thousands of file descriptors), this concatenates all shard arrays for each key into a single shared memory block. Only 5 SharedMemory segments are created regardless of shard count.

Parameters:

shard_data_list (list[dict[str, Any]]) – List of shard data dicts, each containing numpy arrays (data, t1, t2, phi_unique, phi_indices) and a scalar noise_scale.

Returns:

List of lightweight shard references (shm names + offsets). Each ref dict is small enough to serialize cheaply through spawn.

Return type:

list[dict[str, Any]]

cleanup()[source]

Release all shared memory blocks. Must be called in a finally block.

Return type:

None


Chain Execution Methods

Controls how MCMC chains are executed within each worker process.

Method

Best For

Description

"parallel"

CPU multiprocessing (default)

Uses pmap across 4 virtual JAX devices per worker. Achieves full CPU utilisation with the multiprocessing backend.

"vectorized"

Single-process only

Uses vmap on a single device. Empirically 20× slower with the multiprocessing backend (101 s vs 4.9 s wall time).

"sequential"

Debugging

Runs chains one at a time. No parallelism.

Warning

chain_method: "vectorized" is NOT recommended for the multiprocessing backend. Workers drop to 1–2 active CPUs because vmap does not distribute across the 4 virtual JAX devices. Always use "parallel" in production.


XLA Device Setup

Each worker configures XLA before importing JAX:

# Executed inside each worker process (before JAX import)
import os
os.environ["JAX_ENABLE_X64"] = "1"
os.environ["XLA_FLAGS"] = (
    "--xla_force_host_platform_device_count=4"
)
import jax   # JAX sees 4 virtual CPU devices

This gives each worker 4 virtual devices for parallel chain execution. The parent process restores its original environment after all workers have been spawned. After importing JAX, each worker also enables the persistent JIT compilation cache:

import jax
jax.config.update("jax_compilation_cache_dir", cache_dir)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

Warning

In JAX 0.8+, os.environ["JAX_COMPILATION_CACHE_DIR"] alone does NOT enable the persistent cache. The jax.config.update() call is required.

optimization:
  cmc:
    backend_config:
      name: "auto"
    per_shard_mcmc:
      chain_method: "parallel"
      num_chains: 4

Heterogeneity Detection

The multiprocessing backend uses bounds-aware coefficient of variation (CV) to detect heterogeneous shards before combining posteriors.

For near-zero parameters (e.g., \(\dot\gamma_0 \sim 10^{-3}\)), dividing by the mean CV would artificially inflate the heterogeneity score. Instead, the scale is computed relative to the parameter’s bounds range:

# For near-zero params: scale = param_range * 0.01
# Falls back to: scale = max(abs(mean), 1e-10)
cv = std / scale

Shards with abnormally high cross-shard CV are flagged for review and may be excluded before the consensus combination step.


Worker Pool

The worker_pool module provides a persistent worker pool that avoids repeated JAX re-initialization overhead. Workers are spawned once, process multiple shards via task queues, and shut down when the pool context exits.

Falls back to per-shard process spawning when n_shards < 3 (where pool overhead exceeds the amortization benefit).

homodyne.optimization.cmc.backends.worker_pool.should_use_pool(n_shards, n_workers)[source]

Determine if worker pool is beneficial.

Parameters:
  • n_shards (int) – Total shards to process.

  • n_workers (int) – Available worker count.

Returns:

True if pool amortization outweighs overhead.

Return type:

bool

class homodyne.optimization.cmc.backends.worker_pool.WorkerPool[source]

Bases: object

Persistent process pool for CMC shard dispatch.

Workers are spawned once, process multiple tasks via queues, and shut down when the pool is no longer needed.

Parameters:
  • n_workers (int) – Number of persistent worker processes.

  • worker_fn (Callable[..., dict[str, Any] | None]) – Function each worker calls per task. Signature: worker_fn(task: dict, **init_kwargs) -> dict | None. Must be picklable (module-level function). If it returns None, the pool does not put a result on the result queue (useful when the worker manages its own queue).

  • worker_init_kwargs (dict[str, Any]) – One-time kwargs passed to every worker_fn call.

  • worker_init_fn (Callable[..., None] | None) – Optional one-time initialization function called once per worker before the event loop starts. Signature: worker_init_fn(worker_id: int, **init_kwargs) -> None. Use for expensive setup like JAX/OMP initialization.

__init__(n_workers, worker_fn, worker_init_kwargs, worker_init_fn=None, startup_timeout=120.0)[source]
property n_workers: int

Number of worker processes.

property result_queue: Queue

The shared result queue drained by the parent.

is_alive()[source]

Check if pool has active workers.

Return type:

bool

submit(task)[source]

Submit a task to the next available worker (round-robin).

Parameters:

task (dict[str, Any]) – Task payload with at minimum a task_id key.

Return type:

None

get_result(timeout=300.0)[source]

Block until a result is available.

Parameters:

timeout (float) – Maximum seconds to wait.

Returns:

Result from a worker.

Return type:

dict[str, Any]

Raises:

queue.Empty – If no result within timeout.

results_pending()[source]

Check if results are available without blocking.

Return type:

bool

shutdown(timeout=10.0)[source]

Send shutdown sentinels and join all workers.

Parameters:

timeout (float) – Maximum seconds to wait per worker.

Return type:

None


Base Backend and Combination Utilities

class homodyne.optimization.cmc.backends.base.CMCBackend[source]

Bases: ABC

Abstract base class for CMC execution backends.

Backends handle the parallel execution of MCMC sampling across data shards and the combination of results.

abstractmethod run(model, model_kwargs, config, shards=None)[source]

Run MCMC sampling (potentially across shards).

Parameters:
  • model (Callable) – NumPyro model function.

  • model_kwargs (dict[str, Any]) – Common model arguments.

  • config (CMCConfig) – CMC configuration.

  • shards (list[PreparedData] | None) – Data shards for parallel execution. If None, runs single-threaded on full data.

Returns:

Combined samples from all shards.

Return type:

MCMCSamples

abstractmethod get_name()[source]

Get backend name.

Return type:

str

is_available()[source]

Check if backend is available.

Returns:

True if backend can be used.

Return type:

bool

homodyne.optimization.cmc.backends.base.combine_shard_samples(shard_samples, method='weighted_gaussian', chunk_size=500)[source]

Combine samples from multiple shards.

For K <= chunk_size shards, uses a single-pass combination.

For K > chunk_size shards (hierarchical mode), accumulates posterior moments (mean, variance) across chunks without drawing intermediate synthetic samples. A single Gaussian draw is performed at the end from the aggregated moments. This avoids the precision-multiplication artefact that arises when recursive combination re-applies precision-weighting to synthetically drawn intermediate samples (P1-R6-01).

Memory scaling:

  • Each shard result: ~100KB (13 params x 4 chains x 1500 samples x 8 bytes)

  • Hierarchical (chunk=500): processes max(chunk_size) shards at once (~50MB), then releases them. Moment accumulation uses O(n_params) space.

Parameters:
  • shard_samples (list[MCMCSamples]) – Samples from each shard.

  • method (str) – Combination method: “robust_consensus_mc” (recommended), “consensus_mc”, “weighted_gaussian”, “simple_average”, or “auto”.

  • chunk_size (int) – Number of shards to process per chunk for hierarchical combination. Default 500 keeps peak memory under ~50MB per processing step.

Returns:

Combined samples.

Return type:

MCMCSamples


Usage Examples

Selecting the multiprocessing backend explicitly

optimization:
  cmc:
    backend_config:
      name: "multiprocessing"
    per_shard_mcmc:
      chain_method: "parallel"
      num_chains: 4

Inspecting worker count

import psutil
from homodyne.optimization.cmc.backends.multiprocessing import (
    MultiprocessingBackend,
)

backend = MultiprocessingBackend(config=cmc_config)
n_workers = max(1, psutil.cpu_count(logical=False) // 2 - 1)
print(f"Worker processes: {n_workers}")
print(f"Virtual devices/worker: 4 (via xla_force_host_platform_device_count)")
print(f"Total parallel chains: {n_workers * 4}")

Manually running a backend

from homodyne.optimization.cmc.backends import select_backend
from homodyne.optimization.cmc.config import CMCConfig

config = CMCConfig()
backend = select_backend(config.backend_name, config)

result = backend.run(
    shards=prepared_shards,
    parameter_space=param_space,
    model=xpcs_model,
    initial_values=init_vals,
)

Multiprocessing backend for CMC execution.

This module provides parallel MCMC execution using Python’s multiprocessing module for CPU-based parallelism.

Optimizations (v2.9.1): - Batch PRNG key generation: Pre-generate all shard keys in single JAX call - Adaptive polling: Adjust poll interval based on shard activity - Event.wait heartbeat: Efficient heartbeat using Event.wait(timeout)

Optimizations (v2.22.2): - LPT scheduling: Dispatch highest-cost shards first (size + noise weighted) - Per-shard shared memory: Shard arrays stored in shared memory (avoids pickle overhead) - deque for pending shards: O(1) popleft instead of O(n) list.pop(0) - JIT cache fix: Enable persistent compilation cache via jax.config.update (env var alone insufficient in JAX 0.8+, min_compile_time lowered to 0)