Source code for homodyne.optimization.cmc.priors

"""Prior distribution builders for CMC analysis.

This module provides utilities for building NumPyro prior distributions
from the ParameterSpace configuration.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import numpy as np
import numpyro.distributions as dist

from homodyne.core.scaling_utils import (
    estimate_per_angle_scaling as _estimate_per_angle_scaling_canonical,
)
from homodyne.utils.logging import get_logger

if TYPE_CHECKING:
    from homodyne.config.parameter_space import ParameterSpace, PriorDistribution

logger = get_logger(__name__)


# =============================================================================
# DATA-DRIVEN INITIAL VALUE ESTIMATION
# =============================================================================


[docs] def estimate_contrast_offset_from_data( c2_data: np.ndarray, t1: np.ndarray, t2: np.ndarray, contrast_bounds: tuple[float, float] = (0.0, 1.0), offset_bounds: tuple[float, float] = (0.5, 1.5), lag_floor_quantile: float = 0.80, lag_ceiling_quantile: float = 0.20, value_quantile_low: float = 0.10, value_quantile_high: float = 0.90, ) -> tuple[float, float]: """Estimate contrast and offset from C2 data using physics-informed quantile analysis. Uses the correlation decay structure: C2 = contrast × g1² + offset - At large time lags, g1² → 0, so C2 → offset (the "floor") - At small time lags, g1² ≈ 1, so C2 ≈ contrast + offset (the "ceiling") Parameters ---------- c2_data : np.ndarray C2 correlation values (1D array). t1 : np.ndarray First time coordinate array (same shape as c2_data). t2 : np.ndarray Second time coordinate array (same shape as c2_data). contrast_bounds : tuple[float, float] Valid bounds for contrast parameter. offset_bounds : tuple[float, float] Valid bounds for offset parameter. lag_floor_quantile : float Quantile threshold for "large lag" region (default: 0.80 = top 20% of lags). lag_ceiling_quantile : float Quantile threshold for "small lag" region (default: 0.20 = bottom 20% of lags). value_quantile_low : float Quantile for robust floor estimation (default: 0.10). value_quantile_high : float Quantile for robust ceiling estimation (default: 0.90). Returns ------- tuple[float, float] (contrast_est, offset_est) - Estimated values clipped to bounds. Notes ----- The estimation is robust to outliers by using quantiles instead of min/max. The lag-based segmentation ensures we're sampling from the appropriate regions of the correlation decay curve. """ # Compute time lags delta_t = np.abs(np.asarray(t1) - np.asarray(t2)) c2 = np.asarray(c2_data) # Sanity checks if len(c2) < 100: # Not enough data for robust estimation - return midpoints contrast_mid = (contrast_bounds[0] + contrast_bounds[1]) / 2.0 offset_mid = (offset_bounds[0] + offset_bounds[1]) / 2.0 logger.debug( f"Insufficient data ({len(c2)} points) for quantile estimation, " f"using midpoint defaults: contrast={contrast_mid:.3f}, offset={offset_mid:.3f}" ) return contrast_mid, offset_mid # Find lag thresholds lag_threshold_high = np.percentile(delta_t, lag_floor_quantile * 100) lag_threshold_low = np.percentile(delta_t, lag_ceiling_quantile * 100) # OFFSET estimation: From large-lag region where g1² ≈ 0 # C2 → offset at large lags large_lag_mask = delta_t >= lag_threshold_high if np.sum(large_lag_mask) >= 10: c2_floor_region = c2[large_lag_mask] # Use low quantile for robustness (in case of noise spikes) offset_est = np.percentile(c2_floor_region, value_quantile_low * 100) else: # Fallback: use overall low quantile offset_est = np.percentile(c2, value_quantile_low * 100) # Clip offset to bounds offset_est = float(np.clip(offset_est, offset_bounds[0], offset_bounds[1])) # CONTRAST estimation: From small-lag region where g1² ≈ 1 # C2 ≈ contrast + offset at small lags small_lag_mask = delta_t <= lag_threshold_low if np.sum(small_lag_mask) >= 10: c2_ceiling_region = c2[small_lag_mask] # Use high quantile for robustness c2_ceiling = np.percentile(c2_ceiling_region, value_quantile_high * 100) else: # Fallback: use overall high quantile c2_ceiling = np.percentile(c2, value_quantile_high * 100) # contrast ≈ c2_ceiling - offset contrast_est = c2_ceiling - offset_est # Clip contrast to bounds contrast_est = float(np.clip(contrast_est, contrast_bounds[0], contrast_bounds[1])) logger.debug( f"Quantile-based estimation: offset={offset_est:.4f} (from large-lag floor), " f"contrast={contrast_est:.4f} (from small-lag ceiling - floor)" ) return contrast_est, offset_est
[docs] def estimate_per_angle_scaling( c2_data: np.ndarray, t1: np.ndarray, t2: np.ndarray, phi_indices: np.ndarray, n_phi: int, contrast_bounds: tuple[float, float], offset_bounds: tuple[float, float], ) -> dict[str, float]: """Estimate contrast and offset initial values for each phi angle. Thin wrapper that delegates to the canonical implementation in ``homodyne.core.scaling_utils``. Kept here for backward compatibility with any internal callers within this module. Parameters ---------- c2_data : np.ndarray Pooled C2 correlation values. t1 : np.ndarray Pooled first time coordinates. t2 : np.ndarray Pooled second time coordinates. phi_indices : np.ndarray Index mapping each data point to its phi angle (0 to n_phi-1). n_phi : int Number of unique phi angles. contrast_bounds : tuple[float, float] Valid bounds for contrast. offset_bounds : tuple[float, float] Valid bounds for offset. Returns ------- dict[str, float] Dictionary with keys 'contrast_0', 'offset_0', 'contrast_1', 'offset_1', etc. """ return _estimate_per_angle_scaling_canonical( c2_data=c2_data, t1=t1, t2=t2, phi_indices=phi_indices, n_phi=n_phi, contrast_bounds=contrast_bounds, offset_bounds=offset_bounds, log=logger, )
# Physical parameter names in canonical order STATIC_PARAMS = ["D0", "alpha", "D_offset"] LAMINAR_PARAMS = [ "D0", "alpha", "D_offset", "gamma_dot_t0", "beta", "gamma_dot_t_offset", "phi0", ]
[docs] def build_prior_from_spec( prior_spec: PriorDistribution, ) -> dist.Distribution: """Build NumPyro distribution from PriorDistribution specification. Parameters ---------- prior_spec : PriorDistribution Prior specification from ParameterSpace. Returns ------- dist.Distribution NumPyro distribution object. Raises ------ ValueError If distribution type is not supported. """ dist_type = prior_spec.dist_type.lower() if dist_type == "truncatednormal": return dist.TruncatedNormal( loc=prior_spec.mu, scale=prior_spec.sigma, low=prior_spec.min_val, high=prior_spec.max_val, ) elif dist_type == "uniform": return dist.Uniform( low=prior_spec.min_val, high=prior_spec.max_val, ) elif dist_type == "lognormal": return dist.LogNormal( loc=prior_spec.mu, scale=prior_spec.sigma, ) elif dist_type == "halfnormal": return dist.HalfNormal(scale=prior_spec.sigma) elif dist_type == "normal": return dist.Normal(loc=prior_spec.mu, scale=prior_spec.sigma) elif dist_type == "betascaled": # Beta distribution scaled to [min_val, max_val] # Use alpha=2, beta=2 for symmetric prior if not specified alpha = getattr(prior_spec, "alpha", 2.0) beta = getattr(prior_spec, "beta", 2.0) base = dist.Beta(concentration1=alpha, concentration0=beta) return dist.TransformedDistribution( base, dist.transforms.AffineTransform( loc=prior_spec.min_val, scale=prior_spec.max_val - prior_spec.min_val, ), ) else: raise ValueError(f"Unsupported distribution type: {dist_type}")
def _get_base_param_name(param_name: str) -> str: """Get base parameter name for per-angle parameters. Maps indexed scaling names (e.g. ``'contrast_0'``) back to their base name (``'contrast'``), derived from the registry's ``is_scaling`` flag. Non-scaling parameter names are returned unchanged. Parameters ---------- param_name : str Parameter name (possibly with angle suffix). Returns ------- str Base parameter name. """ from homodyne.config.parameter_registry import ParameterRegistry for sname in ParameterRegistry().scaling_names: if param_name.startswith(f"{sname}_"): return sname return param_name
[docs] def build_prior( param_name: str, parameter_space: ParameterSpace, ) -> dist.Distribution: """Build NumPyro prior distribution for a parameter. Parameters ---------- param_name : str Parameter name (e.g., "D0", "alpha", "contrast", "contrast_0"). parameter_space : ParameterSpace Parameter space with bounds and priors. Returns ------- dist.Distribution NumPyro distribution for sampling. """ # Use base name for per-angle parameters (contrast_0 -> contrast, etc.) base_name = _get_base_param_name(param_name) try: prior_spec = parameter_space.get_prior(base_name) return build_prior_from_spec(prior_spec) except (KeyError, AttributeError): # Fallback to uniform prior with bounds bounds = parameter_space.get_bounds(base_name) logger.debug( f"No prior spec for {param_name}, using Uniform({bounds[0]}, {bounds[1]})" ) return dist.Uniform(low=bounds[0], high=bounds[1])
[docs] def get_init_value( param_name: str, initial_values: dict[str, float] | None, parameter_space: ParameterSpace, ) -> float: """Get initial value for a parameter. Priority: 1. Value from initial_values dict if provided (exact match) 2. Value from initial_values dict for base param (e.g., 'contrast' for 'contrast_0') 3. Midpoint of parameter bounds as fallback Parameters ---------- param_name : str Parameter name. initial_values : dict[str, float] | None Initial values from config. parameter_space : ParameterSpace Parameter space with bounds. Returns ------- float Initial value for the parameter. Notes ----- Per-angle parameter handling (scalar broadcast): For per-angle parameters like 'contrast_0', 'contrast_1', etc., this function broadcasts a single scalar value to all angles. If only 'contrast' is provided in initial_values (not 'contrast_0', 'contrast_1', etc.), that single value is used for ALL phi angles. To specify different initial values per angle, provide explicit keys like: ``{'contrast_0': 0.4, 'contrast_1': 0.5, 'contrast_2': 0.45}`` The same applies to 'offset' parameters. Examples -------- >>> # Scalar broadcast: same value for all angles >>> initial_values = {'contrast': 0.5, 'offset': 1.0} >>> get_init_value('contrast_0', initial_values, param_space) # Returns 0.5 >>> get_init_value('contrast_1', initial_values, param_space) # Returns 0.5 >>> # Explicit per-angle values >>> initial_values = {'contrast_0': 0.4, 'contrast_1': 0.6} >>> get_init_value('contrast_0', initial_values, param_space) # Returns 0.4 >>> get_init_value('contrast_1', initial_values, param_space) # Returns 0.6 """ # Check initial_values first (exact match) if initial_values is not None and param_name in initial_values: return float(initial_values[param_name]) # For per-angle params, check base param name in initial_values base_name = _get_base_param_name(param_name) if initial_values is not None and base_name in initial_values: return float(initial_values[base_name]) # Fallback to midpoint of bounds (use base name for per-angle params) bounds = parameter_space.get_bounds(base_name) midpoint = (bounds[0] + bounds[1]) / 2.0 return midpoint
[docs] def validate_initial_value_bounds( param_name: str, value: float, parameter_space: ParameterSpace, ) -> tuple[float, bool]: """Validate and optionally clip initial value to parameter bounds. Parameters ---------- param_name : str Parameter name. value : float Initial value to validate. parameter_space : ParameterSpace Parameter space with bounds. Returns ------- tuple[float, bool] (validated_value, was_clipped) - The value (clipped if needed) and whether clipping occurred. """ import math base_name = _get_base_param_name(param_name) bounds = parameter_space.get_bounds(base_name) lower, upper = bounds[0], bounds[1] # Check for NaN/Inf if not math.isfinite(value): midpoint = (lower + upper) / 2.0 logger.warning( f"Initial value for '{param_name}' is {value} (non-finite), " f"resetting to midpoint {midpoint:.4g}" ) return midpoint, True # Check bounds if value < lower: logger.warning( f"Initial value for '{param_name}' ({value:.4g}) is below lower bound ({lower:.4g}), " f"clipping to lower bound + 1% margin" ) margin = 0.01 * (upper - lower) return lower + margin, True elif value > upper: logger.warning( f"Initial value for '{param_name}' ({value:.4g}) is above upper bound ({upper:.4g}), " f"clipping to upper bound - 1% margin" ) margin = 0.01 * (upper - lower) return upper - margin, True return value, False
[docs] def build_init_values_dict( n_phi: int, analysis_mode: str, initial_values: dict[str, float] | None, parameter_space: ParameterSpace, *, c2_data: np.ndarray | None = None, t1: np.ndarray | None = None, t2: np.ndarray | None = None, phi_indices: np.ndarray | None = None, per_angle_mode: str = "individual", ) -> dict[str, float]: """Build complete initial values dictionary in sampling order. CRITICAL: Parameter order must match NumPyro model sampling order: 1. contrast_0, contrast_1, ..., contrast_{n_phi-1} (individual mode) OR contrast_avg (constant mode). 2. offset_0, offset_1, ..., offset_{n_phi-1} (individual mode) OR offset_avg (constant mode). 3. Physical parameters in canonical order. Parameters ---------- n_phi : int Number of phi angles. analysis_mode : str Analysis mode ("static" or "laminar_flow"). initial_values : dict[str, float] | None Initial values from config. Supports both scalar (broadcast) and per-angle specifications for contrast/offset. See Notes for details. parameter_space : ParameterSpace Parameter space with bounds. c2_data : np.ndarray | None Optional C2 correlation data for quantile-based estimation of contrast/offset. t1 : np.ndarray | None Optional time coordinates (required if c2_data provided). t2 : np.ndarray | None Optional time coordinates (required if c2_data provided). phi_indices : np.ndarray | None Optional phi angle indices for per-angle estimation. per_angle_mode : str Per-angle scaling mode: "individual" or "constant". Returns ------- dict[str, float] Initial values dictionary in sampling order. Notes ----- Per-angle scaling parameters (contrast/offset): This function supports three modes for specifying per-angle initial values: 1. **Data-driven estimation** (NEW, preferred): If c2_data, t1, t2, and phi_indices are provided, and contrast/offset not in initial_values, uses physics-informed quantile analysis to estimate values from data. 2. **Scalar broadcast**: If initial_values contains only base names like 'contrast' and 'offset', those values are broadcast to ALL phi angles. Example: ``{'contrast': 0.5}`` → contrast_0=0.5, contrast_1=0.5, ... 3. **Explicit per-angle**: If initial_values contains indexed names like 'contrast_0', 'contrast_1', etc., those specific values are used. Example: ``{'contrast_0': 0.4, 'contrast_1': 0.6}`` Priority: explicit per-angle > scalar broadcast > data-driven > midpoint fallback Bounds validation: All initial values are validated against parameter bounds. Out-of-bounds values are clipped to bounds ± 1% margin with a warning logged. """ init_dict: dict[str, float] = {} clipped_params: list[str] = [] # Determine physical params early (needed for logging in constant mode) physical_params = ( LAMINAR_PARAMS if analysis_mode == "laminar_flow" else STATIC_PARAMS ) # Check if we should use data-driven estimation for contrast/offset # Only use if: # 1. Data arrays are provided # 2. contrast/offset are NOT in initial_values (neither scalar nor per-angle) use_data_estimation = ( c2_data is not None and t1 is not None and t2 is not None and phi_indices is not None and len(c2_data) >= 100 ) # Check if scaling parameters are present in initial_values # (uses _get_base_param_name which derives from registry is_scaling flag) has_contrast = initial_values is not None and ( "contrast" in initial_values or any(_get_base_param_name(k) == "contrast" for k in initial_values) ) has_offset = initial_values is not None and ( "offset" in initial_values or any(_get_base_param_name(k) == "offset" for k in initial_values) ) # Compute data-driven estimates if needed data_estimates: dict[str, float] = {} if use_data_estimation and (not has_contrast or not has_offset): # Type narrowing: use_data_estimation guards non-None above, so these # are guaranteed non-None here. Use assert for mypy (safe: condition # is logically unreachable, but protects against future refactors). assert c2_data is not None # noqa: S101 — type narrowing assert t1 is not None # noqa: S101 — type narrowing assert t2 is not None # noqa: S101 — type narrowing assert phi_indices is not None # noqa: S101 — type narrowing contrast_bounds = parameter_space.get_bounds("contrast") offset_bounds = parameter_space.get_bounds("offset") data_estimates = estimate_per_angle_scaling( c2_data=c2_data, t1=t1, t2=t2, phi_indices=phi_indices, n_phi=n_phi, contrast_bounds=contrast_bounds, offset_bounds=offset_bounds, ) logger.info( f"Using data-driven quantile estimation for contrast/offset " f"(n_phi={n_phi}, n_data={len(c2_data):,})" ) # ========================================================================= # Handle per_angle_mode: "constant", "auto", or "individual" # ========================================================================= if per_angle_mode in ("constant", "constant_averaged"): # CONSTANT/CONSTANT_AVERAGED MODE: No per-angle params are sampled - # they're fixed from quantile estimation and passed directly to the model. # Only physical parameters need initialization. logger.info( f"{per_angle_mode.upper()} mode: contrast/offset are FIXED (not sampled). " f"Only initializing {len(physical_params)} physical parameters." ) elif per_angle_mode == "auto": # AUTO MODE: Sample single averaged contrast/offset (2 params) # The model (xpcs_model_averaged) samples "contrast" and "offset" directly # 1. Initialize averaged contrast if initial_values is not None and "contrast" in initial_values: raw_contrast = float(initial_values["contrast"]) elif data_estimates: # Use mean of per-angle estimates contrast_values = [ data_estimates.get(f"contrast_{i}", 0.5) for i in range(n_phi) ] raw_contrast = float(np.mean(contrast_values)) else: bounds = parameter_space.get_bounds("contrast") raw_contrast = (bounds[0] + bounds[1]) / 2.0 validated_contrast, was_clipped = validate_initial_value_bounds( "contrast", raw_contrast, parameter_space ) init_dict["contrast"] = validated_contrast if was_clipped: clipped_params.append("contrast") # 2. Initialize averaged offset if initial_values is not None and "offset" in initial_values: raw_offset = float(initial_values["offset"]) elif data_estimates: # Use mean of per-angle estimates offset_values = [ data_estimates.get(f"offset_{i}", 1.0) for i in range(n_phi) ] raw_offset = float(np.mean(offset_values)) else: bounds = parameter_space.get_bounds("offset") raw_offset = (bounds[0] + bounds[1]) / 2.0 validated_offset, was_clipped = validate_initial_value_bounds( "offset", raw_offset, parameter_space ) init_dict["offset"] = validated_offset if was_clipped: clipped_params.append("offset") logger.info( f"Auto mode: initializing SAMPLED averaged contrast={validated_contrast:.4f}, " f"offset={validated_offset:.4f}" ) else: # INDIVIDUAL MODE: Sample per-angle contrast_i and offset_i # 1. Per-angle contrast parameters (FIRST) for i in range(n_phi): param_name = f"contrast_{i}" # Priority: initial_values > data_estimates > midpoint fallback if initial_values is not None and param_name in initial_values: raw_value = float(initial_values[param_name]) elif initial_values is not None and "contrast" in initial_values: raw_value = float(initial_values["contrast"]) elif param_name in data_estimates: raw_value = data_estimates[param_name] else: # Midpoint fallback bounds = parameter_space.get_bounds("contrast") raw_value = (bounds[0] + bounds[1]) / 2.0 validated_value, was_clipped = validate_initial_value_bounds( param_name, raw_value, parameter_space ) init_dict[param_name] = validated_value if was_clipped: clipped_params.append(param_name) # 2. Per-angle offset parameters (SECOND) for i in range(n_phi): param_name = f"offset_{i}" # Priority: initial_values > data_estimates > midpoint fallback if initial_values is not None and param_name in initial_values: raw_value = float(initial_values[param_name]) elif initial_values is not None and "offset" in initial_values: raw_value = float(initial_values["offset"]) elif param_name in data_estimates: raw_value = data_estimates[param_name] else: # Midpoint fallback bounds = parameter_space.get_bounds("offset") raw_value = (bounds[0] + bounds[1]) / 2.0 validated_value, was_clipped = validate_initial_value_bounds( param_name, raw_value, parameter_space ) init_dict[param_name] = validated_value if was_clipped: clipped_params.append(param_name) # 3. Physical parameters (THIRD, in canonical order) for param_name in physical_params: raw_value = get_init_value(param_name, initial_values, parameter_space) validated_value, was_clipped = validate_initial_value_bounds( param_name, raw_value, parameter_space ) init_dict[param_name] = validated_value if was_clipped: clipped_params.append(param_name) if clipped_params: logger.warning( f"{len(clipped_params)} initial values were outside bounds and clipped: " f"{clipped_params}. This may indicate NLSQ fit issues or mismatched bounds." ) logger.debug( f"Built init values for {len(init_dict)} params: {list(init_dict.keys())}" ) # Defensive validation: ensure dict keys match expected order # This catches parameter ordering bugs that could cause subtle issues expected_names = get_param_names_in_order(n_phi, analysis_mode, per_angle_mode) validate_init_values_order(init_dict, expected_names) return init_dict
[docs] def get_param_names_in_order( n_phi: int, analysis_mode: str, per_angle_mode: str = "individual" ) -> list[str]: """Get parameter names in NumPyro sampling order. CRITICAL: This order must match the model sampling order exactly. Parameters ---------- n_phi : int Number of phi angles. analysis_mode : str Analysis mode ("static" or "laminar_flow"). per_angle_mode : str Per-angle scaling mode: "individual", "auto", or "constant". Returns ------- list[str] Parameter names in sampling order. Notes ----- Mode semantics (same as NLSQ): - individual mode: Samples per-angle contrast/offset (2*n_phi params) - auto mode: Samples single averaged contrast/offset (2 params) - constant mode: NO contrast/offset sampled (fixed from quantile estimation) """ names: list[str] = [] # Scaling parameters depend on mode if per_angle_mode == "auto": # Auto mode: sample single averaged contrast/offset (2 params) names.append("contrast") names.append("offset") elif per_angle_mode == "individual": # Individual mode: sample per-angle contrast/offset (2*n_phi params) # 1. Per-angle contrast for i in range(n_phi): names.append(f"contrast_{i}") # 2. Per-angle offset for i in range(n_phi): names.append(f"offset_{i}") # constant/constant_averaged mode: no contrast/offset sampled (fixed) # 3. Physical parameters if analysis_mode == "laminar_flow": names.extend(LAMINAR_PARAMS) else: names.extend(STATIC_PARAMS) return names
[docs] def validate_init_values_order( init_values: dict[str, float], expected_names: list[str], ) -> None: """Validate that init values dictionary keys match expected order. This is a defensive check to catch parameter ordering bugs early. In Python 3.7+, dict preserves insertion order, so key order matters for functions that assume positional correspondence. Parameters ---------- init_values : dict[str, float] Initial values dictionary. expected_names : list[str] Expected parameter names in order. Raises ------ ValueError If parameter order doesn't match. """ actual_names = list(init_values.keys()) if actual_names != expected_names: # P2-R5-04: Check length BEFORE per-element comparison. # Previously used zip(..., strict=True) which raises a generic # "zip() has arguments with different lengths" error on length mismatch, # obscuring the informative message below. Check length first. if len(actual_names) != len(expected_names): raise ValueError( f"Parameter count mismatch!\n" f"Expected {len(expected_names)} params: {expected_names}\n" f"Actual {len(actual_names)} params: {actual_names}" ) # Same length: find first positional mismatch for i, (actual, expected) in enumerate( zip(actual_names, expected_names, strict=False) ): if actual != expected: raise ValueError( f"Parameter order mismatch at position {i}!\n" f"Expected: {expected}\n" f"Actual: {actual}\n" f"Full expected: {expected_names}\n" f"Full actual: {actual_names}" )
# ============================================================================= # NLSQ WARM-START PRIORS (Jan 2026) # =============================================================================
[docs] def build_nlsq_informed_prior( param_name: str, nlsq_value: float, nlsq_std: float | None, bounds: tuple[float, float], width_factor: float = 2.0, ) -> dist.Distribution: """Build a TruncatedNormal prior centered on NLSQ estimate. This provides informative priors for CMC that leverage NLSQ's point estimates. The resulting priors: 1. Center at the NLSQ estimate (faster warmup, better mixing) 2. Have width based on NLSQ uncertainty or parameter range 3. Are truncated to respect parameter bounds 4. Enable posterior contraction metrics (comparing prior vs posterior width) Parameters ---------- param_name : str Parameter name for logging. nlsq_value : float NLSQ point estimate (mean of the prior). nlsq_std : float | None NLSQ standard error estimate. If None, uses 10% of bounds range. bounds : tuple[float, float] Parameter bounds (low, high). width_factor : float Multiplier for NLSQ std to get prior width. Default 2.0 gives ~95% coverage assuming Gaussian posterior. Returns ------- dist.Distribution TruncatedNormal distribution centered at nlsq_value. """ low, high = bounds # Determine prior standard deviation if nlsq_std is not None and nlsq_std > 0: # Use NLSQ uncertainty scaled by width_factor prior_std = nlsq_std * width_factor else: # Fall back to 10% of range (weak informative prior) prior_std = (high - low) * 0.1 # Ensure std is reasonable (not too narrow or too wide) min_std = (high - low) * 0.01 # At least 1% of range max_std = (high - low) * 0.5 # At most 50% of range prior_std = np.clip(prior_std, min_std, max_std) logger.debug( f"NLSQ-informed prior for {param_name}: " f"TruncatedNormal(loc={nlsq_value:.4g}, scale={prior_std:.4g}, " f"bounds=[{low:.4g}, {high:.4g}])" ) return dist.TruncatedNormal( loc=nlsq_value, scale=prior_std, low=low, high=high, )
[docs] def build_nlsq_informed_priors( nlsq_result: dict[str, float], nlsq_uncertainties: dict[str, float] | None, parameter_space: ParameterSpace, analysis_mode: str, n_phi: int, width_factor: float = 2.0, ) -> dict[str, dist.Distribution]: """Build informative priors for all physical parameters from NLSQ results. Parameters ---------- nlsq_result : dict[str, float] NLSQ parameter estimates (e.g., {"D0": 1e10, "alpha": -0.5, ...}). nlsq_uncertainties : dict[str, float] | None NLSQ standard errors for each parameter. If None, uses weak priors. parameter_space : ParameterSpace Parameter space with bounds. analysis_mode : str Analysis mode: "static" or "laminar_flow". n_phi : int Number of phi angles (for per-angle parameters if needed). width_factor : float Width multiplier for priors. Default 2.0. Returns ------- dict[str, dist.Distribution] Dictionary of informative priors keyed by parameter name. """ # Physical params (excluding per-angle scaling like contrast/offset) base_physical = ["D0", "alpha", "D_offset"] if analysis_mode == "laminar_flow": base_physical.extend(["gamma_dot_t0", "beta", "gamma_dot_t_offset", "phi0"]) priors = {} for param_name in base_physical: if param_name in nlsq_result: nlsq_value = nlsq_result[param_name] nlsq_std = ( nlsq_uncertainties.get(param_name) if nlsq_uncertainties else None ) bounds = parameter_space.get_bounds(param_name) priors[param_name] = build_nlsq_informed_prior( param_name=param_name, nlsq_value=nlsq_value, nlsq_std=nlsq_std, bounds=bounds, width_factor=width_factor, ) else: logger.warning( f"Parameter {param_name} not found in NLSQ result, using default prior" ) logger.info( f"Built NLSQ-informed priors for {len(priors)} parameters: " f"{list(priors.keys())}" ) return priors
[docs] def extract_nlsq_values_for_cmc( nlsq_result: dict | Any, ) -> tuple[dict[str, float], dict[str, float] | None]: """Extract parameter values and uncertainties from an NLSQ result. This utility handles various NLSQ result formats and extracts the information needed for CMC warm-start priors. Parameters ---------- nlsq_result : dict or OptimizationResult NLSQ result, either: - OptimizationResult dataclass with parameters/uncertainties arrays - dict with "params"/"parameters"/"best_params" keys - dict with flat structure (parameter names as keys) Returns ------- tuple[dict[str, float], dict[str, float] | None] Tuple of (parameter_values, uncertainties). uncertainties may be None if not available. """ # Handle OptimizationResult dataclass (has 'parameters' attribute as ndarray) if hasattr(nlsq_result, "parameters") and hasattr( nlsq_result.parameters, "__len__" ): import numpy as np params_array = np.asarray(nlsq_result.parameters) n_params = len(params_array) # Infer parameter names from array length # Static mode: contrast, offset, D0, alpha, D_offset (5 params) # Laminar flow: contrast, offset, D0, alpha, D_offset, gamma_dot_t0, beta, # gamma_dot_t_offset, phi0 (9 params) # Per-angle scaling adds more params, but physical params are at the end # Determine analysis mode from parameter structure # Static individual with ≥3 angles: n_params = 3 + 2*n_phi (≥9 for n_phi≥3) # Laminar flow minimum: 7 physical + 2 scaling = 9 # Disambiguation: check if (n_params - 3) is even and ≥ 4 (static individual) # vs (n_params - 7) is even and ≥ 2 (laminar flow) n_static_scaling = n_params - 3 n_laminar_scaling = n_params - 7 is_likely_static_individual = ( n_static_scaling >= 4 and n_static_scaling % 2 == 0 and n_laminar_scaling < 2 ) # Use laminar flow ONLY if scaling count is small (2-3 angles max without ambiguity) # For n_params >= 9, check the analysis_mode hint from model_kwargs if available analysis_mode_hint = getattr(nlsq_result, "analysis_mode", None) # n_params=9 is ambiguous: could be static-individual (3 angles) or laminar_flow (auto_averaged) if n_params == 9 and analysis_mode_hint is None: logger.warning( "Ambiguous n_params=9: could be static-individual (3 angles) or " "laminar_flow auto_averaged. Defaulting to laminar_flow. " "Pass analysis_mode_hint to disambiguate." ) if analysis_mode_hint == "static" or ( analysis_mode_hint is None and is_likely_static_individual ): physical_names = ["D0", "alpha", "D_offset"] n_physical = 3 elif n_params >= 9: # Laminar flow (7 physical + 2 scaling minimum) # With per-angle scaling, first params are contrast/offset per angle # Last 7 are physical params physical_names = [ "D0", "alpha", "D_offset", "gamma_dot_t0", "beta", "gamma_dot_t_offset", "phi0", ] n_physical = 7 else: # Static mode (3 physical + 2 scaling minimum) physical_names = ["D0", "alpha", "D_offset"] n_physical = 3 # Build scaling param names based on array structure n_scaling = n_params - n_physical if n_scaling == 2: # Single contrast/offset (constant mode) scaling_names = ["contrast", "offset"] elif n_scaling > 2 and n_scaling % 2 == 0: # Per-angle scaling n_angles = n_scaling // 2 scaling_names = [f"contrast_{i}" for i in range(n_angles)] + [ f"offset_{i}" for i in range(n_angles) ] else: # Fallback: generic scaling names scaling_names = [f"scaling_{i}" for i in range(n_scaling)] param_names = scaling_names + physical_names # Build values dict values = {name: float(params_array[i]) for i, name in enumerate(param_names)} # Extract uncertainties if available uncertainties = None if ( hasattr(nlsq_result, "uncertainties") and nlsq_result.uncertainties is not None ): unc_array = np.asarray(nlsq_result.uncertainties) if len(unc_array) == n_params: uncertainties = { name: float(unc_array[i]) for i, name in enumerate(param_names) } return values, uncertainties # Handle dict-based result formats if isinstance(nlsq_result, dict): if "params" in nlsq_result: values = nlsq_result["params"] elif "parameters" in nlsq_result: values = nlsq_result["parameters"] elif "best_params" in nlsq_result: values = nlsq_result["best_params"] else: # Assume flat structure - filter out non-parameter keys exclude_keys = { "success", "message", "iterations", "chi_squared", "r_squared", "residuals", "jacobian", "covariance", "uncertainties", "std_errors", } values = { k: v for k, v in nlsq_result.items() if k not in exclude_keys and isinstance(v, (int, float)) } # Extract uncertainties if available uncertainties = None if "uncertainties" in nlsq_result: uncertainties = nlsq_result["uncertainties"] elif "std_errors" in nlsq_result: uncertainties = nlsq_result["std_errors"] # Ensure values are plain floats values = {k: float(v) for k, v in values.items()} if uncertainties: uncertainties = {k: float(v) for k, v in uncertainties.items()} return values, uncertainties # Unknown format - raise informative error raise TypeError( f"extract_nlsq_values_for_cmc expects dict or OptimizationResult, " f"got {type(nlsq_result).__name__}" )