Source code for homodyne.optimization.cmc.results

"""CMC result dataclass and ArviZ integration.

This module provides the CMCResult dataclass that encapsulates MCMC
posterior samples and diagnostics in a format compatible with ArviZ
and the existing CLI save functions.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any

try:
    import arviz as az

    HAS_ARVIZ = True
except ImportError:
    HAS_ARVIZ = False
    az = None  # type: ignore

import numpy as np

from homodyne.config.parameter_registry import ParameterRegistry
from homodyne.optimization.cmc.diagnostics import DEFAULT_MIN_ESS
from homodyne.optimization.cmc.sampler import MCMCSamples, SamplingStats
from homodyne.utils.logging import get_logger

logger = get_logger(__name__)

# Scaling base names derived from registry is_scaling flag
_SCALING_NAMES = ParameterRegistry().scaling_names


def _get_scaling_base(name: str) -> str | None:
    """Return the base scaling name if *name* is an indexed scaling param, else None."""
    for sname in _SCALING_NAMES:
        if name.startswith(f"{sname}_"):
            return sname
    return None


[docs] class ParameterStats(dict): """Hybrid mapping/sequence for posterior summaries. Supports dict-style access by name (for tests/back-compat) and list/array-style access by index (for plotting utilities). """
[docs] def __init__(self, ordered_names: list[str], values: list[float]) -> None: super().__init__(zip(ordered_names, values, strict=True)) self._ordered_names = list(ordered_names) self._ordered_values = list(values)
def __getitem__(self, key: int | str) -> float: if isinstance(key, int): return self._ordered_values[key] return super().__getitem__(key) def __len__(self) -> int: # sequence semantics return len(self._ordered_values) def __array__(self, dtype=None) -> np.ndarray: # numpy friendliness return np.asarray(self._ordered_values, dtype=dtype) @property def as_array(self) -> np.ndarray: """Return ordered values as numpy array.""" return np.asarray(self._ordered_values, dtype=float)
[docs] def tolist(self) -> list[float]: """Return ordered values as list (numpy compatibility).""" return list(self._ordered_values)
[docs] @dataclass class CMCResult: """CMC analysis result with posterior samples and diagnostics. This dataclass is compatible with save_mcmc_results() in cli/commands.py. Attributes ---------- parameters : np.ndarray Posterior mean values, shape (n_params,). uncertainties : np.ndarray Posterior standard deviations, shape (n_params,). param_names : list[str] Parameter names in sampling order. samples : dict[str, np.ndarray] Raw samples, {name: (n_chains, n_samples)}. convergence_status : str "converged" | "divergences" | "not_converged". r_hat : dict[str, float] Per-parameter R-hat values. ess_bulk : dict[str, float] Per-parameter bulk ESS. ess_tail : dict[str, float] Per-parameter tail ESS. divergences : int Total number of divergent transitions. inference_data : az.InferenceData ArviZ InferenceData for plotting. execution_time : float Total sampling time in seconds. warmup_time : float Warmup time in seconds. n_chains : int Number of MCMC chains. n_samples : int Samples per chain. n_warmup : int Warmup samples. analysis_mode : str Analysis mode used. covariance : np.ndarray Parameter covariance matrix (from samples). chi_squared : float Placeholder for compatibility (not directly computed in MCMC). reduced_chi_squared : float Placeholder for compatibility. device_info : dict[str, Any] Device used for computation. """ # Core results parameters: np.ndarray uncertainties: np.ndarray param_names: list[str] # MCMC-specific samples: dict[str, np.ndarray] convergence_status: str r_hat: dict[str, float] ess_bulk: dict[str, float] ess_tail: dict[str, float] divergences: int # ArviZ inference_data: az.InferenceData # Timing execution_time: float warmup_time: float # Config n_chains: int = 4 n_samples: int = 2000 n_warmup: int = 500 analysis_mode: str = "static" per_angle_mode: str = ( "auto" # Per-angle scaling mode (auto/constant/constant_averaged/individual) ) num_shards: int = 1 # Number of shards combined (for correct divergence rate) # Compatibility fields covariance: np.ndarray = field(default_factory=lambda: np.array([])) chi_squared: float = 0.0 reduced_chi_squared: float = 0.0 device_info: dict[str, Any] = field(default_factory=dict) recovery_actions: list[str] = field(default_factory=list) quality_flag: str = "good" # Legacy/CLI plot compatibility mean_params: ParameterStats = field(default_factory=lambda: ParameterStats([], [])) std_params: ParameterStats = field(default_factory=lambda: ParameterStats([], [])) mean_contrast: float | None = None std_contrast: float | None = None mean_offset: float | None = None std_offset: float | None = None
[docs] def is_cmc_result(self) -> bool: """Return True - required by CLI for diagnostic generation.""" return True
@property def success(self) -> bool: """Return True if sampling converged (backward compatibility).""" return self.convergence_status == "converged" @property def message(self) -> str: """Return descriptive message about result.""" if self.convergence_status == "converged": return f"CMC sampling converged. {self.divergences} divergences." elif self.convergence_status == "divergences": # Account for num_shards in divergence rate calculation total_transitions = self.num_shards * self.n_chains * self.n_samples rate = self.divergences / total_transitions if total_transitions > 0 else 0 return f"CMC completed with {rate:.1%} divergence rate." else: return f"CMC did not converge: {self.convergence_status}"
[docs] @classmethod def from_mcmc_samples( cls, mcmc_samples: MCMCSamples, stats: SamplingStats, analysis_mode: str, n_warmup: int = 500, min_ess: float | None = None, ) -> CMCResult: """Create CMCResult from MCMC samples. Parameters ---------- mcmc_samples : MCMCSamples Raw MCMC samples. stats : SamplingStats Sampling statistics. analysis_mode : str Analysis mode used. n_warmup : int Number of warmup samples. min_ess : float | None Minimum effective sample size for convergence checks. If None, uses ``DEFAULT_MIN_ESS`` from diagnostics module. Returns ------- CMCResult Complete result object. """ from homodyne.optimization.cmc.diagnostics import ( DEFAULT_MIN_ESS, check_convergence, compute_ess, compute_r_hat, ) if min_ess is None: min_ess = DEFAULT_MIN_ESS # Compute diagnostics r_hat = compute_r_hat(mcmc_samples.samples) ess_bulk, ess_tail = compute_ess(mcmc_samples.samples) # Check convergence # Pass num_shards for correct divergence rate calculation in CMC convergence_status, warnings = check_convergence( r_hat=r_hat, ess_bulk=ess_bulk, divergences=stats.num_divergent, n_samples=mcmc_samples.n_samples, n_chains=mcmc_samples.n_chains, min_ess=min_ess, num_shards=getattr(mcmc_samples, "num_shards", 1), ) if warnings: for warning in warnings: logger.warning(f"Convergence warning: {warning}") # Compute posterior statistics param_names = mcmc_samples.param_names parameters = np.zeros(len(param_names)) uncertainties = np.zeros(len(param_names)) # Aggregate convenience stats for legacy consumers (CLI plots, writers) contrast_values: list[float] = [] contrast_stds: list[float] = [] offset_values: list[float] = [] offset_stds: list[float] = [] physical_param_names: list[str] = [] mean_params_physical: list[float] = [] std_params_physical: list[float] = [] for i, name in enumerate(param_names): if name in mcmc_samples.samples: samples_flat = mcmc_samples.samples[name].flatten() parameters[i] = np.nanmean(samples_flat) uncertainties[i] = np.nanstd(samples_flat) # CRITICAL FIX (Dec 2025): Exclude _z (z-space) parameters from legacy stats # The scaled model samples contrast_0_z ~ N(0,1) and registers contrast_0 as # deterministic. Only use original-space values (without _z suffix). # Classify using registry is_scaling flag _base = _get_scaling_base(name) if _base == "contrast" and not name.endswith("_z"): contrast_values.append(float(parameters[i])) contrast_stds.append(float(uncertainties[i])) elif _base == "offset" and not name.endswith("_z"): offset_values.append(float(parameters[i])) offset_stds.append(float(uncertainties[i])) elif not name.endswith("_z"): # Physical parameters (D0, alpha, etc.) - exclude _z variants physical_param_names.append(name) mean_params_physical.append(float(parameters[i])) std_params_physical.append(float(uncertainties[i])) # Skip _z parameters - they are z-space samples, not original-space values # Compute covariance (requires at least 2 samples) # P2-R6-03: Guard against param_names entries absent from samples # (e.g. deterministic sites not returned by get_samples, failed shards). present_names = [n for n in param_names if n in mcmc_samples.samples] all_samples = ( np.column_stack( [mcmc_samples.samples[name].flatten() for name in present_names] ) if present_names else np.zeros((0, 0)) ) if all_samples.shape[0] < 2: # Not enough samples for covariance - return zeros covariance = np.zeros((all_samples.shape[1], all_samples.shape[1])) else: # Filter rows with any NaN before computing covariance. # NaN samples arise from failed shards or NUTS divergences that # produce non-finite values; np.cov propagates them to the full matrix. finite_mask = np.all(np.isfinite(all_samples), axis=1) all_samples_finite = all_samples[finite_mask] if all_samples_finite.shape[0] < 2: covariance = np.zeros((all_samples.shape[1], all_samples.shape[1])) else: # Q4: Subsample to at most 10K rows before computing covariance. # np.cov is O(N*P^2); for N=600K, P=9 this takes ~1 s and uses ~170 MB. # 10K rows give a statistically equivalent 9x9 result in ~50 ms. _max_cov_samples = 10_000 if all_samples_finite.shape[0] > _max_cov_samples: # Fixed seed for reproducible covariance subsampling. rng = np.random.default_rng(seed=0) idx = rng.choice( all_samples_finite.shape[0], size=_max_cov_samples, replace=False, ) covariance = np.cov(all_samples_finite[idx], rowvar=False) else: covariance = np.cov(all_samples_finite, rowvar=False) # Create ArviZ InferenceData inference_data = create_inference_data(mcmc_samples) mean_params_stats = ParameterStats(physical_param_names, mean_params_physical) std_params_stats = ParameterStats(physical_param_names, std_params_physical) if contrast_values: mean_params_stats["contrast"] = float(np.nanmean(contrast_values)) std_params_stats["contrast"] = float(np.nanmean(contrast_stds)) if offset_values: mean_params_stats["offset"] = float(np.nanmean(offset_values)) std_params_stats["offset"] = float(np.nanmean(offset_stds)) return cls( parameters=parameters, uncertainties=uncertainties, param_names=param_names, samples=mcmc_samples.samples, convergence_status=convergence_status, r_hat=r_hat, ess_bulk=ess_bulk, ess_tail=ess_tail, divergences=stats.num_divergent, inference_data=inference_data, execution_time=stats.total_time, warmup_time=stats.warmup_time, n_chains=mcmc_samples.n_chains, n_samples=mcmc_samples.n_samples, n_warmup=n_warmup, analysis_mode=analysis_mode, num_shards=getattr(mcmc_samples, "num_shards", 1), covariance=covariance, device_info={"platform": "cpu", "device": "CPU"}, # Legacy/compat fields expected by CLI writers/plots mean_params=mean_params_stats, std_params=std_params_stats, mean_contrast=mean_params_stats.get("contrast"), std_contrast=std_params_stats.get("contrast"), mean_offset=mean_params_stats.get("offset"), std_offset=std_params_stats.get("offset"), )
[docs] def get_posterior_stats(self) -> dict[str, dict[str, float]]: """Get posterior statistics for each parameter. Returns ------- dict[str, dict[str, float]] Statistics per parameter: mean, std, median, hdi_5%, hdi_95%. """ stats: dict[str, dict[str, float]] = {} for name in self.param_names: if name not in self.samples: continue samples_flat = self.samples[name].flatten() stats[name] = { "mean": float(np.nanmean(samples_flat)), "std": float(np.nanstd(samples_flat)), "median": float(np.nanmedian(samples_flat)), "hdi_5%": float(np.nanpercentile(samples_flat, 5)), "hdi_95%": float(np.nanpercentile(samples_flat, 95)), "r_hat": self.r_hat.get(name, np.nan), "ess_bulk": self.ess_bulk.get(name, np.nan), "ess_tail": self.ess_tail.get(name, np.nan), } return stats
[docs] def get_samples_array(self) -> np.ndarray: """Get samples as 3D array. Returns ------- np.ndarray Shape (n_chains, n_samples, n_params). """ n_params = len(self.param_names) samples_3d = np.zeros((self.n_chains, self.n_samples, n_params)) for i, name in enumerate(self.param_names): if name in self.samples: samples_3d[:, :, i] = self.samples[name] return samples_3d
[docs] def validate_parameters(self, n_phi: int | None = None) -> list[str]: """Validate that result contains expected parameters. Parameters ---------- n_phi : int | None Number of phi angles expected. If None, infers from samples. Returns ------- list[str] List of validation warnings (empty if all valid). """ warnings: list[str] = [] # Check required physical parameters for analysis mode if self.analysis_mode == "laminar_flow": required_physical = [ "D0", "alpha", "D_offset", "gamma_dot_t0", "beta", "gamma_dot_t_offset", "phi0", ] else: required_physical = ["D0", "alpha", "D_offset"] for param in required_physical: if param not in self.samples: warnings.append(f"Missing required parameter: {param}") elif not np.all(np.isfinite(self.samples[param])): warnings.append(f"Non-finite values in parameter: {param}") # Check for contrast/offset parameters — only in individual mode. # auto mode uses "contrast"/"offset" sites (not indexed), and # constant/constant_averaged modes have no sampled contrast/offset sites. # Checking for contrast_0..N in those modes produces spurious warnings. if n_phi is None: # Infer from samples (only count indexed per-angle sites) _first_scaling = _SCALING_NAMES[0] contrast_params = [ p for p in self.param_names if p.startswith(f"{_first_scaling}_") ] n_phi = len(contrast_params) if contrast_params else 0 if n_phi > 0: # Individual mode: verify all indexed per-angle sites are present for i in range(n_phi): contrast_name = f"contrast_{i}" offset_name = f"offset_{i}" if contrast_name not in self.samples: warnings.append(f"Missing per-angle parameter: {contrast_name}") if offset_name not in self.samples: warnings.append(f"Missing per-angle parameter: {offset_name}") elif "contrast" not in self.samples and "contrast_0" not in self.samples: # Neither auto-mode site nor individual-mode site found; # only warn if the mode actually expects sampled contrast/offset. # analysis_mode vocabulary is "static"/"laminar_flow" (not "constant"), # so check per_angle_mode (which tracks the scaling strategy) instead. if self.analysis_mode not in ("constant", "constant_averaged") and getattr( self, "per_angle_mode", None ) not in ("constant", "constant_averaged"): warnings.append( "No contrast/offset sites found in posterior samples. " "Expected 'contrast'/'offset' (auto mode) or 'contrast_0'/'offset_0' " "(individual mode). Use constant/constant_averaged mode if scaling is fixed." ) # Check diagnostic values _r_hat_finite = [v for v in self.r_hat.values() if np.isfinite(v)] _ess_finite = [v for v in self.ess_bulk.values() if np.isfinite(v)] max_r_hat = max(_r_hat_finite) if _r_hat_finite else float("nan") min_ess = min(_ess_finite) if _ess_finite else 0.0 if max_r_hat > 1.1: warnings.append(f"High R-hat detected: {max_r_hat:.3f} > 1.1") if min_ess < DEFAULT_MIN_ESS: warnings.append(f"Low ESS detected: {min_ess:.0f} < {DEFAULT_MIN_ESS}") # Check for divergences if self.divergences > 0: total_transitions = self.num_shards * self.n_chains * self.n_samples div_rate = ( self.divergences / total_transitions if total_transitions > 0 else 0 ) if div_rate > 0.01: warnings.append(f"High divergence rate: {div_rate:.1%}") return warnings
[docs] def create_inference_data(mcmc_samples: MCMCSamples) -> az.InferenceData: """Create ArviZ InferenceData from MCMC samples. Parameters ---------- mcmc_samples : MCMCSamples Raw MCMC samples. Returns ------- az.InferenceData ArviZ-compatible data structure. """ if not HAS_ARVIZ: raise ImportError("ArviZ is required to create InferenceData") # Build posterior dictionary posterior_dict: dict[str, np.ndarray] = {} for name in mcmc_samples.param_names: if name in mcmc_samples.samples: # ArviZ expects (n_chains, n_samples) posterior_dict[name] = mcmc_samples.samples[name] # Map NumPyro extra_fields to ArviZ sample_stats conventions stats: dict[str, np.ndarray] | None = None if mcmc_samples.extra_fields: stats = {} for key, val in mcmc_samples.extra_fields.items(): if key == "potential_energy": # ArviZ plot_energy expects "energy" stats["energy"] = val elif "." in key: # xarray doesn't allow dots in variable names (e.g. adapt_state.step_size) stats[key.replace(".", "_")] = val else: stats[key] = val # Create InferenceData (ArviZ 1.0+ uses nested dict API) from_dict_data: dict[str, dict[str, np.ndarray]] = {"posterior": posterior_dict} if stats: from_dict_data["sample_stats"] = stats idata = az.from_dict(from_dict_data) return idata
[docs] def samples_dict_from_array( samples_array: np.ndarray, param_names: list[str], ) -> dict[str, np.ndarray]: """Convert samples array to dictionary. Parameters ---------- samples_array : np.ndarray Shape (n_chains, n_samples, n_params). param_names : list[str] Parameter names. Returns ------- dict[str, np.ndarray] Samples dictionary. """ samples_dict: dict[str, np.ndarray] = {} for i, name in enumerate(param_names): samples_dict[name] = samples_array[:, :, i] return samples_dict
[docs] def compute_fitted_c2( result: CMCResult, t1: np.ndarray, t2: np.ndarray, phi: np.ndarray, q: float, L: float, dt: float, analysis_mode: str, fixed_contrasts: np.ndarray | None = None, fixed_offsets: np.ndarray | None = None, ) -> tuple[np.ndarray, np.ndarray]: """Compute fitted C2 values from posterior mean. Parameters ---------- result : CMCResult CMC result with posterior samples. t1, t2, phi : np.ndarray Coordinates (pooled 1D). q, L, dt : float Physics parameters. analysis_mode : str Analysis mode. fixed_contrasts : np.ndarray | None Per-angle contrast array of shape (n_phi,) for ``constant`` and ``constant_averaged`` modes where contrast is not sampled. Required when neither ``contrast_0`` nor ``contrast`` appears in posterior samples. fixed_offsets : np.ndarray | None Per-angle offset array of shape (n_phi,) paired with ``fixed_contrasts``. Returns ------- tuple[np.ndarray, np.ndarray] (c2_fitted_mean, c2_fitted_std) from posterior. """ import jax import jax.numpy as jnp from homodyne.core.physics_cmc import compute_g1_total from homodyne.optimization.cmc.priors import LAMINAR_PARAMS, STATIC_PARAMS # Get posterior mean parameters stats = result.get_posterior_stats() # Extract physical parameters if analysis_mode == "laminar_flow": param_names = LAMINAR_PARAMS else: param_names = STATIC_PARAMS params = np.array([stats[name]["mean"] for name in param_names]) # Prepare unique phi for physics call (compute_g1_total expects unique phi) phi_unique = np.unique(phi) # Compute g1 with posterior mean g1 = compute_g1_total( jnp.array(params), jnp.array(t1), jnp.array(t2), jnp.array(phi_unique), q, L, dt, ) # Get per-angle contrast/offset n_phi = len(phi_unique) # Handle all per-angle modes: # individual: contrast_0, contrast_1, ... are sampled # auto: contrast (single) is sampled and broadcast # constant/constant_averaged: not sampled; caller supplies fixed_contrasts if "contrast_0" in stats: contrasts = np.array([stats[f"contrast_{i}"]["mean"] for i in range(n_phi)]) offsets = np.array([stats[f"offset_{i}"]["mean"] for i in range(n_phi)]) elif "contrast" in stats: # Auto mode: single sampled contrast/offset broadcast to all angles contrasts = np.full(n_phi, stats["contrast"]["mean"]) offsets = np.full(n_phi, stats["offset"]["mean"]) elif fixed_contrasts is not None and fixed_offsets is not None: # Constant/constant_averaged mode: contrast/offset are fixed, not sampled. # Caller must supply the pre-computed fixed arrays. contrasts = np.asarray(fixed_contrasts, dtype=float) offsets = np.asarray(fixed_offsets, dtype=float) if contrasts.shape != (n_phi,) or offsets.shape != (n_phi,): raise ValueError( f"fixed_contrasts/fixed_offsets must have shape ({n_phi},), " f"got {contrasts.shape} and {offsets.shape}" ) else: raise KeyError( f"Cannot find contrast parameters in posterior stats " f"(available keys: {sorted(stats.keys())}). " f"For constant/constant_averaged mode, pass fixed_contrasts and " f"fixed_offsets arrays (shape ({n_phi},)) from the original model_kwargs." ) # Map phi to indices using nearest-neighbor matching (consistent with # data_prep.extract_phi_info). Raw searchsorted silently assigns points to the # wrong angle when phi values have float precision differences. if n_phi <= 256: phi_indices = np.argmin( np.abs(phi[:, None] - phi_unique[None, :]), axis=1 ).astype(np.int32) else: idx = np.searchsorted(phi_unique, phi) idx = np.clip(idx, 0, n_phi - 1) left = np.clip(idx - 1, 0, n_phi - 1) use_left = np.abs(phi - phi_unique[left]) < np.abs(phi - phi_unique[idx]) phi_indices = np.where(use_left, left, idx).astype(np.int32) # Apply scaling: gather the right phi row from g1 per data point. # g1 shape is (n_phi, n_points); phi_indices maps each point to its phi row. contrast_per_point = contrasts[phi_indices] offset_per_point = offsets[phi_indices] g1_arr = np.array(g1) # (n_phi, n_points) g1_at_phi = g1_arr[phi_indices, np.arange(len(phi_indices))] # (n_points,) c2_fitted = contrast_per_point * g1_at_phi**2 + offset_per_point # D4: Compute uncertainty by batched vmap instead of a Python loop. # Previously: 100 sequential compute_g1_total calls (100 JAX dispatches). # Now: build (n_posterior_samples, n_params) batch array and call vmap once. n_posterior_samples = min(100, result.n_samples) # Build batch arrays: each row is one posterior draw. # Index draws round-robin across chains to match the original ordering. chain_indices = np.arange(n_posterior_samples) % result.n_chains within_chain_indices = np.arange(n_posterior_samples) // result.n_chains batched_params = np.stack( [ np.array( [ result.samples[name][chain_indices[i], within_chain_indices[i]] for name in param_names ] ) for i in range(n_posterior_samples) ] ) # shape: (n_posterior_samples, n_physical_params) # Handle all per-angle modes for posterior draws if "contrast_0" in result.samples: # Individual mode: per-angle contrast/offset sampled independently batched_contrasts = np.stack( [ np.array( [ result.samples[f"contrast_{j}"][ chain_indices[i], within_chain_indices[i] ] for j in range(n_phi) ] ) for i in range(n_posterior_samples) ] ) # shape: (n_posterior_samples, n_phi) batched_offsets = np.stack( [ np.array( [ result.samples[f"offset_{j}"][ chain_indices[i], within_chain_indices[i] ] for j in range(n_phi) ] ) for i in range(n_posterior_samples) ] ) # shape: (n_posterior_samples, n_phi) elif "contrast" in result.samples: # Auto mode: single sampled contrast/offset broadcast to all angles batched_contrasts = np.stack( [ np.full( n_phi, result.samples["contrast"][ chain_indices[i], within_chain_indices[i] ], ) for i in range(n_posterior_samples) ] ) # shape: (n_posterior_samples, n_phi) batched_offsets = np.stack( [ np.full( n_phi, result.samples["offset"][chain_indices[i], within_chain_indices[i]], ) for i in range(n_posterior_samples) ] ) # shape: (n_posterior_samples, n_phi) else: # Constant/constant_averaged mode: fixed values, no uncertainty over contrast. # Use the fixed arrays from the mean-computation step (already validated above). batched_contrasts = np.tile(contrasts, (n_posterior_samples, 1)) batched_offsets = np.tile(offsets, (n_posterior_samples, 1)) # vmap over the first axis (sample index); all other args are fixed. _t1_jnp = jnp.array(t1) _t2_jnp = jnp.array(t2) _phi_jnp = jnp.array(phi_unique) def _g1_single(single_params: jnp.ndarray) -> jnp.ndarray: return compute_g1_total(single_params, _t1_jnp, _t2_jnp, _phi_jnp, q, L, dt) batched_g1 = jax.vmap(_g1_single)(jnp.array(batched_params)) # batched_g1 shape: (n_posterior_samples, n_phi, n_points) # Apply per-angle contrast/offset scaling for each sample. # batched_contrasts[:, phi_indices] -> (n_posterior_samples, n_points) sample_contrasts_mapped = batched_contrasts[:, phi_indices] # (S, N) sample_offsets_mapped = batched_offsets[:, phi_indices] # (S, N) # Gather the right phi row per data point. # phi_indices is (n_points,); combine with a point index to select from dim-2. n_points = len(phi_indices) batched_g1_at_phi = batched_g1[:, phi_indices, np.arange(n_points)] # shape: (n_posterior_samples, n_points) c2_samples_arr = np.array( sample_contrasts_mapped * np.array(batched_g1_at_phi) ** 2 + sample_offsets_mapped ) # (n_posterior_samples, n_points) c2_fitted_std = np.nanstd(c2_samples_arr, axis=0) return c2_fitted, c2_fitted_std