Source code for homodyne.optimization.cmc.backends.multiprocessing

"""Multiprocessing backend for CMC execution.

This module provides parallel MCMC execution using Python's
multiprocessing module for CPU-based parallelism.

Optimizations (v2.9.1):
- Batch PRNG key generation: Pre-generate all shard keys in single JAX call
- Adaptive polling: Adjust poll interval based on shard activity
- Event.wait heartbeat: Efficient heartbeat using Event.wait(timeout)

Optimizations (v2.22.2):
- LPT scheduling: Dispatch highest-cost shards first (size + noise weighted)
- Per-shard shared memory: Shard arrays stored in shared memory (avoids pickle overhead)
- deque for pending shards: O(1) popleft instead of O(n) list.pop(0)
- JIT cache fix: Enable persistent compilation cache via jax.config.update (env var alone insufficient in JAX 0.8+, min_compile_time lowered to 0)
"""

from __future__ import annotations

import logging
import multiprocessing as mp
import multiprocessing.shared_memory
import os
import queue
import threading
import time
from collections import deque
from collections.abc import Callable
from pathlib import Path
from typing import TYPE_CHECKING, Any

import numpy as np
from tqdm import tqdm

from homodyne.optimization.cmc.backends.base import (
    CMCBackend,
    combine_shard_samples,
    combine_shard_samples_bimodal,
)
from homodyne.optimization.cmc.diagnostics import (
    check_shard_bimodality,
    cluster_shard_modes,
    summarize_cross_shard_bimodality,
)
from homodyne.utils.logging import get_logger, log_exception, with_context

if TYPE_CHECKING:
    from homodyne.config.parameter_space import ParameterSpace
    from homodyne.optimization.cmc.config import CMCConfig
    from homodyne.optimization.cmc.data_prep import PreparedData
    from homodyne.optimization.cmc.sampler import MCMCSamples

logger = get_logger(__name__)

# Keys for per-shard numpy arrays stored in shared memory.
# Used by SharedDataManager.create_shared_shard_arrays() and _load_shared_shard_data().
_SHARD_ARRAY_KEYS = ("data", "t1", "t2", "phi_unique", "phi_indices")


[docs] class SharedDataManager: """Manages shared memory blocks for data common to all CMC shards. Uses multiprocessing.shared_memory to share config, parameter space, initial values, and time_grid across spawned worker processes, avoiding redundant pickling per shard. Note on serialization: Uses pickle internally for trusted config dicts only (CMCConfig.to_dict(), ParameterSpace). This matches the existing multiprocessing behavior which also pickles all process arguments. Must be used as a context manager or call cleanup() in a finally block. """
[docs] def __init__(self) -> None: self._shared_blocks: list[mp.shared_memory.SharedMemory] = [] self._refs: dict[str, dict[str, Any]] = {}
[docs] def create_shared_bytes(self, name: str, data: bytes) -> dict[str, Any]: """Store bytes in shared memory.""" shm = mp.shared_memory.SharedMemory(create=True, size=len(data)) shm.buf[: len(data)] = data self._shared_blocks.append(shm) ref = {"shm_name": shm.name, "size": len(data), "type": "bytes"} self._refs[name] = ref return ref
[docs] def create_shared_array(self, name: str, array: np.ndarray) -> dict[str, Any]: """Store a numpy array in shared memory.""" shm = mp.shared_memory.SharedMemory(create=True, size=array.nbytes) shared_arr = np.ndarray(array.shape, dtype=array.dtype, buffer=shm.buf) shared_arr[:] = array self._shared_blocks.append(shm) ref = { "shm_name": shm.name, "shape": array.shape, "dtype": str(array.dtype), "type": "array", } self._refs[name] = ref return ref
[docs] def create_shared_dict(self, name: str, d: dict) -> dict[str, Any]: """Serialize a trusted internal dict to shared memory. Only used for CMCConfig and ParameterSpace dicts — never for external/untrusted data. """ import pickle as _pkl # noqa: S403 — trusted internal data return self.create_shared_bytes(name, _pkl.dumps(d))
[docs] def create_shared_shard_arrays( self, shard_data_list: list[dict[str, Any]] ) -> list[dict[str, Any]]: """Place per-shard numpy arrays into shared memory (packed format). Instead of creating one SharedMemory segment per array per shard (n_shards * 5 = thousands of file descriptors), this concatenates all shard arrays for each key into a single shared memory block. Only 5 SharedMemory segments are created regardless of shard count. Parameters ---------- shard_data_list : list[dict[str, Any]] List of shard data dicts, each containing numpy arrays (data, t1, t2, phi_unique, phi_indices) and a scalar noise_scale. Returns ------- list[dict[str, Any]] List of lightweight shard references (shm names + offsets). Each ref dict is small enough to serialize cheaply through spawn. """ n_shards = len(shard_data_list) # For each array key, concatenate all shards into one block key_meta: dict[str, dict[str, Any]] = {} for key in _SHARD_ARRAY_KEYS: arrays = [] sizes = [] for sd in shard_data_list: arr = sd[key] if not isinstance(arr, np.ndarray): arr = np.asarray(arr) arr = np.ascontiguousarray(arr.ravel()) arrays.append(arr) sizes.append(arr.shape[0]) combined = np.concatenate(arrays) shm = mp.shared_memory.SharedMemory( create=True, size=max(1, combined.nbytes) ) shared_arr = np.ndarray( combined.shape, dtype=combined.dtype, buffer=shm.buf ) shared_arr[:] = combined self._shared_blocks.append(shm) # Compute per-shard offsets via cumulative sum offsets = [0] for s in sizes[:-1]: offsets.append(offsets[-1] + s) key_meta[key] = { "shm_name": shm.name, "dtype": str(combined.dtype), "offsets": offsets, "sizes": sizes, } # Build per-shard refs that workers can slice from the packed blocks shard_refs: list[dict[str, Any]] = [] for i in range(n_shards): ref: dict[str, Any] = {"noise_scale": shard_data_list[i]["noise_scale"]} for key in _SHARD_ARRAY_KEYS: meta = key_meta[key] ref[key] = { "shm_name": meta["shm_name"], "dtype": meta["dtype"], "offset": meta["offsets"][i], "size": meta["sizes"][i], } shard_refs.append(ref) return shard_refs
[docs] def cleanup(self) -> None: """Release all shared memory blocks. Must be called in a finally block.""" for shm in self._shared_blocks: try: shm.close() shm.unlink() except (FileNotFoundError, OSError): pass self._shared_blocks.clear() self._refs.clear()
def __enter__(self) -> SharedDataManager: return self def __exit__(self, *exc: object) -> None: self.cleanup()
def _load_shared_bytes(ref: dict[str, Any]) -> bytes: """Reconstruct bytes from a shared memory reference.""" shm = mp.shared_memory.SharedMemory(name=ref["shm_name"], create=False) try: data = bytes(shm.buf[: ref["size"]]) finally: shm.close() return data def _load_shared_dict(ref: dict[str, Any]) -> dict: """Reconstruct a trusted internal dict from shared memory. Only used for CMCConfig and ParameterSpace dicts — never for external/untrusted data. """ import pickle as _pkl # noqa: S403 — trusted internal data return _pkl.loads(_load_shared_bytes(ref)) # noqa: S301 # nosec B301 def _load_shared_array(ref: dict[str, Any]) -> np.ndarray: """Reconstruct a numpy array from a shared memory reference.""" shm = mp.shared_memory.SharedMemory(name=ref["shm_name"], create=False) try: arr = np.ndarray( ref["shape"], dtype=np.dtype(ref["dtype"]), buffer=shm.buf ).copy() # Copy so we don't hold a reference to the shared buffer finally: shm.close() return arr def _load_shared_shard_data(shard_ref: dict[str, Any]) -> dict[str, Any]: """Reconstruct per-shard data arrays from packed shared memory. Each array key maps to a single concatenated SharedMemory block shared across all shards. The per-shard ref carries ``offset`` (element index) and ``size`` (element count) to slice this shard's portion. Parameters ---------- shard_ref : dict[str, Any] Lightweight shard reference created by ``SharedDataManager.create_shared_shard_arrays``. Returns ------- dict[str, Any] Shard data dict with numpy arrays (copied from shared memory) and scalar noise_scale. """ shard_data: dict[str, Any] = {"noise_scale": shard_ref["noise_scale"]} for key in _SHARD_ARRAY_KEYS: arr_ref = shard_ref[key] shm = mp.shared_memory.SharedMemory(name=arr_ref["shm_name"], create=False) try: dtype = np.dtype(arr_ref["dtype"]) offset = arr_ref["offset"] size = arr_ref["size"] # Map the full concatenated buffer, then slice this shard's region total_elements = len(shm.buf) // dtype.itemsize full_arr = np.ndarray((total_elements,), dtype=dtype, buffer=shm.buf) arr = full_arr[offset : offset + size].copy() finally: shm.close() shard_data[key] = arr return shard_data def _generate_shard_keys(n_shards: int, seed: int = 42) -> list[tuple[int, ...]]: """Pre-generate all shard PRNG keys in a single JAX call. This is more efficient than generating keys one-at-a-time in each worker, as it amortizes JAX compilation overhead across all shards. Parameters ---------- n_shards : int Number of shards to generate keys for. seed : int Base seed for PRNG key generation. Returns ------- list[tuple[int, int]] List of (key_high, key_low) tuples that can be used to reconstruct JAX PRNG keys in worker processes without importing JAX here. """ import jax import jax.numpy as jnp # Generate all keys at once using jax.random.split base_key = jax.random.PRNGKey(seed) # Split into n_shards + 1 keys (first is throwaway, rest are for shards) all_keys = jax.random.split(base_key, n_shards + 1) shard_keys = all_keys[1:] # Skip the first key # Convert to serializable format (tuples of ints). # JAX ≤0.4.30 uses uint32[2]; JAX 0.4.31+ uses typed keys (key<fry>[]). # Flatten to raw uint32 array to handle both formats. key_tuples = [] for key in shard_keys: raw = jax.random.key_data(key).flatten().astype(jnp.uint32) key_tuples.append(tuple(int(x) for x in raw)) return key_tuples def _get_physical_cores() -> int: """Get physical core count using psutil for accurate detection. Falls back to os.cpu_count() // 2 if psutil unavailable. """ try: import psutil physical = psutil.cpu_count(logical=False) if physical is not None: return physical except ImportError: pass # Fallback: assume hyperthreading (logical = 2 * physical) import os return max(1, (os.cpu_count() or 1) // 2) def _compute_lpt_schedule( shard_data_list: list[dict[str, Any]], ) -> deque[int]: """Order shard indices by descending estimated cost (LPT heuristic). Cost = n_points * (1 + normalized_noise), where noise is linearly scaled to [0, 1] across shards. Dispatching the most expensive shards first minimizes tail latency on identical parallel workers. Parameters ---------- shard_data_list : list[dict[str, Any]] Shard dicts with ``"data"`` (array) and ``"noise_scale"`` (float). Returns ------- deque[int] Shard indices sorted by descending cost. """ n_shards = len(shard_data_list) sizes = [len(shard_data_list[i]["data"]) for i in range(n_shards)] noises = [shard_data_list[i]["noise_scale"] for i in range(n_shards)] max_noise = max(noises) if noises else 1.0 min_noise = min(noises) if noises else 1.0 noise_range = max_noise - min_noise if noise_range > 0: costs = [ sizes[i] * (1.0 + (noises[i] - min_noise) / noise_range) for i in range(n_shards) ] else: costs = [float(s) for s in sizes] return deque(sorted(range(n_shards), key=lambda i: costs[i], reverse=True)) def _compute_threads_per_worker(total_threads: int, workers: int) -> int: """Derive a conservative thread budget per worker to avoid oversubscription. Uses psutil for accurate physical core detection when available, otherwise approximates physical cores as half of logical (common HT layout). Divides the budget across workers, clamping to at least 1. """ physical_cores = _get_physical_cores() # Use physical cores as the safe pool (avoid hyperthreading contention) safe_pool = max(1, min(total_threads, physical_cores)) worker_count = max(1, workers) return max(1, safe_pool // worker_count) def _run_shard_worker_with_queue( shard_idx: int, shard_ref: dict[str, Any], model_fn: Callable, config_ref: dict[str, Any], initial_values_ref: dict[str, Any] | None, ps_ref: dict[str, Any], shared_kwargs_ref: dict[str, Any], time_grid_ref: dict[str, Any] | None, n_phi: int, analysis_mode: str, threads_per_worker: int, result_queue: mp.Queue, rng_key_tuple: tuple[int, ...] | None = None, ) -> None: """Worker function that puts result in a queue for proper timeout handling. Accepts shared memory references instead of full dicts to avoid redundant pickling. Reconstructs shared data from shared memory blocks. Wraps all initialization and sampling in a top-level try/except to ensure that crashes during setup (imports, config reconstruction, model_kwargs) are captured and reported back to the parent via the result queue. """ try: # Reconstruct per-shard arrays from shared memory (avoids pickle overhead) shard_data = _load_shared_shard_data(shard_ref) # Reconstruct shared data from shared memory config_dict = _load_shared_dict(config_ref) parameter_space_dict = _load_shared_dict(ps_ref) shared_kwargs = _load_shared_dict(shared_kwargs_ref) initial_values = ( _load_shared_dict(initial_values_ref) if initial_values_ref is not None else None ) # Reconstruct time_grid time_grid = ( _load_shared_array(time_grid_ref) if time_grid_ref is not None else None ) # Merge shared kwargs into shard_data for backward-compatible worker interface shard_data["time_grid"] = time_grid shard_data.update(shared_kwargs) result = _run_shard_worker( shard_idx=shard_idx, shard_data=shard_data, model_fn=model_fn, config_dict=config_dict, initial_values=initial_values, parameter_space_dict=parameter_space_dict, n_phi=n_phi, analysis_mode=analysis_mode, threads_per_worker=threads_per_worker, result_queue=result_queue, rng_key_tuple=rng_key_tuple, ) except Exception as e: # Catch crashes during initialization (shared memory, imports, config # reconstruction) that occur before _run_shard_worker's internal try/except. import traceback result = { "type": "result", "success": False, "shard_idx": shard_idx, "error": f"Worker initialization failed: {e}", "error_category": "init_crash", "traceback": traceback.format_exc(), "duration": 0.0, } try: result_queue.put_nowait(result) except Exception: # noqa: S110 - Best-effort queue put, parent handles failures # If the queue is already full or closed, drop the result; the parent # loop will have marked the shard as failed. This is a best-effort send. pass def _pool_worker_init(worker_id: int, **init_kwargs: Any) -> None: """One-time initialization for persistent pool workers. Configures JAX/OMP environment and imports JAX once. Subsequent shards processed by this worker skip re-initialization (env vars are idempotent, JAX modules are cached). Parameters ---------- worker_id : int Index of this worker in the pool (0-based). **init_kwargs : Any Must contain ``threads_per_worker`` (int). """ import re as _re threads_per_worker = init_kwargs["threads_per_worker"] # Thread pinning (same as _run_shard_worker lines 521-528) os.environ["OMP_NUM_THREADS"] = str(threads_per_worker) os.environ["MKL_NUM_THREADS"] = str(threads_per_worker) os.environ["OPENBLAS_NUM_THREADS"] = str(threads_per_worker) os.environ.pop("OMP_PROC_BIND", None) os.environ.pop("OMP_PLACES", None) # Float64 + XLA device count BEFORE JAX import (CLAUDE.md rule #8) os.environ["JAX_ENABLE_X64"] = "true" if "JAX_COMPILATION_CACHE_DIR" not in os.environ: os.environ["JAX_COMPILATION_CACHE_DIR"] = str( Path(os.path.expanduser("~/.cache/homodyne/jax_cache")) ) _num_chains = int(os.environ.get("HOMODYNE_CMC_NUM_CHAINS", "4")) _xla_flags = os.environ.get("XLA_FLAGS", "") _xla_flags = _re.sub(r"--xla_force_host_platform_device_count=\d+", "", _xla_flags) os.environ["XLA_FLAGS"] = ( _xla_flags.strip() + f" --xla_force_host_platform_device_count={_num_chains}" ) import jax jax.config.update("jax_enable_x64", True) # Persistent compilation cache (CLAUDE.md rule #9) _cache_dir = os.environ.get( "JAX_COMPILATION_CACHE_DIR", str(Path(os.path.expanduser("~/.cache/homodyne/jax_cache"))), ) jax.config.update("jax_compilation_cache_dir", _cache_dir) jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) def _pool_shard_worker(task: dict[str, Any], **init_kwargs: Any) -> None: """Per-shard worker for persistent pool dispatch. Reconstructs shard data from shared memory, runs NUTS sampling via ``_run_shard_worker``, and puts the result on ``result_queue``. Returns ``None`` so WorkerPool does not double-queue results. Parameters ---------- task : dict Per-shard task with keys: ``shard_idx``, ``shard_ref``, ``n_phi``, ``rng_key_tuple``. **init_kwargs : Any Shared across all shards: ``model_fn``, ``config_ref``, ``iv_ref``, ``ps_ref``, ``kwargs_ref``, ``tg_ref``, ``analysis_mode``, ``threads_per_worker``, ``result_queue``. """ shard_idx = task["shard_idx"] result_queue = init_kwargs["result_queue"] try: # Reconstruct per-shard arrays from shared memory shard_data = _load_shared_shard_data(task["shard_ref"]) # Reconstruct shared data from shared memory config_dict = _load_shared_dict(init_kwargs["config_ref"]) parameter_space_dict = _load_shared_dict(init_kwargs["ps_ref"]) shared_kwargs = _load_shared_dict(init_kwargs["kwargs_ref"]) initial_values = ( _load_shared_dict(init_kwargs["iv_ref"]) if init_kwargs["iv_ref"] is not None else None ) # Reconstruct time_grid time_grid = ( _load_shared_array(init_kwargs["tg_ref"]) if init_kwargs["tg_ref"] is not None else None ) # Merge shared kwargs into shard_data shard_data["time_grid"] = time_grid shard_data.update(shared_kwargs) result = _run_shard_worker( shard_idx=shard_idx, shard_data=shard_data, model_fn=init_kwargs["model_fn"], config_dict=config_dict, initial_values=initial_values, parameter_space_dict=parameter_space_dict, n_phi=task["n_phi"], analysis_mode=init_kwargs["analysis_mode"], threads_per_worker=init_kwargs["threads_per_worker"], result_queue=result_queue, rng_key_tuple=task["rng_key_tuple"], ) except MemoryError: result = { "type": "result", "success": False, "shard_idx": shard_idx, "error": f"Shard {shard_idx} failed: MemoryError", "error_category": "memory_error", "duration": 0.0, } except Exception as e: import traceback result = { "type": "result", "success": False, "shard_idx": shard_idx, "error": f"Shard {shard_idx} failed: {e}", "error_category": "shard_error", "traceback": traceback.format_exc(), "duration": 0.0, } try: result_queue.put(result, timeout=30) except Exception: # noqa: S110 pass return None def _run_shard_worker( shard_idx: int, shard_data: dict[str, Any], model_fn: Callable, config_dict: dict[str, Any], initial_values: dict[str, float] | None, parameter_space_dict: dict[str, Any], n_phi: int, analysis_mode: str, threads_per_worker: int = 2, result_queue: mp.Queue | None = None, rng_key_tuple: tuple[int, ...] | None = None, ) -> dict[str, Any]: """Worker function for processing a single shard. This runs in a separate process. Parameters ---------- shard_idx : int Shard index for logging. shard_data : dict[str, Any] Shard data dictionary. model_fn : Callable NumPyro model function. config_dict : dict[str, Any] CMC configuration as dict. initial_values : dict[str, float] | None Initial parameter values. parameter_space_dict : dict[str, Any] Serialized parameter space. n_phi : int Number of phi angles in this shard. analysis_mode : str Analysis mode. threads_per_worker : int Number of threads for JAX/XLA in this worker process. rng_key_tuple : tuple[int, ...] | None Pre-generated PRNG key as raw uint32 tuple. If None, generates a key based on shard_idx (legacy behavior). Returns ------- dict[str, Any] Serialized MCMCSamples. """ import os # Configure worker threading to avoid oversubscription across workers. # The parent process clears these before spawning, but we set them here # as a safety net in case the spawn context inherited stale values. os.environ["OMP_NUM_THREADS"] = str(threads_per_worker) os.environ["MKL_NUM_THREADS"] = str(threads_per_worker) os.environ["OPENBLAS_NUM_THREADS"] = str(threads_per_worker) # CRITICAL: Clear OMP_PROC_BIND and OMP_PLACES to prevent thread pinning. # When set, each worker's OpenMP runtime tries to pin threads to the same # physical cores, causing severe contention across concurrent workers. os.environ.pop("OMP_PROC_BIND", None) os.environ.pop("OMP_PLACES", None) # P0-1: Enable float64 precision and XLA device count BEFORE importing JAX. # Spawned workers don't inherit the parent's jax.config.x64_enabled state. # Without this, all JAX ops in workers run in float32, silently losing # precision for parameters spanning 6+ orders of magnitude (D0~1e4, gamma~1e-3). # P1-5: Ensure XLA_FLAGS propagates to spawned workers regardless of spawn # method (fork vs spawn). Parent sets JAX_ENABLE_X64 in homodyne/__init__.py # but spawn-mode workers start fresh processes and must re-set it. os.environ["JAX_ENABLE_X64"] = "true" if "JAX_COMPILATION_CACHE_DIR" not in os.environ: os.environ["JAX_COMPILATION_CACHE_DIR"] = str( Path(os.path.expanduser("~/.cache/homodyne/jax_cache")) ) # Dynamic XLA device count: read num_chains from env var set by parent process. # Previously hardcoded to 4, now matches config.num_chains so JAX sees the # correct number of virtual devices for parallel chain execution. import re as _re _num_chains = int(os.environ.get("HOMODYNE_CMC_NUM_CHAINS", "4")) _xla_flags = os.environ.get("XLA_FLAGS", "") _xla_flags = _re.sub(r"--xla_force_host_platform_device_count=\d+", "", _xla_flags) os.environ["XLA_FLAGS"] = ( _xla_flags.strip() + f" --xla_force_host_platform_device_count={_num_chains}" ) import jax jax.config.update("jax_enable_x64", True) # C3: Enable persistent compilation cache so subsequent workers reuse # compiled XLA programs from the first worker. The env var alone is # insufficient in JAX 0.8+ — we must also call jax.config.update(). # Additionally, CMC functions compile in 0.07-0.15s, below the default # 1.0s min_compile_time threshold, so we lower it to 0. _cache_dir = os.environ.get( "JAX_COMPILATION_CACHE_DIR", str(Path(os.path.expanduser("~/.cache/homodyne/jax_cache"))), ) jax.config.update("jax_compilation_cache_dir", _cache_dir) jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) import jax.numpy as jnp from homodyne.config.parameter_space import ParameterSpace from homodyne.optimization.cmc.config import CMCConfig from homodyne.optimization.cmc.sampler import run_nuts_sampling start_time = time.perf_counter() worker_logger = get_logger( __name__, context={"run": config_dict.get("run_id"), "shard": shard_idx}, ) # T044: Log shard start with data range and point count n_points = len(shard_data["data"]) worker_logger.info( f"Shard {shard_idx} starting: {n_points:,} points, " f"n_phi={n_phi}, mode={analysis_mode}" ) # Reconstruct objects from dicts config = CMCConfig.from_dict(config_dict) # Reconstruct ParameterSpace # For now, we pass the config dict directly parameter_space = ParameterSpace.from_config( config_dict=parameter_space_dict, analysis_mode=analysis_mode, ) # Create RNG key for this shard # Use pre-generated key if available (batch optimization), else generate locally if rng_key_tuple is not None: # Reconstruct JAX PRNG key from raw uint32 data (handles both # legacy uint32[2] and typed-key formats via key_data round-trip) rng_key = jax.random.wrap_key_data(jnp.array(rng_key_tuple, dtype=jnp.uint32)) else: # Legacy behavior: generate key based on shard index rng_key = jax.random.PRNGKey(42 + shard_idx) # Prepare model kwargs - must match xpcs_model() signature # jnp.asarray avoids a copy when the source is already a contiguous ndarray. model_kwargs = { "data": jnp.asarray(shard_data["data"]), "t1": jnp.asarray(shard_data["t1"]), "t2": jnp.asarray(shard_data["t2"]), "phi_unique": jnp.asarray(shard_data["phi_unique"]), "phi_indices": jnp.asarray(shard_data["phi_indices"]), "q": shard_data["q"], "L": shard_data["L"], "dt": shard_data["dt"], "time_grid": ( jnp.asarray(shard_data["time_grid"]) if shard_data.get("time_grid") is not None else None ), "analysis_mode": analysis_mode, "parameter_space": parameter_space, "n_phi": n_phi, "noise_scale": shard_data.get("noise_scale", 0.1), } # Restore fixed parameters if present if shard_data.get("fixed_contrast") is not None: model_kwargs["fixed_contrast"] = shard_data["fixed_contrast"] if shard_data.get("fixed_offset") is not None: model_kwargs["fixed_offset"] = shard_data["fixed_offset"] # Restore per_angle_mode, nlsq_prior_config, and propagated kwargs per_angle_mode = shard_data.get("per_angle_mode", "individual") model_kwargs["per_angle_mode"] = per_angle_mode if shard_data.get("nlsq_prior_config") is not None: model_kwargs["nlsq_prior_config"] = shard_data["nlsq_prior_config"] # Restore num_shards for prior tempering (Feb 2026 fix) if shard_data.get("num_shards") is not None: model_kwargs["num_shards"] = shard_data["num_shards"] # Restore reparam_config from serialized dict (Feb 2026 fix) if shard_data.get("reparam_config_dict") is not None: from homodyne.optimization.cmc.reparameterization import ReparamConfig model_kwargs["reparam_config"] = ReparamConfig( **shard_data["reparam_config_dict"] ) # Restore t_ref for reference-time reparameterization if shard_data.get("t_ref") is not None: model_kwargs["t_ref"] = shard_data["t_ref"] # M1-worker: Free the numpy shard_data dict now that all values have been # extracted into model_kwargs (as JAX arrays) or as scalars. This releases # the numpy copies of data/t1/t2/phi_indices/time_grid, which otherwise # stay alive for the entire sampling duration alongside their JAX twins. del shard_data # P0-1: Pre-compute scaling factors ONCE before NUTS starts. from homodyne.optimization.cmc.scaling import compute_scaling_factors model_kwargs["scalings"] = compute_scaling_factors( parameter_space, n_phi, analysis_mode ) # P0-2: Pre-compute physics prefactors (constant for entire shard). import math as _math _q = model_kwargs["q"] _L = model_kwargs["L"] _dt = model_kwargs["dt"] model_kwargs["wavevector_q_squared_half_dt"] = jnp.asarray(0.5 * (_q**2) * _dt) model_kwargs["sinc_prefactor"] = jnp.asarray(0.5 / _math.pi * _q * _L * _dt) # P1-3: Pre-compute point_idx array (constant for entire shard). model_kwargs["point_idx"] = jnp.arange( model_kwargs["phi_indices"].shape[0], dtype=jnp.int32 ) # D2: Pre-compute shard-constant quantities (time_safe + searchsorted indices) # once before NUTS starts. Eliminates redundant work on every leapfrog step. try: from homodyne.core.physics_cmc import precompute_shard_grid _t1 = model_kwargs["t1"] _t2 = model_kwargs["t2"] _time_grid = model_kwargs.get("time_grid") _dt = model_kwargs.get("dt", 1e-3) if _time_grid is not None: model_kwargs["shard_grid"] = precompute_shard_grid( _time_grid, _t1, _t2, _dt ) except (ImportError, ValueError, RuntimeError) as _exc: # Non-fatal: fall back to legacy compute_g1_total path in model.py worker_logger.warning( f"precompute_shard_grid failed (using legacy path): " f"{type(_exc).__name__}: {_exc}" ) # Heartbeat thread to emit liveness updates back to the parent. # Optimization: Use Event.wait(timeout) instead of busy-wait loop. # This reduces wake-ups by 75% (from 4 per interval to 1). stop_hb = threading.Event() heartbeat_interval = 30.0 def _heartbeat_loop() -> None: while True: # Wait for stop signal OR timeout (whichever comes first) # This is much more efficient than sleep + check loop if stop_hb.wait(timeout=heartbeat_interval): # Event was set - time to exit break # Timeout expired - send heartbeat payload = { "type": "heartbeat", "shard_idx": shard_idx, "elapsed": time.perf_counter() - start_time, } if result_queue is not None: try: result_queue.put_nowait(payload) except Exception: # noqa: S110 - Best-effort heartbeat pass hb_thread = threading.Thread(target=_heartbeat_loop, daemon=True) hb_thread.start() try: # Run sampling samples, stats = run_nuts_sampling( model=model_fn, model_kwargs=model_kwargs, config=config, initial_values=initial_values, parameter_space=parameter_space, n_phi=n_phi, analysis_mode=analysis_mode, rng_key=rng_key, progress_bar=False, # Disable in worker per_angle_mode=per_angle_mode, ) duration = time.perf_counter() - start_time # M1-worker: Free shard input arrays now that sampling is done. # model_kwargs holds data/t1/t2/phi_indices/time_grid/shard_grid as # JAX arrays that are no longer needed. Free them before serializing # the result, so peak memory during serialization is lower. model_kwargs.clear() # T045: Log shard completion with elapsed time, acceptance rate, divergence count divergence_str = ( f", divergences: {stats.num_divergent}" if stats.num_divergent > 0 else "" ) worker_logger.info( f"Shard {shard_idx} completed in {duration:.2f}s: " f"{samples.n_samples} samples/chain x {samples.n_chains} chains{divergence_str}" ) if stats.num_divergent > 0: worker_logger.warning( f"Shard {shard_idx} had {stats.num_divergent} divergent transitions" ) # Serialize for return result = { "type": "result", "success": True, "shard_idx": shard_idx, "samples": {k: np.array(v) for k, v in samples.samples.items()}, "param_names": samples.param_names, "n_chains": samples.n_chains, "n_samples": samples.n_samples, "extra_fields": {k: np.array(v) for k, v in samples.extra_fields.items()}, "duration": duration, "stats": { "warmup_time": stats.warmup_time, "sampling_time": stats.sampling_time, "total_time": stats.total_time, "num_divergent": stats.num_divergent, "n_warmup": stats.plan.n_warmup if stats.plan else None, "n_samples": stats.plan.n_samples if stats.plan else None, }, } # Result is returned to _run_shard_in_process, which puts it on the queue. # Do NOT put it here — that would cause double-queuing. return result except Exception as e: duration = time.perf_counter() - start_time # Classify error type for diagnostics error_str = str(e).lower() if "nan" in error_str or "inf" in error_str or "singular" in error_str: error_category = "numerical" elif "convergence" in error_str or "diverge" in error_str: error_category = "convergence" else: error_category = "sampling" # T028: Log exception with structured context for debugging log_exception( worker_logger, e, context={ "shard_idx": shard_idx, "duration_s": round(duration, 2), "error_category": error_category, "n_points": n_points, }, ) result = { "type": "result", "success": False, "shard_idx": shard_idx, "error": str(e), "error_category": error_category, "duration": duration, } # Result is returned to _run_shard_in_process, which puts it on the queue. # Do NOT put it here — that would cause double-queuing. return result finally: stop_hb.set() hb_thread.join(timeout=1) def _log_bimodality_summary( run_logger: Any, summary: dict[str, Any], ) -> None: """Log a structured cross-shard bimodality analysis. Parameters ---------- run_logger Logger instance (supports .info()). summary : dict[str, Any] Output from summarize_cross_shard_bimodality(). """ per_param = summary.get("per_param", {}) co_occurrence = summary.get("co_occurrence", {}) n_detections = summary.get("n_detections", 0) n_shards = summary.get("n_shards", 0) if not per_param: run_logger.warning( f"Detected {n_detections} bimodal posteriors across shards, " f"but none exceeded the significance threshold." ) return sep = "=" * 80 dash = "-" * 80 lines = [ sep, f"BIMODALITY ANALYSIS ({n_detections} detections across {n_shards} shards)", sep, f"{'Parameter':<14} {'Bimodal%':>8} {'Mode 1 (mean +/- std)':>24} " f"{'Mode 2 (mean +/- std)':>24} {'Sep.':>5}", dash, ] for param, stats in sorted(per_param.items()): pct = f"{stats['bimodal_fraction']:.1%}" m1 = f"{stats['lower_mean']:.3g} +/- {stats['lower_std']:.2g}" m2 = f"{stats['upper_mean']:.3g} +/- {stats['upper_std']:.2g}" sig = f"{stats['sep_significance']:.1f}x" lines.append(f"{param:<14} {pct:>8} {m1:>24} {m2:>24} {sig:>5}") lines.append(dash) # Consensus impact section impact_lines: list[str] = [] for param, stats in per_param.items(): if stats["consensus_in_trough"]: impact_lines.append( f" {param} consensus mean falls between modes (density trough)" ) d0_alpha_frac = co_occurrence.get("d0_alpha_fraction") if d0_alpha_frac is not None and "D0" in per_param: impact_lines.append( f" D0-alpha co-occurrence: {d0_alpha_frac:.0%} of D0-bimodal " f"shards also bimodal in alpha" ) if d0_alpha_frac > 0.3: impact_lines.append( " -> Likely parameter degeneracy: different (D0, alpha) " "pairs produce similar D(t)" ) if impact_lines: lines.append("CONSENSUS IMPACT:") lines.extend(impact_lines) lines.append("GUIDANCE:") lines.append(" - NLSQ result likely converged to one mode; CMC captures both") lines.append( " - Consider increasing shard size or using tighter NLSQ-informed priors" ) lines.append(sep) for line in lines: run_logger.info(line)
[docs] class MultiprocessingBackend(CMCBackend): """CMC backend using Python multiprocessing. Runs MCMC sampling in parallel across CPU cores using Python's multiprocessing module. """
[docs] def __init__( self, n_workers: int | None = None, spawn_method: str = "spawn", ): """Initialize multiprocessing backend. Parameters ---------- n_workers : int | None Number of worker processes. If None, uses CPU count. spawn_method : str Process start method: "spawn", "fork", or "forkserver". """ from homodyne.optimization.cmc.backends.worker_pool import ( _estimate_physical_workers, ) if n_workers is None: n_workers = _estimate_physical_workers() else: # Cap user-specified workers to estimated worker count n_workers = min(n_workers, _estimate_physical_workers()) self.n_workers = max(1, n_workers) self.spawn_method = spawn_method
[docs] def get_name(self) -> str: """Get backend name.""" return f"multiprocessing({self.n_workers} workers)"
[docs] def run( self, model: Callable, model_kwargs: dict[str, Any], config: CMCConfig, shards: list[PreparedData] | None = None, initial_values: dict[str, float] | None = None, parameter_space: ParameterSpace | None = None, analysis_mode: str = "static", progress_bar: bool = True, ) -> MCMCSamples: """Run MCMC sampling across shards. Parameters ---------- model : Callable NumPyro model function. model_kwargs : dict[str, Any] Common model arguments. config : CMCConfig CMC configuration. shards : list[PreparedData] | None Data shards. initial_values : dict[str, float] | None Initial parameter values. parameter_space : ParameterSpace Parameter space for priors. analysis_mode : str Analysis mode. progress_bar : bool Whether to show progress bar for shard completion. Returns ------- MCMCSamples Combined samples from all shards. """ from homodyne.optimization.cmc.sampler import MCMCSamples, run_nuts_sampling run_logger = with_context( logger, run=getattr(config, "run_id", None), backend="multiprocessing", ) if shards is None or len(shards) <= 1: # Single shard - run directly without multiprocessing run_logger.info("Running single-shard MCMC (no parallelization)") samples, stats = run_nuts_sampling( model=model, model_kwargs=model_kwargs, config=config, initial_values=initial_values, parameter_space=parameter_space, n_phi=model_kwargs.get("n_phi", 1), analysis_mode=analysis_mode, progress_bar=True, per_angle_mode=model_kwargs.get("per_angle_mode", "individual"), ) return samples # Multiple shards - run in parallel with per-shard timeout enforcement n_shards = len(shards) actual_workers = min(self.n_workers, n_shards) # Calculate threads per worker to avoid over-subscription total_threads = mp.cpu_count() threads_per_worker = _compute_threads_per_worker(total_threads, actual_workers) if threads_per_worker < max(1, total_threads // max(1, actual_workers)): run_logger.info( f"Capping threads to avoid oversubscription: logical={total_threads}, " f"workers={actual_workers} -> {threads_per_worker} threads/worker" ) run_logger.info( f"Running {n_shards} shards in parallel with {actual_workers} workers " f"({threads_per_worker} threads each)" ) # Per-shard timeout - enforced per individual process per_shard_timeout = config.per_shard_timeout # Default: 7200s (2 hours) run_logger.info( f"Per-shard timeout: {per_shard_timeout / 3600:.1f} hours " f"(processes will be terminated if exceeded)" ) run_logger.info( f"Heartbeat timeout: {config.heartbeat_timeout}s " f"(unresponsive workers will be terminated)" ) # Prepare shard data for workers # Separate per-shard data from shared data to reduce pickling overhead. # Shared data (config, parameter_space, time_grid, model kwargs) is placed # in shared memory once; only per-shard arrays are pickled per process. shared_kwargs = { "q": model_kwargs["q"], "L": model_kwargs["L"], "dt": model_kwargs["dt"], "fixed_contrast": model_kwargs.get("fixed_contrast"), "fixed_offset": model_kwargs.get("fixed_offset"), "global_phi_unique": model_kwargs.get("global_phi_unique"), "per_angle_mode": model_kwargs.get("per_angle_mode", "individual"), "nlsq_prior_config": model_kwargs.get("nlsq_prior_config"), "num_shards": model_kwargs.get("num_shards", 1), "t_ref": model_kwargs.get("t_ref"), "reparam_config_dict": ( { "enable_d_ref": model_kwargs["reparam_config"].enable_d_ref, "enable_gamma_ref": model_kwargs["reparam_config"].enable_gamma_ref, "t_ref": model_kwargs["reparam_config"].t_ref, } if model_kwargs.get("reparam_config") is not None else None ), } shard_data_list = [] for shard in shards: shard_data_list.append( { # np.asarray avoids a copy when shard arrays are already ndarrays. "data": np.asarray(shard.data), "t1": np.asarray(shard.t1), "t2": np.asarray(shard.t2), "phi_unique": np.asarray(shard.phi_unique), "phi_indices": np.asarray(shard.phi_indices), "noise_scale": shard.noise_scale, } ) # shard_data_list is kept for LPT scheduling (size lookup). # Actual worker data will be served from shared memory (see below). # Serialize config and parameter_space config_dict = config.to_dict() if hasattr(parameter_space, "_config_dict"): ps_dict = parameter_space._config_dict else: ps_dict = model_kwargs.get("config_dict", {}) if not ps_dict: run_logger.error( "ParameterSpace._config_dict is absent and no 'config_dict' in " "model_kwargs. Workers will reconstruct ParameterSpace from an " "empty dict (default bounds). This may produce unconstrained or " "incorrect NUTS proposals. Ensure ParameterSpace exposes " "_config_dict or pass config_dict in model_kwargs." ) # Place shared data in shared memory to avoid per-shard pickling. # Wrap in try-except so partially-created blocks are cleaned up on failure. shared_mgr = SharedDataManager() try: shared_config_ref = shared_mgr.create_shared_dict("config", config_dict) shared_ps_ref = shared_mgr.create_shared_dict("ps", ps_dict) shared_kwargs_ref = shared_mgr.create_shared_dict("kwargs", shared_kwargs) # Share time_grid as array if present (can be large) time_grid_raw = model_kwargs.get("time_grid") if time_grid_raw is not None: shared_tg_ref: dict[str, Any] | None = shared_mgr.create_shared_array( "time_grid", np.array(time_grid_raw) ) else: shared_tg_ref = None # Share initial_values as dict shared_iv_ref: dict[str, Any] | None = None if initial_values is not None: shared_iv_ref = shared_mgr.create_shared_dict( "init_vals", initial_values ) # Place per-shard arrays in shared memory to avoid per-process # serialization overhead through spawn. Each shard's 5 arrays # (data, t1, t2, phi_unique, phi_indices) are stored once; # workers reconstruct them via _load_shared_shard_data(). shared_shard_refs = shared_mgr.create_shared_shard_arrays(shard_data_list) except Exception: shared_mgr.cleanup() raise # Sentinel variables for the finally block — must be defined before # try so that cleanup never hits NameError on early exceptions. _saved_env: dict[str, str | None] = {} active_processes: dict[int, tuple[mp.Process, float]] = {} pbar = None # All setup from here through the main loop is wrapped in try/finally # to ensure shared_mgr.cleanup() runs even if _generate_shard_keys(), # ctx.Queue(), or any other pre-loop setup raises. try: run_logger.debug( f"Shared memory allocated: {len(shared_mgr._shared_blocks)} blocks" ) # Pre-generate all shard PRNG keys in single JAX call (batch optimization) # This amortizes JAX compilation overhead across all shards run_logger.debug(f"Pre-generating {n_shards} PRNG keys...") key_gen_start = time.time() shard_keys = _generate_shard_keys(n_shards, seed=config.seed) key_gen_time = time.time() - key_gen_start run_logger.debug(f"PRNG key generation completed in {key_gen_time:.3f}s") # Use spawn context for clean process isolation ctx = mp.get_context(self.spawn_method) result_queue = ctx.Queue() # Temporarily adjust parent environment before spawning workers. # spawn'd children inherit the parent's env at Process.start() time. # configure_optimal_device() sets OMP_PROC_BIND=true and # OMP_NUM_THREADS=<physical_cores> for the parent process, but workers # must NOT inherit these — they cause massive thread oversubscription # (e.g. 9 workers × 14 OMP threads = 126 threads on 14 cores). _worker_env_overrides = { "OMP_NUM_THREADS": str(threads_per_worker), "MKL_NUM_THREADS": str(threads_per_worker), "OPENBLAS_NUM_THREADS": str(threads_per_worker), "VECLIB_MAXIMUM_THREADS": str(threads_per_worker), # Pass num_chains so workers set XLA device count dynamically # instead of the previous hardcoded value of 4. "HOMODYNE_CMC_NUM_CHAINS": str(config.num_chains), } _worker_env_clear = ["OMP_PROC_BIND", "OMP_PLACES"] for key in _worker_env_clear: _saved_env[key] = os.environ.pop(key, None) for key, val in _worker_env_overrides.items(): _saved_env[key] = os.environ.get(key) os.environ[key] = val # LPT scheduling: dispatch highest-cost shards first to minimize # tail latency. See _compute_lpt_schedule() for cost model. pending_shards = _compute_lpt_schedule(shard_data_list) if n_shards > 1: sizes = [len(sd["data"]) for sd in shard_data_list] noises = [sd["noise_scale"] for sd in shard_data_list] run_logger.debug( f"LPT scheduling: shard sizes range " f"[{min(sizes):,}, {max(sizes):,}], " f"noise range [{min(noises):.4g}, {max(noises):.4g}], " f"dispatching highest-cost first" ) # Determine dispatch strategy: persistent pool vs per-shard spawn from homodyne.optimization.cmc.backends.worker_pool import ( should_use_pool, ) use_pool = should_use_pool(n_shards=n_shards, n_workers=actual_workers) pool = None # type: ignore[assignment] if use_pool: try: from homodyne.optimization.cmc.backends.worker_pool import ( WorkerPool, ) pool = WorkerPool( n_workers=actual_workers, worker_fn=_pool_shard_worker, worker_init_kwargs={ "model_fn": model, "config_ref": shared_config_ref, "iv_ref": shared_iv_ref, "ps_ref": shared_ps_ref, "kwargs_ref": shared_kwargs_ref, "tg_ref": shared_tg_ref, "analysis_mode": analysis_mode, "threads_per_worker": threads_per_worker, "result_queue": result_queue, }, worker_init_fn=_pool_worker_init, ) # Submit all shards (LPT-ordered via pending_shards) for shard_idx in list(pending_shards): pool.submit( { "task_id": f"shard_{shard_idx}", "shard_idx": shard_idx, "shard_ref": shared_shard_refs[shard_idx], "n_phi": shards[shard_idx].n_phi, "rng_key_tuple": shard_keys[shard_idx], } ) # Clear pending so per-shard spawn loop is a no-op pending_shards.clear() run_logger.info( "WorkerPool dispatched %d shards to %d persistent workers", n_shards, actual_workers, ) except (OSError, RuntimeError, MemoryError) as exc: run_logger.warning( "WorkerPool creation failed (%s), " "falling back to per-shard spawn", exc, ) pool = None else: run_logger.debug( "Per-shard spawn: %d shards < 3, pool not beneficial", n_shards, ) # M1-parent: Free per-shard numpy arrays now that they have been # copied into shared memory (via create_shared_shard_arrays above). del shard_data_list results = [] completed_count = 0 recorded_shards: set[int] = set() last_heartbeat: dict[int, float] = {} # EARLY ABORT TRACKING (Jan 2026): Monitor failure rate for early termination # If too many shards fail early, abort to save compute time early_abort_threshold = 0.5 # Abort if >50% of first N shards fail early_abort_sample_size = min(10, n_shards) # Check first 10 shards failure_categories: dict[str, int] = { "timeout": 0, "heartbeat_timeout": 0, "crash": 0, "numerical": 0, "convergence": 0, "sampling": 0, "unknown": 0, } success_count = 0 early_abort_triggered = False # Progress bar pbar = tqdm( total=n_shards, desc="CMC shards", disable=not progress_bar, unit="shard", position=0, leave=True, dynamic_ncols=True, ) pbar.set_postfix_str("starting...") pbar.refresh() start_time = time.time() # Adaptive polling: start with faster polling, slow down as shards run longer poll_interval_min = 0.5 # Fast polling during startup poll_interval_max = 5.0 # Slow polling during long-running shards poll_interval = poll_interval_min last_completion_time = start_time # Track when last shard completed status_log_interval = 300.0 # parent status log every 5 minutes last_status_log = start_time shards_launched = 0 while completed_count < n_shards: # Drain queue first to capture heartbeats and completed shards while True: try: message = result_queue.get_nowait() except queue.Empty: break except Exception as exc: run_logger.warning(f"Queue read error: {exc}") break msg_type = message.get("type") shard_idx = message.get("shard_idx") if msg_type == "heartbeat" and shard_idx is not None: last_heartbeat[shard_idx] = time.time() continue if msg_type == "result" or message.get("success") is not None: # Guard against duplicate results from timed-out shards # whose results arrive late via the queue. if shard_idx is not None and shard_idx in recorded_shards: run_logger.debug( f"Ignoring duplicate result for shard {shard_idx}" ) continue results.append(message) if shard_idx is not None: recorded_shards.add(shard_idx) completed_count += 1 pbar.update(1) # Reset to fast polling on completion (adaptive polling) last_completion_time = time.time() poll_interval = poll_interval_min # Track success/failure for early abort logic if message.get("success"): success_count += 1 pbar.set_postfix( shard=message.get("shard_idx", "?"), time=f"{message.get('duration', 0):.1f}s", ) else: # Track failure category category = message.get("error_category", "unknown") if category in failure_categories: failure_categories[category] += 1 else: failure_categories["unknown"] += 1 pbar.set_postfix( shard=message.get("shard_idx", "?"), status="failed", ) # EARLY ABORT CHECK (Jan 2026): Abort if too many shards fail early if ( not early_abort_triggered and completed_count >= early_abort_sample_size and completed_count <= early_abort_sample_size + 2 ): total_failures = sum(failure_categories.values()) failure_rate = total_failures / completed_count if failure_rate > early_abort_threshold: early_abort_triggered = True run_logger.error( f"EARLY ABORT: {failure_rate:.1%} failure rate in first " f"{completed_count} shards exceeds {early_abort_threshold:.0%} threshold.\n" f"Failure breakdown: {failure_categories}\n" f"Terminating remaining shards to save compute time." ) # Clear pending shards and terminate active processes/pool pending_shards.clear() if pool is not None: pool.shutdown(timeout=5) pool = None for idx, (proc, _) in list(active_processes.items()): run_logger.info( f"Terminating shard {idx} due to early abort" ) proc.terminate() proc.join(timeout=2) if proc.is_alive(): proc.kill() proc.join(timeout=1) active_processes.pop(idx, None) if shard_idx in active_processes: # Clean up completed process tracking proc, _ = active_processes.pop(shard_idx) if proc.is_alive(): proc.join(timeout=1) continue if run_logger.isEnabledFor(logging.DEBUG): run_logger.debug( f"Ignoring unexpected queue message: {message}" ) # Launch new processes up to max workers while len(active_processes) < actual_workers and pending_shards: shard_idx = pending_shards.popleft() process = ctx.Process( target=_run_shard_worker_with_queue, args=( shard_idx, shared_shard_refs[shard_idx], model, shared_config_ref, shared_iv_ref, shared_ps_ref, shared_kwargs_ref, shared_tg_ref, shards[shard_idx].n_phi, analysis_mode, threads_per_worker, result_queue, shard_keys[shard_idx], # Pre-generated PRNG key ), ) process.start() now = time.time() active_processes[shard_idx] = (process, now) last_heartbeat[shard_idx] = now shards_launched += 1 # Check for completed or timed-out processes for shard_idx, (process, proc_start_time) in list( active_processes.items() ): # Skip shards already recorded by queue drain (prevents # double-counting when queue result arrives in the same # loop iteration as the process exit detection). if shard_idx in recorded_shards: del active_processes[shard_idx] continue now = time.time() proc_elapsed = now - proc_start_time last_active = last_heartbeat.get(shard_idx, proc_start_time) inactive_elapsed = now - last_active if not process.is_alive(): process.join(timeout=1) exit_code = process.exitcode del active_processes[shard_idx] run_logger.debug( f"Shard {shard_idx} process exited after {proc_elapsed:.1f}s " f"(exit_code={exit_code})" ) if shard_idx not in recorded_shards: # Build descriptive error with exit code context if exit_code is not None and exit_code < 0: import signal as _signal sig_name = _signal.Signals(-exit_code).name error_msg = ( f"Process killed by signal {sig_name} " f"(exit_code={exit_code})" ) elif exit_code is not None and exit_code > 0: error_msg = ( f"Process exited with error (exit_code={exit_code})" ) else: error_msg = "Process exited without returning a result" results.append( { "type": "result", "success": False, "shard_idx": shard_idx, "error": error_msg, "error_category": "crash", "duration": proc_elapsed, } ) recorded_shards.add(shard_idx) failure_categories["crash"] += 1 completed_count += 1 pbar.update(1) pbar.set_postfix(shard=shard_idx, status="no-result") elif proc_elapsed > per_shard_timeout: # Total runtime exceeded - terminate regardless of heartbeats run_logger.warning( f"Shard {shard_idx} exceeded runtime limit: {proc_elapsed:.0f}s " f"(limit: {per_shard_timeout}s), terminating process (pid={process.pid})" ) process.terminate() process.join(timeout=5) if process.is_alive(): run_logger.warning( f"Shard {shard_idx} did not terminate, killing" ) process.kill() process.join(timeout=2) del active_processes[shard_idx] if shard_idx not in recorded_shards: results.append( { "type": "result", "success": False, "shard_idx": shard_idx, "error": f"Runtime timeout after {proc_elapsed:.0f}s (limit: {per_shard_timeout}s)", "error_category": "timeout", "duration": proc_elapsed, } ) recorded_shards.add(shard_idx) failure_categories["timeout"] += 1 completed_count += 1 pbar.update(1) pbar.set_postfix(shard=shard_idx, status="timeout") elif inactive_elapsed > config.heartbeat_timeout: # No heartbeat for configured timeout - process likely frozen run_logger.warning( f"Shard {shard_idx} unresponsive for {inactive_elapsed:.0f}s " f"(heartbeat timeout: {config.heartbeat_timeout}s), " f"terminating process (pid={process.pid})" ) process.terminate() process.join(timeout=5) if process.is_alive(): run_logger.warning( f"Shard {shard_idx} did not terminate, killing" ) process.kill() process.join(timeout=2) del active_processes[shard_idx] if shard_idx not in recorded_shards: results.append( { "type": "result", "success": False, "shard_idx": shard_idx, "error": f"Unresponsive after {inactive_elapsed:.0f}s (heartbeat timeout: {config.heartbeat_timeout}s)", "error_category": "heartbeat_timeout", "duration": proc_elapsed, } ) recorded_shards.add(shard_idx) failure_categories["heartbeat_timeout"] += 1 completed_count += 1 pbar.update(1) pbar.set_postfix(shard=shard_idx, status="frozen") # Update progress bar with elapsed time if completed_count < n_shards: elapsed = time.time() - start_time mins, secs = divmod(int(elapsed), 60) hrs, mins = divmod(mins, 60) if hrs > 0: pbar.set_postfix_str( f"active={len(active_processes)} elapsed={hrs}h{mins:02d}m" ) else: pbar.set_postfix_str( f"active={len(active_processes)} elapsed={mins}m{secs:02d}s" ) _now = time.time() if _now - last_status_log >= status_log_interval: # Only show heartbeats for active processes active_heartbeats = { k: f"{_now - last_heartbeat.get(k, _now):.0f}s" for k in active_processes } run_logger.info( f"CMC status: {completed_count}/{n_shards} complete; " f"active={len(active_processes)}; " f"launched={shards_launched}; " f"heartbeats={active_heartbeats}" ) last_status_log = _now # Adaptive polling: gradually increase interval if no recent completions # This reduces CPU overhead during long-running shards time_since_completion = time.time() - last_completion_time if time_since_completion > 30.0: # Gradually increase poll interval (10% per 30s of inactivity) poll_interval = min( poll_interval * 1.1, poll_interval_max, ) time.sleep(poll_interval) # Pool stall detection: if no result has arrived for longer than # per_shard_timeout, the pool workers are likely stuck. Shut down # the pool and mark remaining shards as timed out. if ( pool is not None and completed_count < n_shards and (time.time() - last_completion_time) > per_shard_timeout ): stall_elapsed = time.time() - last_completion_time run_logger.warning( "WorkerPool stall detected: no result for %.0fs " "(per_shard_timeout=%.0fs). Shutting down pool.", stall_elapsed, per_shard_timeout, ) pool.shutdown(timeout=5) pool = None # Mark all remaining shards as timed out missing = set(range(n_shards)) - recorded_shards for shard_idx in sorted(missing): results.append( { "type": "result", "success": False, "shard_idx": shard_idx, "error": f"Pool stall timeout after {stall_elapsed:.0f}s", "error_category": "timeout", "duration": stall_elapsed, } ) recorded_shards.add(shard_idx) completed_count += 1 pbar.update(1) pbar.set_postfix(shard=shard_idx, status="pool-timeout") # If no processes remain and nothing is pending, mark any missing shards as failed if ( not active_processes and not pending_shards and pool is None and completed_count < n_shards ): missing = set(range(n_shards)) - recorded_shards for shard_idx in sorted(missing): results.append( { "success": False, "shard_idx": shard_idx, "error": "Shard exited without emitting a result", "error_category": "crash", "duration": None, } ) recorded_shards.add(shard_idx) completed_count += 1 pbar.update(1) pbar.set_postfix(shard=shard_idx, status="no-result") except KeyboardInterrupt: run_logger.warning("Interrupted - terminating all active processes") if pool is not None: pool.shutdown(timeout=2) pool = None for shard_idx, (process, _) in active_processes.items(): run_logger.debug(f"Terminating shard {shard_idx} (pid={process.pid})") process.terminate() process.join(timeout=2) raise finally: if pbar is not None: pbar.close() # Shut down persistent worker pool if still active if pool is not None: pool.shutdown(timeout=5) # Clean up any remaining active processes (per-shard spawn mode) for shard_idx, (process, _) in list(active_processes.items()): if process.is_alive(): run_logger.warning( f"Cleaning up orphan process for shard {shard_idx}" ) process.terminate() process.join(timeout=2) # Restore parent environment after all workers are done for key, val in _saved_env.items(): if val is None: os.environ.pop(key, None) else: os.environ[key] = val # Release shared memory after all workers are done shared_mgr.cleanup() # Process results - collect successful samples with metadata for filtering successful_samples = [] shard_metadata: list[dict] = [] # Track shard idx, divergences, total samples shard_timings: list[tuple[int | None, float | None]] = [] for result in results: if result["success"]: # Reconstruct MCMCSamples samples = MCMCSamples( samples=result["samples"], param_names=result["param_names"], n_chains=result["n_chains"], n_samples=result["n_samples"], extra_fields=result["extra_fields"], ) successful_samples.append(samples) shard_timings.append((result.get("shard_idx"), result.get("duration"))) # Track divergence stats for quality filtering stats = result.get("stats", {}) total_samples = result["n_chains"] * result["n_samples"] shard_metadata.append( { "shard_idx": result.get("shard_idx"), "num_divergent": stats.get("num_divergent", 0), "total_samples": total_samples, # NUTS divergent count <= total_samples; max() prevents div-by-zero "divergence_rate": stats.get("num_divergent", 0) / max(total_samples, 1), "n_warmup": stats.get("n_warmup"), "n_samples": stats.get("n_samples"), } ) else: error_cat = result.get("error_category", "unknown") run_logger.warning( f"Shard {result.get('shard_idx', '?')} failed [{error_cat}]: " f"{result.get('error', 'unknown')}" ) if result.get("traceback"): run_logger.debug( f"Shard {result.get('shard_idx', '?')} traceback:\n" f"{result['traceback']}" ) if not successful_samples: # Aggregate error categories for better diagnostics error_categories: dict[str, int] = {} for result in results: if not result.get("success"): category = result.get("error_category", "unknown") error_categories[category] = error_categories.get(category, 0) + 1 run_logger.error( f"All {n_shards} shards failed. Error breakdown: {error_categories}" ) raise RuntimeError( f"All shards failed. Error categories: {error_categories}" ) # Check success rate — warn first, then error for worse. # P2-A: Previously, warning (0.80) < min (0.90) made the elif unreachable. # Fixed: check warning threshold first (higher), then error threshold (lower). success_rate = len(successful_samples) / n_shards if success_rate < config.min_success_rate_warning: # Critical: below warning threshold (worst case) run_logger.error( f"Success rate {success_rate:.1%} below minimum threshold " f"{config.min_success_rate_warning:.1%} - analysis may be unreliable" ) elif success_rate < config.min_success_rate: # Degraded: between warning and recommended thresholds run_logger.warning( f"Success rate {success_rate:.1%} below recommended threshold " f"{config.min_success_rate:.1%} - consider investigating failed shards" ) valid_durations = [d for _, d in shard_timings if d is not None] if valid_durations: run_logger.debug( f"Shard timing summary: n={len(valid_durations)}, " f"min={min(valid_durations):.1f}s, max={max(valid_durations):.1f}s, " f"median={sorted(valid_durations)[len(valid_durations) // 2]:.1f}s" ) # ────────────────────────────────────────────────────────────────────── # Jan 2026 FIX: Divergence-based shard quality filter # Filter out shards with divergence rate > max_divergence_rate # High-divergence shards have corrupted posteriors that bias the # consensus combination. This is especially critical for laminar_flow # where 28.4% overall divergence rate (from the C020 CMC run) indicated # severe sampling issues that propagated to parameter estimates. # ────────────────────────────────────────────────────────────────────── max_div_rate = getattr(config, "max_divergence_rate", 0.10) if max_div_rate < 1.0 and shard_metadata: # Identify high-divergence shards high_div_shards = [] filtered_samples = [] filtered_metadata = [] for samples, meta in zip(successful_samples, shard_metadata, strict=True): div_rate = meta["divergence_rate"] if div_rate > max_div_rate: high_div_shards.append( (meta["shard_idx"], div_rate, meta["num_divergent"]) ) else: filtered_samples.append(samples) filtered_metadata.append(meta) if high_div_shards: run_logger.warning( f"QUALITY FILTER: Excluding {len(high_div_shards)} shards with " f"divergence rate > {max_div_rate:.0%}:" ) for shard_idx, div_rate, num_div in high_div_shards: run_logger.warning( f" Shard {shard_idx}: {div_rate:.1%} divergence ({num_div} transitions)" ) # Update samples list for combination n_before = len(successful_samples) successful_samples = filtered_samples shard_metadata = filtered_metadata run_logger.info( f"After quality filtering: {len(successful_samples)}/{n_before} shards retained" ) # Re-check if we still have enough shards if not successful_samples: raise RuntimeError( f"All {n_before} successful shards exceeded max_divergence_rate={max_div_rate:.0%}. " "Consider: (1) reducing shard size, (2) adjusting priors, " "(3) increasing max_divergence_rate threshold." ) # Warn if filtered rate is too low filtered_rate = len(successful_samples) / n_shards if filtered_rate < config.min_success_rate: run_logger.error( f"Post-filter success rate {filtered_rate:.1%} below minimum threshold " f"{config.min_success_rate:.1%} - analysis may be unreliable" ) # Log per-shard posterior statistics BEFORE combination # This helps diagnose why combined posteriors may differ from initial values # Also check for heterogeneity abort (Jan 2026 v2) heterogeneity_abort = getattr(config, "heterogeneity_abort", True) max_parameter_cv = getattr(config, "max_parameter_cv", 1.0) high_cv_params: list[tuple[str, float]] = [] # Track (param, cv) pairs if len(successful_samples) > 1: key_params = ["D0", "alpha", "D_offset", "gamma_dot_t0", "beta"] run_logger.info( f"Per-shard posterior statistics ({len(successful_samples)} shards):" ) if parameter_space is None: run_logger.warning( "parameter_space is None - bounds-aware CV disabled; " "heterogeneity detection may produce false positives for near-zero parameters" ) for param in key_params: if param in successful_samples[0].samples: means = [ float(np.nanmean(s.samples[param])) for s in successful_samples ] run_logger.info( f" {param}: shard_means=[{np.nanmin(means):.4g}, {np.nanmax(means):.4g}], " f"range={np.nanmax(means) - np.nanmin(means):.4g}, " f"mean_of_means={np.nanmean(means):.4g}, " f"std_of_means={np.nanstd(means):.4g}" ) # Check for high heterogeneity (bounds-aware CV) mean_val = abs(np.nanmean(means)) if parameter_space is not None: try: lo, hi = parameter_space.get_bounds(param) param_range = hi - lo # Distinguish inverted bounds (lo > hi) from degenerate (lo == hi) if param_range < 0: # Inverted bounds: use absolute range run_logger.warning( f" {param}: inverted bounds [{lo}, {hi}], " f"using abs(range)={abs(param_range):.4g}" ) param_range = abs(param_range) elif param_range == 0: # Degenerate bounds: fall back to mean-based scale run_logger.warning( f" {param}: degenerate bounds [{lo}, {hi}] " f"(range=0), falling back to mean-based scale" ) # For near-zero params, use bounds range as scale reference scale = ( max(mean_val, param_range * 0.01) if param_range > 0 else max(mean_val, 1e-10) ) except (KeyError, ValueError, TypeError): scale = max(mean_val, 1e-10) else: scale = max(mean_val, 1e-10) cv = np.nanstd(means) / scale if not np.isfinite(cv): # NaN/Inf CV means shard posteriors contain non-finite values high_cv_params.append((param, float("inf"))) run_logger.warning( f" NON-FINITE CV: {param} has nan/inf CV " f"(likely NaN samples in shard posteriors)" ) elif cv > max_parameter_cv: high_cv_params.append((param, cv)) run_logger.warning( f" HIGH HETEROGENEITY: {param} has CV={cv:.2f} across shards! " f"(threshold={max_parameter_cv:.2f})" ) elif cv > 0.5: run_logger.warning( f" MODERATE HETEROGENEITY: {param} has CV={cv:.2f} across shards. " f"Combined posterior may be unreliable." ) # HETEROGENEITY ABORT (Jan 2026 v2): Fail fast instead of silently bad results if heterogeneity_abort and high_cv_params: param_summary = ", ".join( f"{p} (CV={cv:.2f})" for p, cv in high_cv_params ) raise RuntimeError( f"HETEROGENEITY ABORT: {len(high_cv_params)} parameter(s) exceed " f"max_parameter_cv={max_parameter_cv:.2f}: {param_summary}\n\n" f"This indicates shards are sampling from inconsistent posterior regions, " f"making consensus combination unreliable.\n\n" f"Recommended actions:\n" f" 1. Ensure NLSQ warm-start is active (--nlsq-result <path> or automatic)\n" f" 2. Increase min_points_per_shard (current: {getattr(config, 'min_points_per_shard', 10000):,})\n" f" 3. Check if data quality issues exist (outliers, missing values)\n" f" 4. Set validation.heterogeneity_abort=false to disable this check (not recommended)\n" f" 5. Increase max_parameter_cv threshold if heterogeneity is expected" ) # Check for bimodal posteriors (per-shard) - Jan 2026 # This helps detect local minima or model misspecification bimodal_detections: list[dict[str, Any]] = [] for i, shard_result in enumerate(successful_samples): bimodal_results = check_shard_bimodality(shard_result.samples) for param, result in bimodal_results.items(): if result.is_bimodal: bimodal_detections.append( { "shard": i, "param": param, "mode1": result.means[0], "mode2": result.means[1], "std1": result.stds[0], "std2": result.stds[1], "weights": result.weights, "separation": result.separation, } ) run_logger.warning( f"BIMODAL POSTERIOR: Shard {i}, {param}: " f"modes at {result.means[0]:.4g} and {result.means[1]:.4g} " f"(weights: {result.weights[0]:.2f}/{result.weights[1]:.2f})" ) if bimodal_detections: # Compute pre-combine consensus means from per-shard posteriors consensus_means: dict[str, float] = {} key_params = ["D0", "alpha", "D_offset", "gamma_dot_t0", "beta"] for param in key_params: if param in successful_samples[0].samples: means = [ float(np.nanmean(s.samples[param])) for s in successful_samples ] consensus_means[param] = float(np.nanmean(means)) bimodal_summary = summarize_cross_shard_bimodality( bimodal_detections, n_shards=len(successful_samples), consensus_means=consensus_means, ) _log_bimodality_summary(run_logger, bimodal_summary) # Mode-aware consensus if significant bimodality detected if bimodal_summary["per_param"]: modal_params = sorted(bimodal_summary["per_param"].keys()) # Get parameter bounds for range normalization param_bounds: dict[str, tuple[float, float]] = {} if parameter_space is not None: for param in modal_params: try: param_bounds[param] = parameter_space.get_bounds(param) except (KeyError, ValueError): pass mode_assignments = cluster_shard_modes( bimodal_detections=bimodal_detections, successful_samples=successful_samples, bimodal_summary=bimodal_summary, param_bounds=param_bounds, ) run_logger.info( f"Mode-aware consensus: cluster sizes = " f"{len(mode_assignments[0])}, {len(mode_assignments[1])}" ) combined, bimodal_result = combine_shard_samples_bimodal( shard_samples=successful_samples, cluster_assignments=mode_assignments, bimodal_detections=bimodal_detections, modal_params=modal_params, co_occurrence=bimodal_summary.get("co_occurrence", {}), method=config.combination_method, ) combined.bimodal_consensus = bimodal_result # Log mode summary for i, mode in enumerate(bimodal_result.modes): mode_means = ", ".join( f"{p}={mode.mean[p]:.4g}" for p in modal_params if p in mode.mean ) run_logger.info( f" Mode {i}: weight={mode.weight:.2f}, " f"n_shards={mode.n_shards}, {mode_means}" ) else: # Bimodal detections exist but below significance threshold combined = combine_shard_samples( successful_samples, method=config.combination_method, ) else: # No bimodality detected — standard path combined = combine_shard_samples( successful_samples, method=config.combination_method, ) # Explicitly set num_shards to surviving shard count for diagnostics. # Without this, per-shard MCMCSamples reconstruction defaults num_shards=1, # and the hierarchical combination may not accumulate correctly. combined.num_shards = len(successful_samples) # Propagate median adapted n_warmup from shard metadata. # Workers may adapt warmup independently (e.g., 500→140 for small shards). # Use median to represent the typical adapted value for CMCResult reporting. warmup_values = [ m["n_warmup"] for m in shard_metadata if m.get("n_warmup") is not None ] if warmup_values: combined.shard_adapted_n_warmup = int(np.median(warmup_values)) else: run_logger.info( "No shards reported adapted n_warmup; CMCResult will use config default" ) # Log summary including divergence filtering total_divergences = ( sum(m.get("num_divergent", 0) for m in shard_metadata) if shard_metadata else 0 ) total_transitions = ( sum(m.get("total_samples", 0) for m in shard_metadata) if shard_metadata else 0 ) overall_div_rate = total_divergences / max(total_transitions, 1) run_logger.info( f"Combined {len(successful_samples)}/{n_shards} shards " f"(overall divergence rate: {overall_div_rate:.1%}, {total_divergences}/{total_transitions})" ) return combined
[docs] def is_available(self) -> bool: """Check if multiprocessing is available.""" return True