Source code for homodyne.optimization.cmc.config

"""CMC configuration dataclass and validation.

This module provides the CMCConfig dataclass for parsing and validating
CMC-specific configuration settings from the YAML config file.

Config Precedence (Important)
-----------------------------
The CLI reads base `optimization.mcmc` settings and applies them to
`per_shard_mcmc`. This means if base mcmc differs from per_shard_mcmc
in your YAML config, the CLI will overwrite per_shard_mcmc with base
values. To avoid surprises, keep base mcmc and per_shard_mcmc aligned.

Example aligned config::

    optimization:
      mcmc:
        num_warmup: 500
        num_samples: 1500
        num_chains: 4
      cmc:
        per_shard_mcmc:
          num_warmup: 500
          num_samples: 1500
          num_chains: 4
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any

from homodyne.utils.logging import get_logger

logger = get_logger(__name__)


[docs] @dataclass class CMCConfig: """Configuration for Consensus Monte Carlo (CMC) analysis. Attributes ---------- enable : bool | str Whether to enable CMC. "auto" enables based on data size. min_points_for_cmc : int Minimum data points to trigger CMC mode. sharding_strategy : str How to partition data: "stratified", "random", "contiguous". num_shards : int | str Number of data shards. "auto" calculates from data size. max_points_per_shard : int | str Maximum points per shard. "auto" calculates optimally based on dataset size, analysis mode, and angle count (see ``_resolve_max_points_per_shard``). Default: "auto". Typical auto values: 5–20K for laminar_flow, 10–20K for static (scales with dataset size). backend_name : str Execution backend: "auto", "multiprocessing", "pjit", "pbs". enable_checkpoints : bool Whether to save checkpoints during sampling. checkpoint_dir : str Directory for checkpoint files. num_warmup : int Number of warmup/burn-in samples per chain. num_samples : int Number of posterior samples per chain. num_chains : int Number of MCMC chains. chain_method : str MCMC chain execution method. ``"parallel"`` (default) runs chains concurrently via JAX vectorization. ``"sequential"`` runs chains one at a time. Parallel is faster on multi-core CPUs but adds ~5-15% overhead on very small shards (<500 points); the sampler auto-falls-back to sequential in that case. target_accept_prob : float Target acceptance probability for NUTS. dense_mass : bool Use dense mass matrix for NUTS. When True, learns parameter correlations for more efficient sampling. Default: True. max_r_hat : float Maximum R-hat for convergence. min_ess : float Minimum effective sample size. combination_method : str How to combine shard posteriors. Options: - ``"consensus_mc"``: Correct Consensus Monte Carlo (precision-weighted means). Recommended. Combines per-shard posterior moments, then generates new samples from the combined Gaussian. - ``"weighted_gaussian"``: Legacy element-wise weighted averaging (deprecated). - ``"simple_average"``: Simple element-wise averaging (deprecated). min_success_rate : float Minimum fraction of shards that must succeed. run_id : str | None Optional identifier used for structured logging across shards. per_angle_mode : str Per-angle scaling mode for anti-degeneracy defense (v2.18.0+): - ``"auto"``: Auto-selects based on n_phi threshold (recommended). When n_phi >= threshold: Estimates per-angle values, AVERAGES them, broadcasts single value to all angles (matches NLSQ behavior). When n_phi < threshold: Uses individual mode. - ``"constant"``: Per-angle contrast/offset from quantile estimation, used DIRECTLY (different fixed value per angle, NOT averaged). Reduces to 8 params (7 physical + 1 sigma). - ``"individual"``: Independent contrast + offset per angle, all sampled. May suffer from parameter absorption degeneracy with many angles. constant_scaling_threshold : int n_phi threshold for auto mode's per-angle strategy. When n_phi >= threshold, auto mode samples averaged contrast/offset (single value broadcast to all angles). When n_phi < threshold, auto mode falls back to individual per-angle sampling. Default: 3. """ # Enable settings enable: bool | str = "auto" min_points_for_cmc: int = 100000 # Anti-degeneracy: Per-angle scaling mode (v2.18.0+) per_angle_mode: str = "auto" constant_scaling_threshold: int = 3 # Sharding sharding_strategy: str = "random" num_shards: int | str = "auto" max_points_per_shard: int | str = "auto" # Backend backend_name: str = "auto" enable_checkpoints: bool = True checkpoint_dir: str = "./checkpoints/cmc" # Sampling # NOTE: Defaults are intentionally conservative for laminar_flow # workloads to avoid per-shard timeouts on large pooled datasets # (millions of points across a handful of angles). # # Effective work per shard scales roughly as: # O(num_chains * (num_warmup + num_samples) * max_points_per_shard) # # Values here are chosen to keep typical laminar_flow CMC shards # well under the 2 hour timeout on modest CPU nodes while still # providing usable posteriors and R-hat diagnostics. # # DEPRECATION NOTE: Do not use config.num_warmup or config.num_samples # directly in sampling hot paths. Use SamplingPlan.from_config() instead, # which applies adaptive scaling. Direct access is only appropriate for: # - Logging configured (pre-adaptation) values # - Timeout estimation (safe upper bound) # - Config serialization/validation num_warmup: int = 500 num_samples: int = 1500 num_chains: int = 4 # Increased from 2 for better R-hat convergence diagnostics chain_method: str = "parallel" # "parallel" (default) or "sequential" target_accept_prob: float = 0.85 dense_mass: bool = ( True # Dense mass matrix for NUTS (learns parameter correlations) ) # Adaptive sampling (Feb 2026): Scale warmup/samples based on shard size # Small datasets benefit from fewer samples to reduce NUTS overhead while # maintaining statistical validity. Profiling showed 1310s for 50 points # with default settings - adaptive scaling reduces this by 60-80%. adaptive_sampling: bool = True # Enable adaptive sample count based on shard size max_tree_depth: int = 10 # NUTS tree depth (max 2^depth leapfrog steps per sample) min_warmup: int = 100 # Minimum warmup even for small datasets min_samples: int = 200 # Minimum samples even for small datasets # JAX profiling (Feb 2026): Capture XLA-level performance data # py-spy can only profile Python code; XLA runs native code invisible to py-spy. # Enable this to trace XLA operations and export to TensorBoard-compatible format. enable_jax_profiling: bool = False # Enable jax.profiler tracing jax_profile_dir: str = "./profiles/jax" # Directory for JAX profile output # Validation thresholds max_r_hat: float = 1.1 min_ess: float = 400.0 max_divergence_rate: float = 0.10 # Filter shards with >10% divergence rate # Combination combination_method: str = ( "robust_consensus_mc" # Robust CMC with MAD outlier filtering (v2.22.2) ) min_success_rate: float = 0.90 run_id: str | None = None # Timeout (Jan 2026: reduced from 2h to 1h to fail faster on problematic shards) per_shard_timeout: int = 3600 # 1 hour per shard in seconds heartbeat_timeout: int = 600 # 10 minutes - terminate unresponsive workers # Warning thresholds min_success_rate_warning: float = 0.80 # Warn if success rate below this # Warm-start requirements (Jan 2026) require_nlsq_warmstart: bool = False # Require NLSQ warm-start for laminar_flow # NLSQ-informed priors (Feb 2026): Use NLSQ estimates to build tighter priors use_nlsq_informed_priors: bool = True # Build TruncatedNormal priors from NLSQ nlsq_prior_width_factor: float = 2.0 # Width = NLSQ_std * factor (~95.4% coverage) # Prior tempering (Feb 2026): Scale priors by 1/K per shard (Scott et al. 2016) # Without tempering, K shards each apply the full prior → combined posterior = prior^K × likelihood. # With tempering, each shard uses prior^(1/K) → combined posterior = prior × likelihood (correct). # For Normal(μ,σ): prior^(1/K) ∝ Normal(μ, σ√K), i.e., widen std by √num_shards. prior_tempering: bool = True # Enable prior tempering for multi-shard CMC # Heterogeneity detection (Jan 2026 v2) # Abort early if shard posteriors are too heterogeneous (high CV) max_parameter_cv: float = 1.0 # Abort if any parameter has CV > 1.0 across shards heterogeneity_abort: bool = True # Enable heterogeneity abort (fail fast) min_points_per_shard: int = 10000 # Enforced minimum for laminar_flow min_points_per_param: int = 1500 # Minimum points per parameter per shard # Reparameterization (Jan 2026 v3) # Transform to break D0/D_offset degeneracy reparameterization_d_total: bool = True # Sample D_total = D0 + D_offset reparameterization_log_gamma: bool = True # Sample log(gamma_dot_t0) bimodal_min_weight: float = 0.2 # Minimum weight for GMM bimodal detection bimodal_min_separation: float = 0.5 # Minimum relative separation for bimodal # Reproducibility seed: int = 42 # Base seed for PRNG key generation # Computed fields _validation_errors: list[str] = field(default_factory=list, repr=False)
[docs] @classmethod def from_dict(cls, config_dict: dict[str, Any]) -> CMCConfig: """Create CMCConfig from configuration dictionary. Parameters ---------- config_dict : dict CMC configuration dictionary from ConfigManager.get_cmc_config(). Returns ------- CMCConfig Validated configuration object. Raises ------ ValueError If required fields are missing or invalid. """ # P2-4: Warn about unrecognized top-level keys that may be typos. _known_keys = { "sharding", "backend", "backend_config", "per_shard_mcmc", "validation", "combination", "reparameterization", "enable", "min_points_for_cmc", "per_angle_mode", "constant_scaling_threshold", "run_id", "per_shard_timeout", "heartbeat_timeout", "prior_tempering", "seed", } unknown_keys = set(config_dict.keys()) - _known_keys if unknown_keys: logger.warning( f"CMC config contains unrecognized keys (possible typos): " f"{sorted(unknown_keys)}" ) # Extract nested sections sharding = config_dict.get("sharding", {}) backend = config_dict.get("backend", {}) backend_config = config_dict.get("backend_config", {}) per_shard = config_dict.get("per_shard_mcmc", {}) validation = config_dict.get("validation", {}) combination = config_dict.get("combination", {}) reparameterization = config_dict.get("reparameterization", {}) # Handle multiple schema variants for backend configuration: # 1. New schema: backend_config.name (preferred) # 2. Old schema: backend.name (dict) # 3. Legacy: backend as string (computational backend, separate from parallel backend) if backend_config and backend_config.get("name"): # New schema: use backend_config section backend_name = backend_config.get("name", "auto") enable_checkpoints = backend_config.get("enable_checkpoints", True) checkpoint_dir = backend_config.get("checkpoint_dir", "./checkpoints/cmc") elif isinstance(backend, str): # Legacy: backend is computational backend string, check backend_config backend_name = backend_config.get("name", "auto") enable_checkpoints = backend_config.get("enable_checkpoints", True) checkpoint_dir = backend_config.get("checkpoint_dir", "./checkpoints/cmc") elif isinstance(backend, dict) and backend.get("name"): # Old schema: backend is dict with name backend_name = backend.get("name", "auto") enable_checkpoints = backend.get("enable_checkpoints", True) checkpoint_dir = backend.get("checkpoint_dir", "./checkpoints/cmc") else: # Default backend_name = "auto" enable_checkpoints = True checkpoint_dir = "./checkpoints/cmc" # Backward compatibility: map legacy "jax" backend name to multiprocessing # NOTE: Map to multiprocessing, NOT pjit, because pjit backend is sequential # (it processes shards one at a time in a for loop, not in parallel) if backend_name == "jax": logger.warning( "CMC backend 'jax' is deprecated; mapping to 'multiprocessing' for parallel execution. " "Set backend_config.name to 'multiprocessing' or 'auto' instead." ) backend_name = "multiprocessing" # T3-6: Normalize possibly stringified numbers (handles "10" and "10.0") num_shards_val = sharding.get("num_shards", "auto") if isinstance(num_shards_val, str) and num_shards_val != "auto": try: num_shards_val = int(float(num_shards_val)) except ValueError: pass # Leave as string; validate() will catch it max_points_val = sharding.get("max_points_per_shard", "auto") if isinstance(max_points_val, str) and max_points_val != "auto": try: max_points_val = int(float(max_points_val)) except ValueError: pass # Leave as string; validate() will catch it config = cls( # Enable settings enable=config_dict.get("enable", "auto"), min_points_for_cmc=config_dict.get("min_points_for_cmc", 100000), # Anti-degeneracy: Per-angle scaling mode (v2.18.0+) per_angle_mode=config_dict.get("per_angle_mode", "auto"), constant_scaling_threshold=config_dict.get("constant_scaling_threshold", 3), # Sharding sharding_strategy=sharding.get("strategy", "random"), num_shards=num_shards_val, max_points_per_shard=max_points_val, # Backend backend_name=backend_name, enable_checkpoints=enable_checkpoints, checkpoint_dir=checkpoint_dir, # Sampling num_warmup=per_shard.get("num_warmup", 500), num_samples=per_shard.get("num_samples", 1500), num_chains=per_shard.get("num_chains", 4), chain_method=per_shard.get("chain_method", "parallel"), target_accept_prob=per_shard.get("target_accept_prob", 0.85), dense_mass=per_shard.get("dense_mass", True), # Adaptive sampling (Feb 2026) adaptive_sampling=per_shard.get("adaptive_sampling", True), max_tree_depth=per_shard.get("max_tree_depth", 10), min_warmup=per_shard.get("min_warmup", 100), min_samples=per_shard.get("min_samples", 200), # JAX profiling (Feb 2026) enable_jax_profiling=per_shard.get("enable_jax_profiling", False), jax_profile_dir=per_shard.get("jax_profile_dir", "./profiles/jax"), # Validation max_r_hat=validation.get("max_per_shard_rhat", 1.1), min_ess=validation.get( "min_ess", validation.get("min_per_shard_ess", 400.0) ), max_divergence_rate=validation.get("max_divergence_rate", 0.10), # Combination combination_method=combination.get("method", "robust_consensus_mc"), min_success_rate=combination.get("min_success_rate", 0.90), run_id=config_dict.get("run_id"), # Timeout per_shard_timeout=config_dict.get("per_shard_timeout", 3600), heartbeat_timeout=config_dict.get("heartbeat_timeout", 600), # Warning thresholds min_success_rate_warning=combination.get("min_success_rate_warning", 0.80), # Warm-start requirements require_nlsq_warmstart=validation.get("require_nlsq_warmstart", False), # NLSQ-informed priors (Feb 2026) use_nlsq_informed_priors=validation.get("use_nlsq_informed_priors", True), nlsq_prior_width_factor=validation.get("nlsq_prior_width_factor", 2.0), # Prior tempering (Feb 2026) prior_tempering=config_dict.get("prior_tempering", True), # Heterogeneity detection (Jan 2026 v2) max_parameter_cv=validation.get("max_parameter_cv", 1.0), heterogeneity_abort=validation.get("heterogeneity_abort", True), min_points_per_shard=sharding.get("min_points_per_shard", 10000), min_points_per_param=sharding.get("min_points_per_param", 1500), # Reparameterization (Jan 2026 v3) reparameterization_d_total=reparameterization.get("enable_d_total", True), reparameterization_log_gamma=reparameterization.get( "enable_log_gamma", True ), bimodal_min_weight=reparameterization.get("bimodal_min_weight", 0.2), bimodal_min_separation=reparameterization.get( "bimodal_min_separation", 0.5 ), # Reproducibility seed=config_dict.get("seed", 42), ) # Validate and log any issues errors = config.validate() if errors: for error in errors: logger.warning(f"CMC config validation: {error}") # Warn about config precedence (CLI overwrites per_shard_mcmc with base mcmc) if per_shard: logger.debug( "Note: CLI applies base mcmc settings to per_shard_mcmc. " "If using CLI, ensure base mcmc and per_shard_mcmc are aligned." ) return config
[docs] def validate(self) -> list[str]: """Validate configuration values. Returns ------- list[str] List of validation error messages (empty if valid). """ errors: list[str] = [] # Validate enable if self.enable not in [True, False, "auto"]: errors.append(f"enable must be True, False, or 'auto', got: {self.enable}") # Validate min_points_for_cmc if not isinstance(self.min_points_for_cmc, int) or self.min_points_for_cmc < 0: errors.append( f"min_points_for_cmc must be non-negative int, got: {self.min_points_for_cmc}" ) # Validate per_angle_mode (v2.18.0+) valid_per_angle_modes = ["auto", "constant", "constant_averaged", "individual"] if self.per_angle_mode not in valid_per_angle_modes: errors.append( f"per_angle_mode must be one of {valid_per_angle_modes}, " f"got: {self.per_angle_mode}" ) # Validate constant_scaling_threshold if ( not isinstance(self.constant_scaling_threshold, int) or self.constant_scaling_threshold < 1 ): errors.append( f"constant_scaling_threshold must be positive int, " f"got: {self.constant_scaling_threshold}" ) # Validate sharding_strategy valid_strategies = ["stratified", "random", "contiguous"] if self.sharding_strategy not in valid_strategies: errors.append( f"sharding_strategy must be one of {valid_strategies}, got: {self.sharding_strategy}" ) # Validate num_shards if self.num_shards != "auto": if not isinstance(self.num_shards, int) or self.num_shards <= 0: errors.append( f"num_shards must be 'auto' or positive int, got: {self.num_shards}" ) # Warn about small max_points_per_shard (creates excessive shards) if ( isinstance(self.max_points_per_shard, int) and self.max_points_per_shard < 3000 ): logger.warning( f"max_points_per_shard={self.max_points_per_shard:,} is very small. " "This will create many shards with high overhead. " "Recommended: 3000-10000 for laminar_flow, 50000-100000 for static." ) # Validate backend_name (allow legacy 'jax' but normalize earlier) valid_backends = ["auto", "multiprocessing", "pjit", "pbs", "slurm", "jax"] if self.backend_name not in valid_backends: errors.append( f"backend_name must be one of {valid_backends}, got: {self.backend_name}" ) # Validate sampling parameters if not isinstance(self.num_warmup, int) or self.num_warmup <= 0: errors.append(f"num_warmup must be positive int, got: {self.num_warmup}") if not isinstance(self.num_samples, int) or self.num_samples <= 0: errors.append(f"num_samples must be positive int, got: {self.num_samples}") if not isinstance(self.num_chains, int) or self.num_chains <= 0: errors.append(f"num_chains must be positive int, got: {self.num_chains}") # Validate chain_method valid_chain_methods = ["parallel", "sequential"] if self.chain_method not in valid_chain_methods: errors.append( f"chain_method must be one of {valid_chain_methods}, " f"got: {self.chain_method}" ) # Validate target_accept_prob if not 0.0 < self.target_accept_prob < 1.0: errors.append( f"target_accept_prob must be in (0, 1), got: {self.target_accept_prob}" ) # Validate adaptive sampling settings (Feb 2026) if ( not isinstance(self.max_tree_depth, int) or not 1 <= self.max_tree_depth <= 15 ): errors.append( f"max_tree_depth must be int in [1, 15], got: {self.max_tree_depth}" ) if not isinstance(self.min_warmup, int) or self.min_warmup < 10: errors.append(f"min_warmup must be int >= 10, got: {self.min_warmup}") if not isinstance(self.min_samples, int) or self.min_samples < 50: errors.append(f"min_samples must be int >= 50, got: {self.min_samples}") # Validate convergence thresholds if not isinstance(self.max_r_hat, (int, float)) or self.max_r_hat < 1.0: errors.append(f"max_r_hat must be >= 1.0, got: {self.max_r_hat}") if not isinstance(self.min_ess, (int, float)) or self.min_ess < 0: errors.append(f"min_ess must be non-negative, got: {self.min_ess}") if not 0.0 <= self.max_divergence_rate <= 1.0: errors.append( f"max_divergence_rate must be in [0, 1], got: {self.max_divergence_rate}" ) # Validate combination settings valid_methods = [ "consensus_mc", "robust_consensus_mc", "weighted_gaussian", "simple_average", "auto", ] if self.combination_method not in valid_methods: errors.append( f"combination_method must be one of {valid_methods}, got: {self.combination_method}" ) if not 0.0 <= self.min_success_rate <= 1.0: errors.append( f"min_success_rate must be in [0, 1], got: {self.min_success_rate}" ) # Validate timeout settings if not isinstance(self.heartbeat_timeout, int) or self.heartbeat_timeout < 60: errors.append( f"heartbeat_timeout must be int >= 60 seconds, got: {self.heartbeat_timeout}" ) # Validate warning threshold if not 0.0 <= self.min_success_rate_warning <= 1.0: errors.append( f"min_success_rate_warning must be in [0, 1], got: {self.min_success_rate_warning}" ) if self.min_success_rate_warning > self.min_success_rate: logger.warning( f"min_success_rate_warning ({self.min_success_rate_warning}) > " f"min_success_rate ({self.min_success_rate}); warning will never trigger" ) # Validate heterogeneity detection settings (Jan 2026 v2) if ( not isinstance(self.max_parameter_cv, (int, float)) or self.max_parameter_cv <= 0 ): errors.append( f"max_parameter_cv must be positive number, got: {self.max_parameter_cv}" ) if not isinstance( self.heterogeneity_abort, bool ): # runtime check for YAML input errors.append( # type: ignore[unreachable] f"heterogeneity_abort must be bool, got: {self.heterogeneity_abort}" ) if ( not isinstance(self.min_points_per_shard, int) or self.min_points_per_shard < 1000 ): errors.append( f"min_points_per_shard must be int >= 1000, got: {self.min_points_per_shard}" ) # Validate bimodal detection thresholds (Jan 2026 v3) if not (0.0 < self.bimodal_min_weight <= 0.5): errors.append( f"bimodal_min_weight must be in (0, 0.5], got: {self.bimodal_min_weight}" ) if not (0.0 < self.bimodal_min_separation <= 2.0): errors.append( f"bimodal_min_separation must be in (0, 2.0], got: {self.bimodal_min_separation}" ) # Validate seed (reproducibility) if not isinstance(self.seed, int) or self.seed < 0: errors.append(f"seed must be a non-negative integer, got: {self.seed}") # Validate NLSQ-informed priors (Feb 2026) if not (1.0 <= self.nlsq_prior_width_factor <= 10.0): errors.append( f"nlsq_prior_width_factor must be in [1.0, 10.0], got: {self.nlsq_prior_width_factor}" ) self._validation_errors = errors return errors
[docs] def is_valid(self) -> bool: """Check if configuration is valid. Returns ------- bool True if configuration has no validation errors. """ return len(self.validate()) == 0
[docs] def should_enable_cmc( self, n_points: int, analysis_mode: str | None = None ) -> bool: """Determine if CMC should be enabled for given data size. Parameters ---------- n_points : int Number of data points. analysis_mode : str | None Deprecated — ignored. Kept for backward compatibility. Returns ------- bool True if CMC should be enabled. Notes ----- Threshold is ``min_points_for_cmc`` (default 100,000) for all modes. """ if self.enable is True: return True if self.enable is False: return False # "auto" mode return n_points >= self.min_points_for_cmc
[docs] def get_num_shards(self, n_points: int, n_phi: int, n_params: int = 7) -> int: """Calculate number of shards with param-aware sizing. Parameters ---------- n_points : int Total number of data points. n_phi : int Number of phi angles. n_params : int Number of model parameters (default: 7 for static). Returns ------- int Number of shards to use. """ if isinstance(self.num_shards, int): return self.num_shards # Auto calculation: stratified by phi angle if self.sharding_strategy == "stratified": return n_phi # For other strategies, calculate based on max_points_per_shard if isinstance(self.max_points_per_shard, int): base_max = self.max_points_per_shard else: # Default: ~100k points per shard base_max = 100000 # Param-aware adjustment: scale up for n_params > 7 param_factor = max(1.0, n_params / 7.0) min_required = int(self.min_points_per_param * n_params) adjusted_max = max( int(base_max * param_factor), min_required, ) if param_factor > 1.0: logger.debug( f"Param-aware shard sizing: {n_params} params detected. " f"Adjusted max_points_per_shard: {base_max:,} -> {adjusted_max:,} " f"(factor={param_factor:.2f})" ) return max(1, n_points // adjusted_max)
[docs] def get_adaptive_sample_counts( self, shard_size: int, n_params: int = 7 ) -> tuple[int, int]: """Calculate adaptive warmup/samples based on shard size. Small datasets benefit from fewer NUTS samples because: 1. JIT compilation overhead is amortized over fewer samples 2. Step size adaptation converges faster with simple likelihoods 3. Mass matrix estimation requires fewer warmup iterations Profiling showed 1310s for 50 points with 500 warmup + 1500 samples. Adaptive scaling reduces this by 60-80% while maintaining statistical validity (ESS targets are reduced proportionally). Parameters ---------- shard_size : int Number of data points in the shard. n_params : int Number of model parameters (affects minimum samples). Returns ------- tuple[int, int] (num_warmup, num_samples) adjusted for shard size. """ if not self.adaptive_sampling: return self.num_warmup, self.num_samples # Reference point: 10k points → full samples # Scale down for smaller datasets reference_size = 10000 scale_factor = min(1.0, shard_size / reference_size) # Compute scaled counts scaled_warmup = int(self.num_warmup * scale_factor) scaled_samples = int(self.num_samples * scale_factor) # Ensure minimum viable sampling (ESS requires ~50 samples per param). # P2-B: Cap at configured defaults — adaptive scaling should only reduce, # never exceed, the user's configured num_warmup/num_samples. min_samples_for_params = min( max(self.min_samples, 50 * n_params), self.num_samples ) min_warmup_for_params = min( max(self.min_warmup, 20 * n_params), self.num_warmup ) # Apply bounds final_warmup = max(min_warmup_for_params, scaled_warmup) final_samples = max(min_samples_for_params, scaled_samples) # Log if different from defaults if final_warmup != self.num_warmup or final_samples != self.num_samples: logger.debug( f"Adaptive sampling: {shard_size:,} points, {n_params} params -> " f"warmup={final_warmup} (was {self.num_warmup}), " f"samples={final_samples} (was {self.num_samples})" ) return final_warmup, final_samples
[docs] def get_effective_per_angle_mode( self, n_phi: int, nlsq_per_angle_mode: str | None = None, has_nlsq_warmstart: bool = False, ) -> str: """Determine effective per-angle mode based on configuration and data. Parameters ---------- n_phi : int Number of phi angles in the dataset. nlsq_per_angle_mode : str | None Optional per-angle mode from NLSQ result. When provided (from warm-start), CMC will use this mode to ensure parameterization parity with NLSQ. This prevents CMC vs NLSQ divergence from different model structures. has_nlsq_warmstart : bool Whether an NLSQ warm-start result is available. When True and both CMC and NLSQ use "auto" mode, upgrades to "constant_averaged" for fewer sampled parameters and better stability. Returns ------- str Effective mode: "auto", "constant", "constant_averaged", or "individual". Notes ----- Mode semantics (same as NLSQ): - auto: Sample single averaged contrast/offset (10 params for laminar_flow). Only activated when n_phi >= threshold (many angles). - constant: Use FIXED per-angle values from quantile estimation (8 params). - constant_averaged: Use FIXED averaged scaling for NLSQ parity. - individual: Sample per-angle contrast/offset (n_phi*2 + 7 + 1 params). Priority: nlsq_per_angle_mode > explicit config > auto-selection When NLSQ warm-start is present and both sides use "auto", upgrades to "constant_averaged" to fix scaling values and reduce parameter count. This prevents contrast/offset sampling from absorbing physical parameter signal, which was the root cause of heterogeneous shard posteriors. """ # Jan 2026 v2: When NLSQ warm-start provides per-angle mode, match it # This ensures CMC and NLSQ use identical parameterizations if nlsq_per_angle_mode is not None: # Feb 2026: When NLSQ warm-start present and both sides use "auto", # upgrade to constant_averaged for fewer params and better stability if ( has_nlsq_warmstart and nlsq_per_angle_mode == "auto" and self.per_angle_mode == "auto" ): logger.info( "CMC per-angle mode: auto -> constant_averaged " "(NLSQ warm-start present, fixing scaling for stability)" ) return "constant_averaged" logger.info( f"CMC per-angle mode: Using NLSQ warm-start mode '{nlsq_per_angle_mode}' " f"for parameterization parity" ) return nlsq_per_angle_mode if self.per_angle_mode == "auto": if n_phi >= self.constant_scaling_threshold: # Return "auto" - this uses the xpcs_model_averaged which samples # single averaged contrast/offset (10 params for laminar_flow) logger.info( f"CMC anti-degeneracy: Using 'auto' mode (sampled averaged scaling) " f"(n_phi={n_phi} >= threshold={self.constant_scaling_threshold})" ) return "auto" else: # Few angles - use individual per-angle sampling logger.info( f"CMC anti-degeneracy: Auto-selected 'individual' mode " f"(n_phi={n_phi} < threshold={self.constant_scaling_threshold})" ) return "individual" else: # Explicit mode (constant, constant_averaged, or individual) return self.per_angle_mode
[docs] def to_dict(self) -> dict[str, Any]: """Convert configuration to dictionary. Returns ------- dict Configuration as dictionary. """ return { "enable": self.enable, "min_points_for_cmc": self.min_points_for_cmc, "per_angle_mode": self.per_angle_mode, "constant_scaling_threshold": self.constant_scaling_threshold, "run_id": self.run_id, "sharding": { "strategy": self.sharding_strategy, "num_shards": self.num_shards, "max_points_per_shard": self.max_points_per_shard, "min_points_per_shard": self.min_points_per_shard, "min_points_per_param": self.min_points_per_param, }, "backend_config": { "name": self.backend_name, "enable_checkpoints": self.enable_checkpoints, "checkpoint_dir": self.checkpoint_dir, }, "per_shard_mcmc": { "num_warmup": self.num_warmup, "num_samples": self.num_samples, "num_chains": self.num_chains, "chain_method": self.chain_method, "target_accept_prob": self.target_accept_prob, "adaptive_sampling": self.adaptive_sampling, "max_tree_depth": self.max_tree_depth, "min_warmup": self.min_warmup, "min_samples": self.min_samples, "enable_jax_profiling": self.enable_jax_profiling, "jax_profile_dir": self.jax_profile_dir, }, "validation": { "max_per_shard_rhat": self.max_r_hat, "min_per_shard_ess": self.min_ess, "max_divergence_rate": self.max_divergence_rate, "require_nlsq_warmstart": self.require_nlsq_warmstart, "use_nlsq_informed_priors": self.use_nlsq_informed_priors, "nlsq_prior_width_factor": self.nlsq_prior_width_factor, "max_parameter_cv": self.max_parameter_cv, "heterogeneity_abort": self.heterogeneity_abort, }, "combination": { "method": self.combination_method, "min_success_rate": self.min_success_rate, "min_success_rate_warning": self.min_success_rate_warning, }, "prior_tempering": self.prior_tempering, "per_shard_timeout": self.per_shard_timeout, "heartbeat_timeout": self.heartbeat_timeout, "reparameterization": { "enable_d_total": self.reparameterization_d_total, "enable_log_gamma": self.reparameterization_log_gamma, "bimodal_min_weight": self.bimodal_min_weight, "bimodal_min_separation": self.bimodal_min_separation, }, "seed": self.seed, }