"""Fit Computation Utilities for NLSQ Results.
This module provides functions for computing theoretical fits from NLSQ
optimization results. Extracted from cli/commands.py for better organization.
Extracted from cli/commands.py as part of refactoring (Dec 2025).
"""
from __future__ import annotations
from typing import Any
import jax
import jax.numpy as jnp
import numpy as np
from homodyne.core.jax_backend import compute_g2_scaled
from homodyne.utils.logging import get_logger
logger = get_logger(__name__)
# Performance Optimization (Spec 006 - FR-007, FR-007a): Vectorized computation
[docs]
def compute_g2_batch(
physical_params: jnp.ndarray,
t1: jnp.ndarray,
t2: jnp.ndarray,
phi_angles: jnp.ndarray,
q: float,
L: float,
dt: float,
contrast: float = 1.0,
offset: float = 1.0,
) -> jnp.ndarray:
"""Compute g2 for all phi angles in a single vectorized operation.
Performance Optimization (Spec 006 - FR-007):
Uses jax.vmap to compute g2 for all angles in parallel instead of
sequential Python loop. Expected speedup: 10-20x for post-fitting.
Parameters
----------
physical_params : jnp.ndarray
Physical parameters array
t1 : jnp.ndarray
t1 time values, shape (n_t1,)
t2 : jnp.ndarray
t2 time values, shape (n_t2,)
phi_angles : jnp.ndarray
Phi angles in radians, shape (n_phi,)
q : float
Wave vector magnitude
L : float
Sample-to-detector distance
dt : float
Time step
contrast : float
Contrast parameter (default 1.0 for raw computation)
offset : float
Offset parameter (default 1.0 for raw computation)
Returns
-------
jnp.ndarray
g2 values, shape (n_phi, n_t1, n_t2)
"""
n_t1 = len(t1)
n_t2 = len(t2)
# Define single-angle computation
def compute_single_angle(phi_val):
g2 = compute_g2_scaled(
params=physical_params,
t1=t1,
t2=t2,
phi=jnp.array([phi_val]),
q=q,
L=L,
contrast=contrast,
offset=offset,
dt=dt,
)
# Reshape to ensure consistent (n_t1, n_t2) output
# compute_g2_scaled may return different shapes, so flatten and reshape
return g2.reshape(n_t1, n_t2)
# Note: vmap wrapper is recreated per call since the closure captures varying params.
# This is acceptable for post-processing (not in optimization hot path).
compute_all_angles = jax.vmap(compute_single_angle)
return compute_all_angles(phi_angles)
[docs]
def compute_g2_batch_with_per_angle_scaling(
physical_params: jnp.ndarray,
t1: jnp.ndarray,
t2: jnp.ndarray,
phi_angles: jnp.ndarray,
q: float,
L: float,
dt: float,
contrasts: jnp.ndarray,
offsets: jnp.ndarray,
) -> jnp.ndarray:
"""Compute g2 with per-angle contrast/offset in single vectorized operation.
Performance Optimization (Spec 006 - FR-007a):
Extends compute_g2_batch for per-angle scaling parameters.
Parameters
----------
physical_params : jnp.ndarray
Physical parameters array
t1, t2 : jnp.ndarray
Time values
phi_angles : jnp.ndarray
Phi angles in radians, shape (n_phi,)
q, L, dt : float
Experimental parameters
contrasts : jnp.ndarray
Per-angle contrasts, shape (n_phi,)
offsets : jnp.ndarray
Per-angle offsets, shape (n_phi,)
Returns
-------
jnp.ndarray
g2 values with scaling applied, shape (n_phi, n_t1, n_t2)
"""
n_t1 = len(t1)
n_t2 = len(t2)
def compute_single_angle_scaled(phi_val, contrast_val, offset_val):
g2 = compute_g2_scaled(
params=physical_params,
t1=t1,
t2=t2,
phi=jnp.array([phi_val]),
q=q,
L=L,
contrast=contrast_val,
offset=offset_val,
dt=dt,
)
# Reshape to ensure consistent (n_t1, n_t2) output
return g2.reshape(n_t1, n_t2)
# Note: vmap wrapper is recreated per call since the closure captures varying params.
# This is acceptable for post-processing (not in optimization hot path).
compute_all_angles = jax.vmap(compute_single_angle_scaled, in_axes=(0, 0, 0))
return compute_all_angles(phi_angles, contrasts, offsets)
[docs]
def solve_lstsq_batch(
theory_batch: jnp.ndarray,
exp_batch: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Batch least squares solving for all angles.
Performance Optimization (Spec 006 - FR-008):
Vectorized least squares using jax.vmap for all angles simultaneously.
Parameters
----------
theory_batch : jnp.ndarray
Theory values flattened, shape (n_phi, n_t1 * n_t2)
exp_batch : jnp.ndarray
Experimental values flattened, shape (n_phi, n_t1 * n_t2)
Returns
-------
tuple[jnp.ndarray, jnp.ndarray]
(contrasts, offsets) each shape (n_phi,)
"""
def solve_single(theory_flat, exp_flat):
A = jnp.column_stack([theory_flat, jnp.ones_like(theory_flat)])
solution, _, _, _ = jnp.linalg.lstsq(A, exp_flat, rcond=None)
return solution[0], solution[1] # contrast, offset
solve_all = jax.vmap(solve_single, in_axes=(0, 0))
contrasts, offsets = solve_all(theory_batch, exp_batch)
return contrasts, offsets
[docs]
def normalize_analysis_mode(
mode: str | None,
n_params: int,
n_angles: int,
) -> str:
"""Resolve analysis mode, inferring from parameter counts if needed.
Args:
mode: Explicit mode or None
n_params: Number of parameters
n_angles: Number of angles
Returns:
Normalized mode: 'static' or 'laminar_flow'
"""
if mode:
mode_lower = mode.lower()
if mode_lower in {"static", "static_isotropic"}:
return "static"
if mode_lower == "laminar_flow":
return "laminar_flow"
# Infer from parameter counts (legacy scalar vs per-angle layout)
candidates = {
"static": 3,
"laminar_flow": 7,
}
for candidate_mode, n_phys in candidates.items():
if n_params in {n_phys + 2, 2 * n_angles + n_phys}:
return candidate_mode
# Default to static for backward compatibility
logger.debug(
"Unable to infer analysis_mode from params=%s angles=%s; defaulting to static",
n_params,
n_angles,
)
return "static"
[docs]
def get_physical_param_count(analysis_mode: str) -> int:
"""Get number of physical parameters for analysis mode.
Args:
analysis_mode: 'static' or 'laminar_flow'
Returns:
Number of physical parameters
Raises:
ValueError: If mode is unknown
"""
if analysis_mode == "static":
return 3 # D0, alpha, D_offset
elif analysis_mode == "laminar_flow":
return 7 # D0, alpha, D_offset, gamma_dot_t0, beta, gamma_dot_t_offset, phi0
else:
raise ValueError(
f"Unknown analysis_mode: '{analysis_mode}'. Expected 'static' or 'laminar_flow'"
)
[docs]
def compute_theoretical_fits(
result: Any,
data: dict[str, Any],
metadata: dict[str, Any],
*,
analysis_mode: str | None = None,
include_solver_surface: bool = True,
) -> dict[str, Any]:
"""Compute theoretical fits with per-angle least squares scaling.
Generates theoretical correlation functions using optimized parameters,
then applies per-angle scaling (contrast, offset) via least squares fitting
to match experimental intensities.
Args:
result: NLSQ optimization result with physical parameters
data: Experimental data with phi_angles_list, c2_exp, t1, t2
metadata: Metadata with L, dt, q for theoretical computation
analysis_mode: Optional analysis mode override
include_solver_surface: Whether to include solver surface in output
Returns:
Dictionary with keys:
- 'c2_theoretical_raw': Raw theoretical fits (n_angles, n_t1, n_t2)
- 'c2_theoretical_scaled': Scaled fits (n_angles, n_t1, n_t2)
- 'c2_solver_scaled': Solver surface (if requested)
- 'per_angle_scaling': Post-hoc lstsq scaling params (n_angles, 2)
- 'per_angle_scaling_solver': Original solver scaling params
- 'residuals': Exp - scaled fit (n_angles, n_t1, n_t2)
- 'scalar_per_angle_expansion': Whether scalar expansion was used
Raises:
ValueError: If q is missing or parameter count is invalid
"""
phi_angles = np.asarray(data["phi_angles_list"])
c2_exp = np.asarray(data["c2_exp"])
t1 = np.asarray(data["t1"])
t2 = np.asarray(data["t2"])
# Convert 2D meshgrids to 1D if needed
if t1.ndim == 2:
t1 = t1[:, 0]
if t2.ndim == 2:
t2 = t2[0, :]
n_params = len(result.parameters)
n_angles = len(phi_angles)
# Normalize analysis mode
normalized_mode = normalize_analysis_mode(
analysis_mode or getattr(result, "analysis_mode", None),
n_params,
n_angles,
)
# Extract parameters
fitted_contrasts, fitted_offsets, physical_params, scalar_expansion = (
extract_parameters_from_result(result.parameters, n_angles, normalized_mode)
)
logger.info(
f"Per-angle scaling: {n_angles} angles, using FITTED scaling parameters from NLSQ optimization"
)
logger.debug(
f"Extracted fitted parameters - "
f"contrasts: mean={np.nanmean(fitted_contrasts):.4f}, "
f"offsets: mean={np.nanmean(fitted_offsets):.4f}"
)
# Extract metadata
L = metadata["L"]
dt_value = metadata.get("dt")
if dt_value is not None:
dt = float(dt_value)
else:
# dt is required for the J(t1,t2) numerical integration used by
# compute_g2_scaled(). A wrong dt produces incorrect theory curves and
# misleading post-fit visualisations. Raise rather than silently fall
# back to an arbitrary 0.1 s default.
raise ValueError(
"dt (frame exposure time) is required for compute_theoretical_fits() "
"but was not found in metadata. Pass metadata with a valid 'dt' key."
)
q = metadata["q"]
if q is None:
raise ValueError("q (wavevector) is required but was not found")
logger.info(
f"Computing theoretical fits for {len(phi_angles)} angles using L={L:.1f} AA, q={q:.6f} AA^-1"
)
# Performance Optimization (Spec 006 - FR-007, FR-008):
# Vectorized computation replaces sequential per-angle loop.
# Expected speedup: 10-20x for post-fitting analysis.
# Convert to JAX arrays
t1_jax = jnp.array(t1)
t2_jax = jnp.array(t2)
phi_jax = jnp.array(phi_angles)
params_jax = jnp.array(physical_params)
# Compute RAW theory for ALL angles at once (FR-007)
c2_theoretical_raw = compute_g2_batch(
physical_params=params_jax,
t1=t1_jax,
t2=t2_jax,
phi_angles=phi_jax,
q=float(q),
L=float(L),
dt=float(dt),
contrast=1.0,
offset=1.0,
)
c2_theoretical_raw = np.asarray(c2_theoretical_raw) # Shape: (n_angles, n_t1, n_t2)
# Compute solver surface for ALL angles at once (FR-007a) if requested
if include_solver_surface:
c2_solver_surface = compute_g2_batch_with_per_angle_scaling(
physical_params=params_jax,
t1=t1_jax,
t2=t2_jax,
phi_angles=phi_jax,
q=float(q),
L=float(L),
dt=float(dt),
contrasts=jnp.array(fitted_contrasts),
offsets=jnp.array(fitted_offsets),
)
c2_solver_surface = np.asarray(c2_solver_surface)
else:
c2_solver_surface = None
# Batch least-squares scaling (FR-008)
# Flatten theory and exp for batch lstsq: shape (n_angles, n_t1 * n_t2)
theory_batch_flat = jnp.array(c2_theoretical_raw.reshape(n_angles, -1))
exp_batch_flat = jnp.array(c2_exp.reshape(n_angles, -1))
# Solve all angles at once
contrasts_lstsq, offsets_lstsq = solve_lstsq_batch(
theory_batch_flat, exp_batch_flat
)
contrasts_lstsq = np.asarray(contrasts_lstsq)
offsets_lstsq = np.asarray(offsets_lstsq)
# Apply scaling: c2_scaled = contrast * c2_raw + offset
# Broadcasting: (n_angles, 1, 1) * (n_angles, n_t1, n_t2) + (n_angles, 1, 1)
c2_theoretical_fitted = (
contrasts_lstsq[:, None, None] * c2_theoretical_raw
+ offsets_lstsq[:, None, None]
)
# Build per-angle scaling array
per_angle_scaling = np.column_stack((contrasts_lstsq, offsets_lstsq))
solver_scaling = np.column_stack((fitted_contrasts, fitted_offsets))
# Log statistics
logger.debug(
f"Batch lstsq - contrasts: mean={np.nanmean(contrasts_lstsq):.4f}, "
f"offsets: mean={np.nanmean(offsets_lstsq):.4f}"
)
logger.info(
"Note: lstsq contrast/offset values may differ from NLSQ-optimized values. "
"lstsq re-fits scaling to raw theory (contrast=1, offset=1) post-hoc; "
"NLSQ values are authoritative as they are jointly optimized with physical parameters."
)
residuals = c2_exp - c2_theoretical_fitted
logger.info(f"Computed theoretical fits for {len(phi_angles)} angles")
return {
"c2_theoretical_raw": c2_theoretical_raw,
"c2_theoretical_scaled": c2_theoretical_fitted,
"c2_solver_scaled": c2_solver_surface,
"per_angle_scaling": per_angle_scaling,
"per_angle_scaling_solver": solver_scaling,
"residuals": residuals,
"scalar_per_angle_expansion": scalar_expansion,
}