"""NumPyro model for XPCS C2 correlation function.
This module defines the probabilistic model for Bayesian inference
of XPCS parameters using NumPyro.
CRITICAL: Parameter sampling order must match:
1. Per-angle contrast: contrast_0, contrast_1, ... (individual mode only)
2. Per-angle offset: offset_0, offset_1, ... (individual mode only)
3. Physical parameters: D0, alpha, D_offset, [gamma_dot_t0, ...]
Per-Angle Modes (v2.18.0+):
- "individual": Independent contrast + offset per angle (2*n_phi + n_physical + 1 params)
- "constant": Fixed per-angle contrast/offset from quantile estimation (n_physical + 1 params)
- "auto": Selects based on n_phi threshold (constant if n_phi >= 3, else individual)
"""
from __future__ import annotations
import math
from typing import TYPE_CHECKING
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from homodyne.core.physics_cmc import (
ShardGrid,
compute_g1_total,
compute_g1_total_with_precomputed,
)
from homodyne.optimization.cmc.scaling import (
compute_scaling_factors,
sample_scaled_parameter,
)
from homodyne.utils.logging import get_logger
if TYPE_CHECKING:
from homodyne.config.parameter_space import ParameterSpace
from homodyne.optimization.cmc.reparameterization import ReparamConfig
logger = get_logger(__name__)
[docs]
def validate_model_output(
c2_theory: jnp.ndarray,
params: jnp.ndarray,
) -> bool | jnp.ndarray:
"""Validate that model output is physically reasonable.
Parameters
----------
c2_theory : jnp.ndarray
Theoretical C2 values.
params : jnp.ndarray
Parameter values.
Returns
-------
bool
True if output is valid.
"""
del params # Reserved for future parameter-aware checks.
valid = (
jnp.all(jnp.isfinite(c2_theory))
& jnp.all(c2_theory >= -1.0)
& jnp.all(c2_theory <= 10.0)
)
try:
return bool(valid)
except TypeError:
return valid
[docs]
def get_model_param_count(
n_phi: int, analysis_mode: str, per_angle_mode: str = "individual"
) -> int:
"""Get total number of sampled parameters.
Parameters
----------
n_phi : int
Number of phi angles.
analysis_mode : str
Analysis mode.
per_angle_mode : str
Per-angle scaling mode: "individual", "auto", or "constant".
Returns
-------
int
Total number of parameters (including sigma).
Notes
-----
Mode semantics (same as NLSQ):
- individual mode: 2*n_phi (contrast + offset) + physical + sigma
- auto mode: 2 (averaged contrast + offset, SAMPLED) + physical + sigma
- constant mode: 0 per-angle (FIXED from quantiles) + physical + sigma
"""
# Per-angle parameters depend on mode
if per_angle_mode == "constant":
n_params = 0 # No per-angle params sampled (fixed from quantile estimation)
elif per_angle_mode == "auto":
n_params = 2 # Single averaged contrast + offset (SAMPLED)
else:
n_params = n_phi * 2 # contrast_0..n + offset_0..n
# Physical parameters
if analysis_mode == "laminar_flow":
n_params += (
7 # D0, alpha, D_offset, gamma_dot_t0, beta, gamma_dot_t_offset, phi0
)
else:
n_params += 3 # D0, alpha, D_offset
# Noise parameter
n_params += 1 # sigma
return n_params
[docs]
def xpcs_model_scaled(
data: jnp.ndarray,
t1: jnp.ndarray,
t2: jnp.ndarray,
phi_unique: jnp.ndarray,
phi_indices: jnp.ndarray,
q: float,
L: float,
dt: float,
analysis_mode: str,
parameter_space: ParameterSpace,
n_phi: int,
time_grid: jnp.ndarray | None = None,
noise_scale: float = 0.1,
num_shards: int = 1,
shard_grid: ShardGrid | None = None,
**kwargs,
) -> None:
"""NumPyro model with non-centered parameterization for gradient balancing.
This model samples all parameters in normalized (z) space where z ~ N(0,1),
then transforms to original space: P = center + scale * z. This ensures
all gradient magnitudes are balanced, solving the 0% acceptance rate issue
caused by D0 (~10^4) dominating gradients over gamma_dot_t0 (~10^-3).
The physics computation is identical to xpcs_model, only the sampling
space is transformed.
Parameters
----------
data : jnp.ndarray
Observed C2 correlation data, shape (n_total,).
t1, t2 : jnp.ndarray
Time coordinates, shape (n_total,).
phi_unique : jnp.ndarray
Unique phi angles, shape (n_phi,).
phi_indices : jnp.ndarray
Index into per-angle arrays for each point, shape (n_total,).
q : float
Wavevector magnitude.
L : float
Stator-rotor gap length (nm).
dt : float
Time step.
analysis_mode : str
Analysis mode: "static" or "laminar_flow".
parameter_space : ParameterSpace
Parameter space with bounds and priors.
n_phi : int
Number of unique phi angles.
noise_scale : float
Initial estimate of observation noise.
"""
# =========================================================================
# 0. Compute scaling factors and prior tempering scale
# =========================================================================
# P0-1: Use pre-computed scalings from model_kwargs (avoids ~50K Python
# allocations per NUTS leapfrog step). Fallback for backward compat.
scalings = kwargs.get("scalings") or compute_scaling_factors(
parameter_space, n_phi, analysis_mode
)
prior_scale = math.sqrt(num_shards)
# =========================================================================
# 1. Sample per-angle CONTRAST parameters in z-space (FIRST)
# =========================================================================
contrasts = []
for i in range(n_phi):
c_i = sample_scaled_parameter(
f"contrast_{i}", scalings[f"contrast_{i}"], prior_scale=prior_scale
)
contrasts.append(c_i)
contrast_arr = jnp.array(contrasts)
# =========================================================================
# 2. Sample per-angle OFFSET parameters in z-space (SECOND)
# =========================================================================
offsets = []
for i in range(n_phi):
o_i = sample_scaled_parameter(
f"offset_{i}", scalings[f"offset_{i}"], prior_scale=prior_scale
)
offsets.append(o_i)
offset_arr = jnp.array(offsets)
# =========================================================================
# 3. Sample PHYSICAL parameters in z-space (THIRD, with prior tempering)
# =========================================================================
D0 = sample_scaled_parameter("D0", scalings["D0"], prior_scale=prior_scale)
alpha = sample_scaled_parameter("alpha", scalings["alpha"], prior_scale=prior_scale)
D_offset = sample_scaled_parameter(
"D_offset", scalings["D_offset"], prior_scale=prior_scale
)
if analysis_mode == "laminar_flow":
gamma_dot_t0 = sample_scaled_parameter(
"gamma_dot_t0", scalings["gamma_dot_t0"], prior_scale=prior_scale
)
beta = sample_scaled_parameter(
"beta", scalings["beta"], prior_scale=prior_scale
)
gamma_dot_t_offset = sample_scaled_parameter(
"gamma_dot_t_offset",
scalings["gamma_dot_t_offset"],
prior_scale=prior_scale,
)
phi0 = sample_scaled_parameter(
"phi0", scalings["phi0"], prior_scale=prior_scale
)
params = jnp.array(
[D0, alpha, D_offset, gamma_dot_t0, beta, gamma_dot_t_offset, phi0]
)
else:
params = jnp.array([D0, alpha, D_offset])
# =========================================================================
# 4. Compute theoretical g1 using EXACT same physics as NLSQ
# =========================================================================
# P0-2: Use pre-computed wavevector constants from model_kwargs.
wq2hdt = kwargs.get("wavevector_q_squared_half_dt")
sp = kwargs.get("sinc_prefactor")
if shard_grid is not None:
if wq2hdt is None:
wq2hdt = jnp.asarray(0.5 * (q**2) * dt)
if sp is None:
sp = jnp.asarray(0.5 / math.pi * q * L * dt)
g1_all_phi = compute_g1_total_with_precomputed(
params, phi_unique, shard_grid, wq2hdt, sp
)
else:
g1_all_phi = compute_g1_total(
params, t1, t2, phi_unique, q, L, dt, time_grid=time_grid
)
# P1-3: Use pre-computed point_idx from model_kwargs.
point_idx = kwargs.get("point_idx")
if point_idx is None:
point_idx = jnp.arange(phi_indices.shape[0], dtype=jnp.int32)
g1_per_point = g1_all_phi[phi_indices, point_idx]
# =========================================================================
# 5. Apply per-angle scaling to get C2
# =========================================================================
contrast_per_point = contrast_arr[phi_indices]
offset_per_point = offset_arr[phi_indices]
# P1-1: Sanitize g1 BEFORE squaring to prevent NaN gradient contamination.
g1_per_point = jnp.where(jnp.isfinite(g1_per_point), g1_per_point, 1e-10)
c2_theory = contrast_per_point * g1_per_point**2 + offset_per_point
n_nan = jnp.sum(~jnp.isfinite(c2_theory))
numpyro.deterministic("n_numerical_issues", n_nan)
# =========================================================================
# 6. Likelihood with noise model (tighter sigma prior for precision)
# =========================================================================
# Jan 2026: Reduced from 3.0x to 1.5x for tighter precision
# Prior tempering: widen sigma prior by sqrt(K)
sigma_scale = noise_scale * 1.5 * prior_scale
sigma = numpyro.sample("sigma", dist.HalfNormal(scale=sigma_scale))
numpyro.sample("obs", dist.Normal(c2_theory, sigma), obs=data)
[docs]
def xpcs_model_constant(
data: jnp.ndarray,
t1: jnp.ndarray,
t2: jnp.ndarray,
phi_unique: jnp.ndarray,
phi_indices: jnp.ndarray,
q: float,
L: float,
dt: float,
analysis_mode: str,
parameter_space: ParameterSpace,
n_phi: int,
time_grid: jnp.ndarray | None = None,
noise_scale: float = 0.1,
fixed_contrast: jnp.ndarray | None = None,
fixed_offset: jnp.ndarray | None = None,
num_shards: int = 1,
shard_grid: ShardGrid | None = None,
**kwargs,
) -> None:
"""NumPyro model with FIXED per-angle scaling (anti-degeneracy constant mode).
This model uses FIXED per-angle contrast/offset values estimated from
quantile analysis of the raw data. These values are NOT sampled, reducing
the parameter space to only physical parameters + sigma.
This matches NLSQ's anti-degeneracy constant mode and prevents parameter
absorption degeneracy where per-angle params absorb physical signals.
Parameter count comparison (laminar_flow, n_phi=23):
- individual mode: 54 params (46 per-angle + 7 physical + 1 sigma)
- constant mode: 8 params (7 physical + 1 sigma)
Parameters
----------
data : jnp.ndarray
Observed C2 correlation data, shape (n_total,).
t1, t2 : jnp.ndarray
Time coordinates, shape (n_total,).
phi_unique : jnp.ndarray
Unique phi angles, shape (n_phi,).
phi_indices : jnp.ndarray
Index into per-angle arrays for each point, shape (n_total,).
q : float
Wavevector magnitude.
L : float
Stator-rotor gap length (nm).
dt : float
Time step.
analysis_mode : str
Analysis mode: "static" or "laminar_flow".
parameter_space : ParameterSpace
Parameter space with bounds and priors.
n_phi : int
Number of unique phi angles.
noise_scale : float
Initial estimate of observation noise.
fixed_contrast : jnp.ndarray, optional
Fixed per-angle contrast values, shape (n_phi,).
Estimated from quantile analysis. Required for constant mode.
fixed_offset : jnp.ndarray, optional
Fixed per-angle offset values, shape (n_phi,).
Estimated from quantile analysis. Required for constant mode.
"""
# =========================================================================
# 0. Validate fixed scaling arrays
# =========================================================================
if fixed_contrast is None or fixed_offset is None:
raise ValueError(
"xpcs_model_constant requires fixed_contrast and fixed_offset arrays. "
"These should be estimated from quantile analysis before calling."
)
# P0-4: Remap fixed_contrast/fixed_offset from global angle ordering to shard-local.
# fixed_contrast is indexed 0..global_n_phi-1 (built from full dataset),
# but phi_indices in each shard are 0..shard_n_phi-1 (recomputed per shard).
# Without remapping, shard-local index 0 fetches global angle 0's contrast
# instead of the correct global angle for this shard's first angle.
global_phi_unique = kwargs.get("global_phi_unique", None)
if global_phi_unique is not None and phi_unique is not None:
# Map each shard-local phi_unique entry to its global index
# Use nearest-neighbor argmin instead of searchsorted to handle float
# precision differences between shard-local and global phi values
# (consistent with extract_phi_info which uses argmin for n_phi <= 256).
global_indices = jnp.argmin(
jnp.abs(phi_unique[:, None] - global_phi_unique[None, :]), axis=1
)
contrast_arr = fixed_contrast[global_indices] # shape: (shard_n_phi,)
offset_arr = fixed_offset[global_indices] # shape: (shard_n_phi,)
else:
# Fallback: use directly (single-shard or matching phi ordering)
contrast_arr = fixed_contrast
offset_arr = fixed_offset
# =========================================================================
# 1. Compute scaling factors and prior tempering scale
# =========================================================================
# P0-1: Use pre-computed scalings from model_kwargs.
scalings = kwargs.get("scalings") or compute_scaling_factors(
parameter_space, n_phi, analysis_mode
)
prior_scale = math.sqrt(num_shards)
# =========================================================================
# 2. Sample PHYSICAL parameters in z-space (with prior tempering)
# =========================================================================
D0 = sample_scaled_parameter("D0", scalings["D0"], prior_scale=prior_scale)
alpha = sample_scaled_parameter("alpha", scalings["alpha"], prior_scale=prior_scale)
D_offset = sample_scaled_parameter(
"D_offset", scalings["D_offset"], prior_scale=prior_scale
)
if analysis_mode == "laminar_flow":
gamma_dot_t0 = sample_scaled_parameter(
"gamma_dot_t0", scalings["gamma_dot_t0"], prior_scale=prior_scale
)
beta = sample_scaled_parameter(
"beta", scalings["beta"], prior_scale=prior_scale
)
gamma_dot_t_offset = sample_scaled_parameter(
"gamma_dot_t_offset",
scalings["gamma_dot_t_offset"],
prior_scale=prior_scale,
)
phi0 = sample_scaled_parameter(
"phi0", scalings["phi0"], prior_scale=prior_scale
)
params = jnp.array(
[D0, alpha, D_offset, gamma_dot_t0, beta, gamma_dot_t_offset, phi0]
)
else:
params = jnp.array([D0, alpha, D_offset])
# =========================================================================
# 3. Compute theoretical g1 using EXACT same physics as NLSQ
# =========================================================================
# P0-2: Use pre-computed wavevector constants from model_kwargs.
wq2hdt = kwargs.get("wavevector_q_squared_half_dt")
sp = kwargs.get("sinc_prefactor")
if shard_grid is not None:
if wq2hdt is None:
wq2hdt = jnp.asarray(0.5 * (q**2) * dt)
if sp is None:
sp = jnp.asarray(0.5 / math.pi * q * L * dt)
g1_all_phi = compute_g1_total_with_precomputed(
params, phi_unique, shard_grid, wq2hdt, sp
)
else:
g1_all_phi = compute_g1_total(
params, t1, t2, phi_unique, q, L, dt, time_grid=time_grid
)
# P1-3: Use pre-computed point_idx from model_kwargs.
point_idx = kwargs.get("point_idx")
if point_idx is None:
point_idx = jnp.arange(phi_indices.shape[0], dtype=jnp.int32)
g1_per_point = g1_all_phi[phi_indices, point_idx]
# =========================================================================
# 4. Apply FIXED per-angle scaling to get C2
# =========================================================================
contrast_per_point = contrast_arr[phi_indices]
offset_per_point = offset_arr[phi_indices]
# P1-1: Sanitize g1 BEFORE squaring to prevent NaN gradient contamination.
g1_per_point = jnp.where(jnp.isfinite(g1_per_point), g1_per_point, 1e-10)
c2_theory = contrast_per_point * g1_per_point**2 + offset_per_point
n_nan = jnp.sum(~jnp.isfinite(c2_theory))
numpyro.deterministic("n_numerical_issues", n_nan)
# =========================================================================
# 5. Likelihood with noise model (tighter sigma prior for precision)
# =========================================================================
# Prior tempering: widen sigma prior by sqrt(K)
sigma_scale = noise_scale * 1.5 * prior_scale
sigma = numpyro.sample("sigma", dist.HalfNormal(scale=sigma_scale))
numpyro.sample("obs", dist.Normal(c2_theory, sigma), obs=data)
[docs]
def xpcs_model_averaged(
data: jnp.ndarray,
t1: jnp.ndarray,
t2: jnp.ndarray,
phi_unique: jnp.ndarray,
phi_indices: jnp.ndarray,
q: float,
L: float,
dt: float,
analysis_mode: str,
parameter_space: ParameterSpace,
n_phi: int,
time_grid: jnp.ndarray | None = None,
noise_scale: float = 0.1,
fixed_contrast: jnp.ndarray | None = None,
fixed_offset: jnp.ndarray | None = None,
nlsq_prior_config: dict | None = None,
num_shards: int = 1,
shard_grid: ShardGrid | None = None,
**kwargs,
) -> None:
"""NumPyro model with SAMPLED averaged per-angle scaling (auto mode).
This model samples a SINGLE contrast and SINGLE offset value, then broadcasts
them to all phi angles. This matches NLSQ's auto/constant mode behavior where
the averaged scaling parameters are optimized (not fixed).
Parameter count comparison (laminar_flow, n_phi=23):
- individual mode: 54 params (46 per-angle + 7 physical + 1 sigma)
- auto mode (this): 10 params (2 averaged scaling + 7 physical + 1 sigma)
- constant mode: 8 params (7 physical + 1 sigma, scaling FIXED)
Parameters
----------
data : jnp.ndarray
Observed C2 correlation data, shape (n_total,).
t1, t2 : jnp.ndarray
Time coordinates, shape (n_total,).
phi_unique : jnp.ndarray
Unique phi angles, shape (n_phi,).
phi_indices : jnp.ndarray
Index into per-angle arrays for each point, shape (n_total,).
q : float
Wavevector magnitude.
L : float
Stator-rotor gap length (nm).
dt : float
Time step.
analysis_mode : str
Analysis mode: "static" or "laminar_flow".
parameter_space : ParameterSpace
Parameter space with bounds and priors.
n_phi : int
Number of unique phi angles.
noise_scale : float
Initial estimate of observation noise.
fixed_contrast : jnp.ndarray, optional
Ignored in this model. Present for API compatibility.
fixed_offset : jnp.ndarray, optional
Ignored in this model. Present for API compatibility.
"""
# =========================================================================
# 0. Compute scaling factors and prior tempering scale
# =========================================================================
# P0-1: Use pre-computed scalings from model_kwargs.
scalings = kwargs.get("scalings") or compute_scaling_factors(
parameter_space, n_phi, analysis_mode
)
# Prior tempering (Scott et al. 2016): widen priors by sqrt(K) so that
# the combined posterior across K shards has exactly one prior contribution.
# num_shards=1 → prior_scale=1.0 (no tempering, single-shard behavior).
prior_scale = math.sqrt(num_shards)
# =========================================================================
# 1. Sample SINGLE averaged contrast and offset (SAMPLED, not fixed)
# =========================================================================
# Use contrast_0 and offset_0 scaling as representative for the averaged values
if "contrast_0" not in scalings:
raise ValueError(
f"scalings dict missing 'contrast_0' key. "
f"Available keys: {list(scalings.keys())}. n_phi may be 0."
)
contrast = sample_scaled_parameter(
"contrast", scalings["contrast_0"], prior_scale=prior_scale
)
offset = sample_scaled_parameter(
"offset", scalings["offset_0"], prior_scale=prior_scale
)
# Broadcast to all angles
contrast_arr = jnp.full(n_phi, contrast)
offset_arr = jnp.full(n_phi, offset)
# =========================================================================
# 2. Sample PHYSICAL parameters in z-space
# When NLSQ-informed priors are available, use TruncatedNormal centered
# on NLSQ estimates instead of the default scaled Normal priors.
# Prior tempering is applied via width_factor scaling for NLSQ priors,
# or via prior_scale for z-space priors.
# =========================================================================
def _sample_param(name: str) -> jnp.ndarray:
"""Sample a parameter, using NLSQ-informed prior if available."""
if nlsq_prior_config is not None and name in nlsq_prior_config.get(
"values", {}
):
from homodyne.optimization.cmc.priors import build_nlsq_informed_prior
# P1-3: Do NOT temper NLSQ-informed priors. The NLSQ estimate is
# data-driven (not a vague base prior), so tempering by sqrt(K) would
# make the prior effectively uniform for large shard counts, defeating
# the warm-start. Only non-NLSQ z-space priors should be tempered.
base_width = nlsq_prior_config.get("width_factor", 2.0)
prior = build_nlsq_informed_prior(
param_name=name,
nlsq_value=nlsq_prior_config["values"][name],
nlsq_std=nlsq_prior_config.get("uncertainties", {}).get(name)
if nlsq_prior_config.get("uncertainties")
else None,
bounds=parameter_space.get_bounds(name),
width_factor=base_width,
)
return numpyro.sample(name, prior)
return sample_scaled_parameter(name, scalings[name], prior_scale=prior_scale)
D0 = _sample_param("D0")
alpha = _sample_param("alpha")
D_offset = _sample_param("D_offset")
if analysis_mode == "laminar_flow":
gamma_dot_t0 = _sample_param("gamma_dot_t0")
beta = _sample_param("beta")
gamma_dot_t_offset = _sample_param("gamma_dot_t_offset")
phi0 = _sample_param("phi0")
params = jnp.array(
[D0, alpha, D_offset, gamma_dot_t0, beta, gamma_dot_t_offset, phi0]
)
else:
params = jnp.array([D0, alpha, D_offset])
# =========================================================================
# 3. Compute theoretical C2
# =========================================================================
# P0-2: Use pre-computed wavevector constants from model_kwargs.
wq2hdt = kwargs.get("wavevector_q_squared_half_dt")
sp = kwargs.get("sinc_prefactor")
if wq2hdt is None:
wq2hdt = jnp.asarray(0.5 * (q**2) * dt)
if sp is None:
sp = jnp.asarray(0.5 / math.pi * q * L * dt)
if shard_grid is not None:
g1_all_phi = compute_g1_total_with_precomputed(
params, phi_unique, shard_grid, wq2hdt, sp
)
else:
g1_all_phi = compute_g1_total(
params, t1, t2, phi_unique, q, L, dt, time_grid=time_grid
)
# P1-3: Use pre-computed point_idx from model_kwargs.
point_idx = kwargs.get("point_idx")
if point_idx is None:
point_idx = jnp.arange(phi_indices.shape[0], dtype=jnp.int32)
g1 = g1_all_phi[phi_indices, point_idx]
# =========================================================================
# 4. Apply per-point contrast and offset
# =========================================================================
contrast_per_point = contrast_arr[phi_indices]
offset_per_point = offset_arr[phi_indices]
# P1-1: Sanitize g1 BEFORE squaring to prevent NaN gradient contamination.
g1 = jnp.where(jnp.isfinite(g1), g1, 1e-10)
c2_theory = offset_per_point + contrast_per_point * g1**2
n_nan = jnp.sum(~jnp.isfinite(c2_theory))
numpyro.deterministic("n_numerical_issues", n_nan)
# =========================================================================
# 5. Likelihood with noise model (tighter sigma prior for precision)
# =========================================================================
# Jan 2026: Reduced from 3.0x to 1.5x to prevent sigma from absorbing
# systematic errors and inflating uncertainty estimates
# Prior tempering: widen sigma prior by sqrt(K)
sigma_scale = noise_scale * 1.5 * prior_scale
sigma = numpyro.sample("sigma", dist.HalfNormal(scale=sigma_scale))
numpyro.sample("obs", dist.Normal(c2_theory, sigma), obs=data)
[docs]
def xpcs_model_constant_averaged(
data: jnp.ndarray,
t1: jnp.ndarray,
t2: jnp.ndarray,
phi_unique: jnp.ndarray,
phi_indices: jnp.ndarray,
q: float,
L: float,
dt: float,
analysis_mode: str,
parameter_space: ParameterSpace,
n_phi: int,
time_grid: jnp.ndarray | None = None,
noise_scale: float = 0.1,
fixed_contrast: jnp.ndarray | None = None,
fixed_offset: jnp.ndarray | None = None,
nlsq_prior_config: dict | None = None,
num_shards: int = 1,
shard_grid: ShardGrid | None = None,
**kwargs,
) -> None:
"""NumPyro model with FIXED averaged per-angle scaling (NLSQ parity mode).
This model uses FIXED contrast/offset values that are the AVERAGE of per-angle
estimates. These values are NOT sampled, providing exact parity with NLSQ's
"auto" mode behavior.
CRITICAL (Jan 2026): This mode fixes the parameter shift issue where CMC's
"auto" mode (xpcs_model_averaged) samples contrast/offset, introducing extra
uncertainty that biases physical parameters. By using FIXED averaged values,
the physical parameter posteriors should match NLSQ estimates.
Parameter count comparison (laminar_flow):
- individual mode: 54 params (46 per-angle + 7 physical + 1 sigma)
- auto mode (xpcs_model_averaged): 10 params (2 sampled scaling + 7 physical + 1 sigma)
- constant mode (xpcs_model_constant): 8 params (7 physical + 1 sigma, per-angle fixed)
- constant_averaged mode (this): 8 params (7 physical + 1 sigma, averaged fixed)
Parameters
----------
data : jnp.ndarray
Observed C2 correlation data, shape (n_total,).
t1, t2 : jnp.ndarray
Time coordinates, shape (n_total,).
phi_unique : jnp.ndarray
Unique phi angles, shape (n_phi,).
phi_indices : jnp.ndarray
Index into per-angle arrays for each point, shape (n_total,).
q : float
Wavevector magnitude.
L : float
Stator-rotor gap length (nm).
dt : float
Time step.
analysis_mode : str
Analysis mode: "static" or "laminar_flow".
parameter_space : ParameterSpace
Parameter space with bounds and priors.
n_phi : int
Number of unique phi angles.
noise_scale : float
Initial estimate of observation noise.
fixed_contrast : jnp.ndarray
Fixed per-angle contrast values, shape (n_phi,). Will be averaged.
fixed_offset : jnp.ndarray
Fixed per-angle offset values, shape (n_phi,). Will be averaged.
"""
# =========================================================================
# 0. Compute AVERAGED fixed scaling (NLSQ parity)
# =========================================================================
if fixed_contrast is None or fixed_offset is None:
raise ValueError(
"xpcs_model_constant_averaged requires fixed_contrast and fixed_offset arrays. "
"These should be estimated from quantile analysis before calling."
)
# P0-3: Remap fixed arrays from global to shard-local angles before averaging.
# Without this, shards with a subset of angles would average over absent angles,
# biasing the scaling toward irrelevant angle contrasts.
global_phi_unique = kwargs.get("global_phi_unique", None)
if global_phi_unique is not None and phi_unique is not None:
# Use nearest-neighbor argmin instead of searchsorted to handle float
# precision differences (consistent with xpcs_model_constant).
global_indices = jnp.argmin(
jnp.abs(phi_unique[:, None] - global_phi_unique[None, :]), axis=1
)
shard_contrast = fixed_contrast[global_indices]
shard_offset = fixed_offset[global_indices]
else:
shard_contrast = fixed_contrast
shard_offset = fixed_offset
# Average the shard-local per-angle values and broadcast to all angles
avg_contrast = jnp.mean(shard_contrast)
avg_offset = jnp.mean(shard_offset)
contrast_arr = jnp.full(n_phi, avg_contrast)
offset_arr = jnp.full(n_phi, avg_offset)
# Log the averaged values (stored as deterministics for diagnostics)
numpyro.deterministic("fixed_contrast_mean", avg_contrast)
numpyro.deterministic("fixed_offset_mean", avg_offset)
# =========================================================================
# 1. Compute scaling factors and prior tempering scale
# =========================================================================
# P0-1: Use pre-computed scalings from model_kwargs.
scalings = kwargs.get("scalings") or compute_scaling_factors(
parameter_space, n_phi, analysis_mode
)
prior_scale = math.sqrt(num_shards)
# =========================================================================
# 2. Sample PHYSICAL parameters in z-space (8 params total: 7 physical + sigma)
# When NLSQ-informed priors are available, use TruncatedNormal centered
# on NLSQ estimates instead of the default scaled Normal priors.
# Prior tempering applied via width_factor or prior_scale.
# =========================================================================
def _sample_param(name: str) -> jnp.ndarray:
"""Sample a parameter, using NLSQ-informed prior if available."""
if nlsq_prior_config is not None and name in nlsq_prior_config.get(
"values", {}
):
from homodyne.optimization.cmc.priors import build_nlsq_informed_prior
base_width = nlsq_prior_config.get("width_factor", 2.0)
prior = build_nlsq_informed_prior(
param_name=name,
nlsq_value=nlsq_prior_config["values"][name],
nlsq_std=nlsq_prior_config.get("uncertainties", {}).get(name)
if nlsq_prior_config.get("uncertainties")
else None,
bounds=parameter_space.get_bounds(name),
width_factor=base_width,
)
return numpyro.sample(name, prior)
return sample_scaled_parameter(name, scalings[name], prior_scale=prior_scale)
D0 = _sample_param("D0")
alpha = _sample_param("alpha")
D_offset = _sample_param("D_offset")
if analysis_mode == "laminar_flow":
gamma_dot_t0 = _sample_param("gamma_dot_t0")
beta = _sample_param("beta")
gamma_dot_t_offset = _sample_param("gamma_dot_t_offset")
phi0 = _sample_param("phi0")
params = jnp.array(
[D0, alpha, D_offset, gamma_dot_t0, beta, gamma_dot_t_offset, phi0]
)
else:
params = jnp.array([D0, alpha, D_offset])
# =========================================================================
# 3. Compute theoretical C2
# =========================================================================
# P0-2: Use pre-computed wavevector constants from model_kwargs.
wq2hdt = kwargs.get("wavevector_q_squared_half_dt")
sp = kwargs.get("sinc_prefactor")
if wq2hdt is None:
wq2hdt = jnp.asarray(0.5 * (q**2) * dt)
if sp is None:
sp = jnp.asarray(0.5 / math.pi * q * L * dt)
if shard_grid is not None:
g1_all_phi = compute_g1_total_with_precomputed(
params, phi_unique, shard_grid, wq2hdt, sp
)
else:
g1_all_phi = compute_g1_total(
params, t1, t2, phi_unique, q, L, dt, time_grid=time_grid
)
# P1-3: Use pre-computed point_idx from model_kwargs.
point_idx = kwargs.get("point_idx")
if point_idx is None:
point_idx = jnp.arange(phi_indices.shape[0], dtype=jnp.int32)
g1 = g1_all_phi[phi_indices, point_idx]
# =========================================================================
# 4. Apply FIXED averaged scaling to get C2
# =========================================================================
contrast_per_point = contrast_arr[phi_indices]
offset_per_point = offset_arr[phi_indices]
# P1-1: Sanitize g1 BEFORE squaring to prevent NaN gradient contamination.
g1 = jnp.where(jnp.isfinite(g1), g1, 1e-10)
c2_theory = offset_per_point + contrast_per_point * g1**2
n_nan = jnp.sum(~jnp.isfinite(c2_theory))
numpyro.deterministic("n_numerical_issues", n_nan)
# =========================================================================
# 5. Likelihood with noise model (tighter sigma prior for precision)
# =========================================================================
# Jan 2026: Use tighter sigma prior (1.5x vs 3.0x noise_scale) for better precision
# Prior tempering: widen sigma prior by sqrt(K)
sigma_scale = noise_scale * 1.5 * prior_scale
sigma = numpyro.sample("sigma", dist.HalfNormal(scale=sigma_scale))
numpyro.sample("obs", dist.Normal(c2_theory, sigma), obs=data)
[docs]
def xpcs_model_reparameterized(
data: jnp.ndarray,
t1: jnp.ndarray,
t2: jnp.ndarray,
phi_unique: jnp.ndarray,
phi_indices: jnp.ndarray,
q: float,
L: float,
dt: float,
analysis_mode: str,
parameter_space: ParameterSpace,
n_phi: int,
time_grid: jnp.ndarray | None = None,
noise_scale: float = 0.1,
fixed_contrast: jnp.ndarray | None = None,
fixed_offset: jnp.ndarray | None = None,
reparam_config: ReparamConfig | None = None,
nlsq_prior_config: dict | None = None,
num_shards: int = 1,
t_ref: float = 1.0,
shard_grid: ShardGrid | None = None,
**kwargs,
) -> None:
"""NumPyro model with reference-time reparameterized sampling space.
This model transforms correlated parameters to orthogonal sampling space:
- D0, alpha → log_D_ref, alpha where D_ref = D0 * t_ref^alpha (decorrelates)
- D_offset → D_offset_ratio = D_offset / D_ref (linear, handles negative D_offset)
- gamma_dot_t0, beta → log_gamma_ref, beta where gamma_ref = gamma_dot_t0 * t_ref^beta
The original physics parameters (D0, D_offset, gamma_dot_t0) are computed
as deterministic transforms and included in the trace for output.
D_offset_ratio uses a TruncatedNormal prior (low=-1+ε), supporting negative
D_offset for jammed/arrested systems while enforcing D_ref + D_offset > 0 at
t_ref. Inverse: D_offset = D_ref * ratio.
Parameters
----------
reparam_config : ReparamConfig, optional
Reparameterization configuration. If None, uses defaults.
nlsq_prior_config : dict, optional
NLSQ-informed prior configuration with keys:
- "values": dict of NLSQ parameter estimates
- "uncertainties": dict of NLSQ standard errors
- "width_factor": prior width multiplier
- "reparam_values": dict of reparameterized NLSQ values (log_D_ref, etc.)
- "reparam_uncertainties": dict of reparameterized uncertainties
t_ref : float
Reference time for reparameterization (default: 1.0).
[Other parameters same as xpcs_model_averaged]
"""
from homodyne.optimization.cmc.reparameterization import ReparamConfig
if reparam_config is None:
reparam_config = ReparamConfig()
# Use t_ref from reparam_config if set, otherwise from kwarg
effective_t_ref = reparam_config.t_ref if reparam_config.t_ref != 1.0 else t_ref
# =========================================================================
# 0. Compute scaling factors and prior tempering scale
# =========================================================================
# P0-1: Use pre-computed scalings from model_kwargs.
scalings = kwargs.get("scalings") or compute_scaling_factors(
parameter_space, n_phi, analysis_mode
)
prior_scale = math.sqrt(num_shards)
# =========================================================================
# 1. Sample SINGLE averaged contrast and offset (same as auto mode)
# =========================================================================
contrast = sample_scaled_parameter(
"contrast", scalings["contrast_0"], prior_scale=prior_scale
)
offset = sample_scaled_parameter(
"offset", scalings["offset_0"], prior_scale=prior_scale
)
contrast_arr = jnp.full(n_phi, contrast)
offset_arr = jnp.full(n_phi, offset)
# =========================================================================
# Helper: sample parameter with NLSQ-informed prior if available
# =========================================================================
def _sample_param(name: str) -> jnp.ndarray:
"""Sample a parameter, using NLSQ-informed prior if available."""
if nlsq_prior_config is not None and name in nlsq_prior_config.get(
"values", {}
):
from homodyne.optimization.cmc.priors import build_nlsq_informed_prior
# P1-3: Do NOT temper NLSQ-informed priors (see averaged model).
base_width = nlsq_prior_config.get("width_factor", 2.0)
prior = build_nlsq_informed_prior(
param_name=name,
nlsq_value=nlsq_prior_config["values"][name],
nlsq_std=nlsq_prior_config.get("uncertainties", {}).get(name)
if nlsq_prior_config.get("uncertainties")
else None,
bounds=parameter_space.get_bounds(name),
width_factor=base_width,
)
return numpyro.sample(name, prior)
return sample_scaled_parameter(name, scalings[name], prior_scale=prior_scale)
# =========================================================================
# 2. Sample REPARAMETERIZED physical parameters
# =========================================================================
# Alpha is sampled directly (not reparameterized, nearly orthogonal to D_ref)
alpha = _sample_param("alpha")
# Initialize reparam dicts unconditionally so both enable_d_ref and
# enable_gamma_ref branches can access them without NameError.
reparam_vals = (
nlsq_prior_config.get("reparam_values", {}) if nlsq_prior_config else {}
)
reparam_uncs = (
nlsq_prior_config.get("reparam_uncertainties", {}) if nlsq_prior_config else {}
)
if reparam_config.enable_d_ref:
# --- Reference-time diffusion reparameterization ---
# Sample log_D_ref where D_ref = D0 * t_ref^alpha (well-constrained by data)
log_D_ref_loc = reparam_vals.get("log_D_ref", 10.0) # ~exp(10) ≈ 22K
log_D_ref_scale = reparam_uncs.get("log_D_ref", 1.0)
# Prior tempering: widen by sqrt(K), capped at 10 log-units.
# D0 = exp(log_D_ref) * t_ref^(-alpha); a ±10 deviation in log space
# spans exp(-10)..exp(10) ≈ 4.5e-5..22026 relative to the reference,
# covering any physically plausible diffusion coefficient. Without a
# cap, large K (e.g. 1000) gives scale≈31.6 log-units, spanning 27
# orders of magnitude and destroying NUTS warmup efficiency.
log_D_ref_scale_tempered = min(log_D_ref_scale * prior_scale, 10.0)
log_D_ref = numpyro.sample(
"log_D_ref",
dist.Normal(loc=log_D_ref_loc, scale=log_D_ref_scale_tempered),
)
D_ref = jnp.exp(log_D_ref)
# D_offset_ratio = D_offset / D_ref (TruncatedNormal prior)
# Handles negative D_offset for jammed/arrested systems.
# Physical lower bound: ratio > -1 ensures D_ref + D_offset > 0 at t_ref.
# Values <= -1 make total diffusion non-positive, which is unphysical and
# causes physics_cmc.py to enter the zero-gradient clamp region under NUTS.
# Inverse: D_offset = D_ref * ratio (exact, no clipping needed).
ratio_loc = reparam_vals.get("D_offset_ratio", 0.0)
ratio_scale = reparam_uncs.get("D_offset_ratio", 0.5)
# Cap at 2.0: avoids near-uniform prior under large K shard counts.
ratio_scale_tempered = min(ratio_scale * prior_scale, 2.0)
D_offset_ratio = numpyro.sample(
"D_offset_ratio",
dist.TruncatedNormal(
loc=ratio_loc,
scale=ratio_scale_tempered,
low=-1.0 + 1e-4, # physical floor: D_ref + D_offset > 0 at t_ref
),
)
# Recover physics parameters as deterministics
D0 = numpyro.deterministic("D0", D_ref * effective_t_ref ** (-alpha))
D_offset = numpyro.deterministic("D_offset", D_ref * D_offset_ratio)
else:
# Standard sampling
D0 = _sample_param("D0")
D_offset = _sample_param("D_offset")
if analysis_mode == "laminar_flow":
# Beta is sampled directly (nearly orthogonal to gamma_ref)
beta = _sample_param("beta")
if reparam_config.enable_gamma_ref:
# --- Reference-time shear reparameterization ---
# Sample log_gamma_ref where gamma_ref = gamma_dot_t0 * t_ref^beta
log_gamma_ref_loc = reparam_vals.get("log_gamma_ref", -5.0)
log_gamma_ref_scale = reparam_uncs.get("log_gamma_ref", 1.0)
# Cap at 10 log-units (same rationale as log_D_ref_scale_tempered above).
log_gamma_ref_scale_tempered = min(log_gamma_ref_scale * prior_scale, 10.0)
log_gamma_ref = numpyro.sample(
"log_gamma_ref",
dist.Normal(loc=log_gamma_ref_loc, scale=log_gamma_ref_scale_tempered),
)
# Recover gamma_dot_t0 = exp(log_gamma_ref) * t_ref^(-beta)
gamma_dot_t0 = numpyro.deterministic(
"gamma_dot_t0", jnp.exp(log_gamma_ref) * effective_t_ref ** (-beta)
)
else:
gamma_dot_t0 = _sample_param("gamma_dot_t0")
gamma_dot_t_offset = _sample_param("gamma_dot_t_offset")
phi0 = _sample_param("phi0")
params = jnp.array(
[D0, alpha, D_offset, gamma_dot_t0, beta, gamma_dot_t_offset, phi0]
)
else:
params = jnp.array([D0, alpha, D_offset])
# =========================================================================
# 3. Compute theoretical C2
# =========================================================================
# P0-2: Use pre-computed wavevector constants from model_kwargs.
wq2hdt = kwargs.get("wavevector_q_squared_half_dt")
sp = kwargs.get("sinc_prefactor")
if wq2hdt is None:
wq2hdt = jnp.asarray(0.5 * (q**2) * dt)
if sp is None:
sp = jnp.asarray(0.5 / math.pi * q * L * dt)
if shard_grid is not None:
g1_all_phi = compute_g1_total_with_precomputed(
params, phi_unique, shard_grid, wq2hdt, sp
)
else:
g1_all_phi = compute_g1_total(
params, t1, t2, phi_unique, q, L, dt, time_grid=time_grid
)
# P1-3: Use pre-computed point_idx from model_kwargs.
point_idx = kwargs.get("point_idx")
if point_idx is None:
point_idx = jnp.arange(phi_indices.shape[0], dtype=jnp.int32)
g1 = g1_all_phi[phi_indices, point_idx]
# =========================================================================
# 4. Apply per-point contrast and offset
# =========================================================================
contrast_per_point = contrast_arr[phi_indices]
offset_per_point = offset_arr[phi_indices]
# P1-1: Sanitize g1 BEFORE squaring to prevent NaN gradient contamination.
g1 = jnp.where(jnp.isfinite(g1), g1, 1e-10)
c2_theory = offset_per_point + contrast_per_point * g1**2
n_nan = jnp.sum(~jnp.isfinite(c2_theory))
numpyro.deterministic("n_numerical_issues", n_nan)
# =========================================================================
# 5. Likelihood with noise model (with prior tempering)
# =========================================================================
sigma_scale = noise_scale * 1.5 * prior_scale
sigma = numpyro.sample("sigma", dist.HalfNormal(scale=sigma_scale))
numpyro.sample("obs", dist.Normal(c2_theory, sigma), obs=data)
[docs]
def get_xpcs_model(
per_angle_mode: str = "individual", use_reparameterization: bool = False
):
"""Get the appropriate NumPyro model function for the given per-angle mode.
Parameters
----------
per_angle_mode : str
Per-angle scaling mode: "individual", "auto", "constant", or "constant_averaged".
use_reparameterization : bool
If True and per_angle_mode is "auto", use reparameterized model for
better sampling of correlated parameters (D_total instead of D0/D_offset,
log_gamma_dot_t0 instead of gamma_dot_t0).
Returns
-------
callable
NumPyro model function.
Notes
-----
Mode semantics (same as NLSQ):
- individual: Uses xpcs_model_scaled which samples per-angle contrast/offset
(n_phi*2 + 7 physical + 1 sigma params for laminar_flow).
- auto: Uses xpcs_model_averaged which samples SINGLE averaged contrast/offset
(2 averaged + 7 physical + 1 sigma = 10 params for laminar_flow).
If use_reparameterization=True, uses xpcs_model_reparameterized instead.
- constant: Uses xpcs_model_constant which requires fixed_contrast/fixed_offset
arrays (NOT sampled, 7 physical + 1 sigma = 8 params for laminar_flow).
- constant_averaged: Uses xpcs_model_constant_averaged with FIXED averaged scaling
(NOT sampled, 7 physical + 1 sigma = 8 params). Provides exact NLSQ parity.
"""
if per_angle_mode == "auto":
if use_reparameterization:
logger.info(
"CMC: Using reparameterized auto mode model "
"(log_D_ref + log_gamma_ref sampling)"
)
return xpcs_model_reparameterized
else:
logger.info(
"CMC: Using auto mode model (sampled averaged scaling, 10 params)"
)
return xpcs_model_averaged
elif per_angle_mode == "constant":
logger.info(
"CMC: Using constant mode model (fixed per-angle scaling, 8 params)"
)
return xpcs_model_constant
elif per_angle_mode == "constant_averaged":
logger.info(
"CMC: Using constant_averaged mode model (fixed averaged scaling, 8 params, NLSQ parity)"
)
return xpcs_model_constant_averaged
elif per_angle_mode == "individual":
logger.info("CMC: Using individual mode model (sampled per-angle scaling)")
return xpcs_model_scaled
else:
# T3-2: Reject unsupported modes (e.g., "fourier") instead of silently
# falling through to individual mode with wrong parameterization.
raise ValueError(
f"Unsupported CMC per_angle_mode: '{per_angle_mode}'. "
f"Valid modes: 'auto', 'constant', 'constant_averaged', 'individual'. "
f"NLSQ 'fourier' mode has no CMC equivalent."
)