from __future__ import annotations
import logging
from typing import Any, cast
import jax
import jax.numpy as jnp
import numpy as np
from homodyne.core.physics_nlsq import compute_g2_scaled
from homodyne.utils.logging import get_logger
[docs]
class StratifiedResidualFunction:
"""
Residual function that respects angle-stratified chunk structure.
This class wraps the model's residual computation to work with stratified chunks,
ensuring that each chunk contains all phi angles. This is critical for per-angle
scaling parameters to have non-zero gradients.
The function is designed to work with NLSQ's least_squares() function, which calls
the residual function at each optimization iteration.
Attributes:
chunks: List of angle-stratified data chunks
model: TheoryEngine instance for computing residuals
per_angle_scaling: Whether per-angle scaling is enabled
logger: Logger instance for diagnostics
n_chunks: Number of stratified chunks
n_total_points: Total number of data points across all chunks
compute_chunk_jit: JIT-compiled chunk residual computation
"""
[docs]
def __init__(
self,
stratified_data: Any,
per_angle_scaling: bool,
physical_param_names: list[str],
logger: logging.Logger | None = None,
):
"""
Initialize the stratified residual function.
Args:
stratified_data: Object with .chunks attribute containing angle-stratified chunks.
Each chunk must have: phi, t1, t2, g2, q, L, dt attributes.
stratified_data.sigma contains the full 3D sigma array (metadata).
per_angle_scaling: Whether per-angle scaling parameters are used.
physical_param_names: List of physical parameter names (e.g., ['D0', 'alpha', 'D_offset'])
logger: Optional logger for diagnostics.
Raises:
ValueError: If stratified_data.chunks is empty or invalid.
"""
self.chunks = stratified_data.chunks
sigma_array = np.asarray(stratified_data.sigma, dtype=np.float64)
# M2: Only keep JAX array; numpy copy was labelled "for legacy paths"
# but self.sigma is never referenced outside __init__.
self._sigma_jax = jnp.asarray(sigma_array)
del sigma_array # Allow GC of the intermediate numpy copy
self.per_angle_scaling = per_angle_scaling
self.physical_param_names = physical_param_names
self.logger = logger or get_logger(__name__)
if not self.chunks:
raise ValueError("stratified_data.chunks is empty")
self.n_chunks = len(self.chunks)
self.n_total_points = sum(len(chunk.g2) for chunk in self.chunks)
# Determine number of unique angles from first chunk
self.n_phi = len(np.unique(self.chunks[0].phi))
# Determine expected parameter structure
# Per-angle: [contrast_0, ..., contrast_{n-1}, offset_0, ..., offset_{n-1}, *physical]
# Legacy: [contrast, offset, *physical]
if per_angle_scaling:
self.n_scaling_params = 2 * self.n_phi
else:
self.n_scaling_params = 2
self.n_physical_params = len(physical_param_names)
self.n_total_params = self.n_scaling_params + self.n_physical_params
# Pre-compute unique values for each chunk (avoid jnp.unique in JIT)
self._precompute_chunk_metadata()
# Setup JIT-compiled functions
self._setup_jax_functions()
# Pre-convert chunk arrays to JAX (avoid jnp.asarray in loop)
self._preconvert_chunk_arrays()
self.logger.info(
f"StratifiedResidualFunction initialized: "
f"{self.n_chunks} chunks, {self.n_total_points:,} total points, "
f"n_phi={self.n_phi}, per_angle_scaling={self.per_angle_scaling}, "
f"n_scaling_params={self.n_scaling_params}, n_physical_params={self.n_physical_params}"
)
def _precompute_chunk_metadata(self) -> None:
"""
Pre-compute GLOBAL unique values from ALL chunks to avoid jnp.unique() in JIT.
This method extracts unique phi, t1, t2 values from ALL chunks combined
and stores them as metadata. Each chunk gets the SAME global unique arrays
to ensure correct flat indexing when accessing sigma_full array.
This avoids ConcretizationTypeError when using jnp.unique() inside
JIT-compiled functions.
CRITICAL: Must use global unique values, not per-chunk subsets, because
sigma_full dimensions are based on ALL data points across all chunks.
Performance Optimization (Spec 006 - FR-001):
Also pre-computes flat indices for each chunk to avoid jnp.searchsorted
calls inside the JIT-compiled residual function. This provides ~15-20%
per-iteration speedup.
"""
# Extract GLOBAL unique values from ALL chunks combined
# This ensures grid dimensions match sigma_full dimensions
all_phi = np.concatenate([chunk.phi for chunk in self.chunks])
all_t1 = np.concatenate([chunk.t1 for chunk in self.chunks])
all_t2 = np.concatenate([chunk.t2 for chunk in self.chunks])
global_phi_unique = jnp.sort(jnp.unique(jnp.asarray(all_phi)))
global_t1_unique = jnp.sort(jnp.unique(jnp.asarray(all_t1)))
global_t2_unique = jnp.sort(jnp.unique(jnp.asarray(all_t2)))
# Store global dimensions for flat index computation
self._n_t1_global = len(global_t1_unique)
self._n_t2_global = len(global_t2_unique)
self.logger.debug(
f"Global unique values extracted from all chunks: "
f"{len(global_phi_unique)} phi, "
f"{self._n_t1_global} t1, "
f"{self._n_t2_global} t2"
)
# Store SAME global unique arrays for ALL chunks
# This ensures flat indexing calculations use correct dimensions
self.chunk_metadata = []
self._precomputed_flat_indices = []
self._precomputed_t1_indices = [] # v2.14.2+: for diagonal masking
self._precomputed_t2_indices = [] # v2.14.2+: for diagonal masking
self._precomputed_t1_values = [] # R5 fix: actual float t1 values for masking
self._precomputed_t2_values = [] # R5 fix: actual float t2 values for masking
for chunk in self.chunks:
metadata = {
"phi_unique": global_phi_unique, # Same for all chunks
"t1_unique": global_t1_unique, # Same for all chunks
"t2_unique": global_t2_unique, # Same for all chunks
}
self.chunk_metadata.append(metadata)
# Pre-compute flat indices for this chunk (FR-001 optimization)
# v2.14.2+: Also returns t1/t2 indices for diagonal masking
flat_indices, t1_indices, t2_indices = self._compute_flat_indices(
phi=chunk.phi,
t1=chunk.t1,
t2=chunk.t2,
phi_unique=global_phi_unique,
t1_unique=global_t1_unique,
t2_unique=global_t2_unique,
)
self._precomputed_flat_indices.append(flat_indices)
self._precomputed_t1_indices.append(t1_indices)
self._precomputed_t2_indices.append(t2_indices)
# R5 fix: store actual float time values for value-based diagonal masking
self._precomputed_t1_values.append(jnp.asarray(chunk.t1))
self._precomputed_t2_values.append(jnp.asarray(chunk.t2))
self.logger.debug(
f"Pre-computed flat indices for {len(self._precomputed_flat_indices)} chunks"
)
def _compute_flat_indices(
self,
phi: np.ndarray,
t1: np.ndarray,
t2: np.ndarray,
phi_unique: jnp.ndarray,
t1_unique: jnp.ndarray,
t2_unique: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
Compute flat indices for mapping chunk points to global grid positions.
This helper method computes the 1D flat indices that map each point
in a chunk to its position in the flattened 3D grid (phi × t1 × t2).
Also returns t1_indices and t2_indices for diagonal masking (v2.14.2+).
Performance Note (Spec 006 - FR-001):
This method is called once during __init__ to pre-compute indices,
avoiding expensive jnp.searchsorted calls during every optimization
iteration. Expected speedup: 15-20% per iteration.
Parameters
----------
phi : np.ndarray
Phi values for this chunk
t1 : np.ndarray
t1 values for this chunk
t2 : np.ndarray
t2 values for this chunk
phi_unique : jnp.ndarray
Global unique phi values (sorted)
t1_unique : jnp.ndarray
Global unique t1 values (sorted)
t2_unique : jnp.ndarray
Global unique t2 values (sorted)
Returns
-------
tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]
- flat_indices: Flat indices for this chunk's points into the global grid
- t1_indices: t1 indices for diagonal masking
- t2_indices: t2 indices for diagonal masking
"""
# Convert to JAX arrays for searchsorted
phi_jax = jnp.asarray(phi)
t1_jax = jnp.asarray(t1)
t2_jax = jnp.asarray(t2)
# Find indices in the sorted unique arrays.
# Cast to int64 BEFORE multiplication to prevent int32 overflow.
# jnp.searchsorted returns int32; for large datasets (n_phi=100,
# n_t1=5000, n_t2=5000) the product 99*25_000_000=2.475B exceeds
# int32 max (2.147B), silently wrapping to a negative index.
phi_indices = jnp.searchsorted(phi_unique, phi_jax).astype(jnp.int64)
t1_indices = jnp.searchsorted(t1_unique, t1_jax).astype(jnp.int64)
t2_indices = jnp.searchsorted(t2_unique, t2_jax).astype(jnp.int64)
# Convert to flat grid indices: phi * (n_t1 * n_t2) + t1 * n_t2 + t2
n_t1 = len(t1_unique)
n_t2 = len(t2_unique)
flat_indices = phi_indices * (n_t1 * n_t2) + t1_indices * n_t2 + t2_indices
return flat_indices, t1_indices, t2_indices
def _setup_jax_functions(self) -> None:
"""
Pre-compile JAX functions for performance.
This method sets up JIT-compiled versions of the residual computation
to maximize performance during optimization.
"""
# Note: _compute_chunk_residuals_raw is no longer used — the hot path
# is _call_jax_vectorized via _setup_vmap_functions. Skip dead JIT
# compilation to save ~0.2-0.5s at init.
self.compute_chunk_jit = None
def _preconvert_chunk_arrays(self) -> None:
"""
Pre-convert chunk arrays to JAX arrays during initialization.
This avoids repeated jnp.asarray() calls inside the optimization loop,
providing ~10-15% speedup by eliminating array conversion overhead.
Performance Optimization (Spec 006 - FR-004, FR-005):
Also creates concatenated arrays (phi_all, t1_all, t2_all, g2_all) and
chunk_boundaries for device-side iteration with jax.lax.scan.
"""
self.chunks_jax = []
for chunk in self.chunks:
chunk_jax = {
"phi": jnp.asarray(chunk.phi),
"t1": jnp.asarray(chunk.t1),
"t2": jnp.asarray(chunk.t2),
"g2": jnp.asarray(chunk.g2),
"q": float(chunk.q),
"L": float(chunk.L),
"dt": float(chunk.dt) if chunk.dt is not None else None,
}
self.chunks_jax.append(chunk_jax)
self.logger.debug(f"Pre-converted {len(self.chunks_jax)} chunks to JAX arrays")
# FR-004, FR-005: Create concatenated arrays for device-side iteration
# This enables jax.lax.scan instead of Python loops
self._concatenate_chunk_data()
def _concatenate_chunk_data(self) -> None:
"""
Concatenate all chunk data into single arrays for device-side iteration.
Performance Optimization (Spec 006 - FR-004, FR-005):
Instead of iterating over chunks in Python, we concatenate all data
and use chunk_boundaries for index lookup. This enables jax.lax.scan
for device-side iteration, reducing Python interpreter overhead.
Attributes Created:
g2_all: Concatenated g2 observations from all chunks
flat_indices_all: Concatenated pre-computed flat indices
t1_indices_all: Concatenated t1 indices for diagonal masking (v2.14.2+)
t2_indices_all: Concatenated t2 indices for diagonal masking (v2.14.2+)
chunk_boundaries: Array of boundary indices [0, len(chunk0), len(chunk0)+len(chunk1), ...]
_chunk_q: q value (same for all chunks)
_chunk_L: L value (same for all chunks)
_chunk_dt: dt value (same for all chunks)
"""
# Concatenate g2 observations
g2_list = [cast(jnp.ndarray, chunk_jax["g2"]) for chunk_jax in self.chunks_jax]
self.g2_all = jnp.concatenate(g2_list, axis=0)
# Concatenate pre-computed flat indices
self.flat_indices_all = jnp.concatenate(self._precomputed_flat_indices, axis=0)
# v2.14.2+: Concatenate t1/t2 indices for diagonal masking
self.t1_indices_all = jnp.concatenate(self._precomputed_t1_indices, axis=0)
self.t2_indices_all = jnp.concatenate(self._precomputed_t2_indices, axis=0)
# R5 fix: concatenate actual float time values for value-based masking
self.t1_values_all = jnp.concatenate(self._precomputed_t1_values, axis=0)
self.t2_values_all = jnp.concatenate(self._precomputed_t2_values, axis=0)
# Compute chunk boundaries for index lookup
chunk_sizes = [
len(cast(jnp.ndarray, chunk_jax["g2"])) for chunk_jax in self.chunks_jax
]
boundaries = [0]
for size in chunk_sizes:
boundaries.append(boundaries[-1] + size)
# Use int64 to prevent overflow when cumulative point count exceeds
# int32 max (2.147B) for large in-core datasets.
self.chunk_boundaries = jnp.array(boundaries, dtype=jnp.int64)
# Store common chunk parameters (assumed same for all chunks)
self._chunk_q = cast(float, self.chunks_jax[0]["q"])
self._chunk_L = cast(float, self.chunks_jax[0]["L"])
self._chunk_dt = cast(float | None, self.chunks_jax[0]["dt"])
# Store global unique arrays (same for all chunks, from first metadata)
self._phi_unique = self.chunk_metadata[0]["phi_unique"]
self._t1_unique = self.chunk_metadata[0]["t1_unique"]
self._t2_unique = self.chunk_metadata[0]["t2_unique"]
self.logger.debug(
f"Concatenated chunk data: {len(self.g2_all):,} total points, "
f"{len(self.chunk_boundaries) - 1} chunks, "
f"boundaries={list(self.chunk_boundaries[:5])}..."
)
# Build stable vmap functions now that chunk metadata is available
self._setup_vmap_functions()
# M1: Free intermediate per-chunk data now that everything is
# concatenated into device-side arrays (g2_all, flat_indices_all, etc.).
# The hot path (_call_jax_vectorized) uses only the concatenated arrays.
# Per-chunk lists were only needed by the dead _call_jax_chunked fallback.
# For a 10M-point dataset this frees ~160+ MB of duplicate JAX arrays.
del self._precomputed_flat_indices
del self._precomputed_t1_indices
del self._precomputed_t2_indices
del self._precomputed_t1_values
del self._precomputed_t2_values
del self.chunks_jax
# Cache diagnostics before freeing original numpy chunks (~320 MB for 10M pts).
# validate_chunk_structure() and get_diagnostics() use these cached values
# after chunks are freed; callers no longer need chunks to exist.
self._cached_n_chunks = self.n_chunks
self._cached_chunk_sizes = [len(c.g2) for c in self.chunks]
self._cached_chunk_angle_counts = [len(np.unique(c.phi)) for c in self.chunks]
self._cached_n_angles = self.n_phi
# Run structural validation inline while chunks are still available.
# This preserves the validation guarantee — after del self.chunks the
# window is closed and validate_chunk_structure() returns the cached result.
self._validate_chunk_structure_inline()
del self.chunks
def _validate_chunk_structure_inline(self) -> None:
"""Run chunk-structure validation while self.chunks is still available.
Called from _concatenate_chunk_data() immediately before del self.chunks.
Raises ValueError on failure so the constructor fails fast rather than
producing a silently corrupt residual function. On success, records the
result in self._chunk_structure_valid so validate_chunk_structure() can
return the cached outcome after chunks have been freed.
"""
expected_angles = set(np.unique(np.round(self.chunks[0].phi, decimals=6)))
n_expected = len(expected_angles)
self.logger.debug(
f"Inline chunk structure validation: {self.n_chunks} chunks, "
f"{n_expected} expected angles per chunk"
)
for i, chunk in enumerate(self.chunks):
chunk_angles = set(np.unique(np.round(chunk.phi, decimals=6)))
if chunk_angles != expected_angles:
missing = expected_angles - chunk_angles
extra = chunk_angles - expected_angles
error_msg = f"Chunk {i} has inconsistent angles:\n"
if missing:
error_msg += f" Missing: {missing}\n"
if extra:
error_msg += f" Extra: {extra}\n"
raise ValueError(error_msg)
if len(chunk.g2) == 0:
raise ValueError(f"Chunk {i} has no data points")
n_points = len(chunk.g2)
if not (len(chunk.phi) == len(chunk.t1) == len(chunk.t2) == n_points):
raise ValueError(
f"Chunk {i} has inconsistent array shapes: "
f"phi={len(chunk.phi)}, t1={len(chunk.t1)}, "
f"t2={len(chunk.t2)}, g2={len(chunk.g2)}"
)
self._chunk_structure_valid = True
self.logger.debug("Inline chunk structure validation passed")
def _setup_vmap_functions(self) -> None:
"""Create vmap-wrapped g2 computation functions once during init.
Avoids re-creating closures on every NLSQ iteration (fixes #20-analog
for residual.py). The closures capture stable values (t1_unique, q, L, dt)
while physical_params is passed as an explicit argument.
"""
if self._chunk_dt is None:
self.logger.warning(
"StratifiedResidualFunction: dt not set (chunk_dt is None); "
"using dt=0.001 s as fallback. Physics factors may be incorrect."
)
dt_value = 0.001
else:
dt_value = self._chunk_dt
# Per-angle scaling: physical_params, phi, contrast, offset all vary
def _g2_per_angle(
physical_params: jnp.ndarray,
phi_val: float,
contrast_val: float,
offset_val: float,
) -> jnp.ndarray:
return jnp.squeeze(
compute_g2_scaled(
params=physical_params,
t1=self._t1_unique,
t2=self._t2_unique,
phi=phi_val,
q=self._chunk_q,
L=self._chunk_L,
contrast=contrast_val,
offset=offset_val,
dt=dt_value,
),
axis=0,
)
self._vmap_g2_per_angle = jax.vmap(_g2_per_angle, in_axes=(None, 0, 0, 0))
# Scalar scaling: contrast/offset are scalars, only phi varies
def _g2_scalar(
physical_params: jnp.ndarray,
contrast_val: float,
offset_val: float,
phi_val: float,
) -> jnp.ndarray:
return jnp.squeeze(
compute_g2_scaled(
params=physical_params,
t1=self._t1_unique,
t2=self._t2_unique,
phi=phi_val,
q=self._chunk_q,
L=self._chunk_L,
contrast=contrast_val,
offset=offset_val,
dt=dt_value,
),
axis=0,
)
self._vmap_g2_scalar = jax.vmap(_g2_scalar, in_axes=(None, None, None, 0))
def _compute_chunk_residuals_raw(
self,
g2_obs: jnp.ndarray,
sigma_full: jnp.ndarray,
params_all: jnp.ndarray,
phi_unique: jnp.ndarray,
t1_unique: jnp.ndarray,
t2_unique: jnp.ndarray,
flat_indices: jnp.ndarray,
t1_indices: jnp.ndarray,
t2_indices: jnp.ndarray,
q: float,
L: float,
dt: float | None,
) -> jnp.ndarray:
"""DEPRECATED: Dead code path -- not called from any live execution path.
Retained for reference. The live path is _call_jax_vectorized.
_call_jax routes exclusively to _call_jax_vectorized, and
_call_jax_chunked raises RuntimeError immediately.
"""
raise RuntimeError(
"_compute_chunk_residuals_raw is deprecated. Use _call_jax_vectorized."
)
def _call_jax(self, params: jnp.ndarray) -> jnp.ndarray:
"""JAX-native residuals for use in JIT/Jacobian contexts.
Performance Optimization (Spec 006 - FR-004, FR-005):
Uses vectorized computation with concatenated arrays instead of Python
loop over chunks. Computes theory grid ONCE and extracts all values
using pre-computed flat indices, eliminating per-chunk overhead.
This replaces the previous loop-based implementation:
- Old: For each chunk, compute full g2 grid, extract chunk indices
- New: Compute g2 grid once, extract ALL indices in single operation
Expected speedup: 20-40% for chunked datasets.
"""
return self._call_jax_vectorized(params)
def _call_jax_vectorized(self, params: jnp.ndarray) -> jnp.ndarray:
"""Vectorized residual computation using concatenated arrays.
Performance Optimization (Spec 006 - FR-004, FR-005):
Instead of iterating over chunks in Python, computes theoretical g2
grid ONCE and uses concatenated flat_indices_all to extract all
values in a single vectorized operation.
This eliminates:
1. Python loop overhead
2. Redundant g2 theory grid computation (was computed per-chunk)
3. Multiple small kernel launches
Args:
params: Parameter array [scaling_params, physical_params]
Returns:
Weighted residuals for ALL data points
"""
params_jax = jnp.asarray(params)
sigma_full = self._sigma_jax
# Extract scaling and physical parameters
if self.per_angle_scaling:
contrast = params_jax[: self.n_phi]
offset = params_jax[self.n_phi : 2 * self.n_phi]
physical_params = params_jax[2 * self.n_phi :]
else:
contrast = params_jax[0]
offset = params_jax[1]
physical_params = params_jax[2:]
# Compute theoretical g2 grid ONCE for all data
# (Previously computed redundantly per-chunk)
# Uses pre-built vmap functions (created once in _setup_vmap_functions)
# to avoid re-creating closures on every NLSQ iteration.
if self.per_angle_scaling:
g2_theory_grid = self._vmap_g2_per_angle(
physical_params, self._phi_unique, contrast, offset
)
else:
g2_theory_grid = self._vmap_g2_scalar(
physical_params, contrast, offset, self._phi_unique
)
# Note: diagonal correction is not applied to the theory grid here.
# Diagonal points (t1==t2) are masked to zero residuals below,
# making any theory value at those points irrelevant to the fit.
# Flatten and extract theory values for ALL points at once
# (Single indexing operation instead of per-chunk)
g2_theory_flat = g2_theory_grid.reshape(-1)
g2_theory_all = g2_theory_flat[self.flat_indices_all]
# Get sigma values for ALL points (single indexing operation)
sigma_flat = sigma_full.reshape(-1)
sigma_all = sigma_flat[self.flat_indices_all]
# Compute ALL residuals — mask out zero-sigma points entirely
EPS = 1e-10
valid_sigma = sigma_all > EPS
safe_sigma = jnp.where(valid_sigma, sigma_all, 1.0)
residuals = jnp.where(
valid_sigma, (self.g2_all - g2_theory_all) / safe_sigma, 0.0
)
# v2.14.2+ / R5 fix: Mask diagonal points (t1 == t2) to zero.
# Use actual float time values, not grid indices, to correctly detect
# t1==t2 even when two different time points happen to map to the same
# index bin (which is impossible but safe) and, critically, to mirror
# the JIT path in residual_jit.py which compares values not indices.
# Diagonal points are autocorrelation artifacts, not physics.
residuals = jnp.where(
jnp.abs(self.t1_values_all - self.t2_values_all) > 1e-15, residuals, 0.0
)
return residuals
def _call_jax_chunked(self, params: jnp.ndarray) -> jnp.ndarray:
"""Original chunk-based residual computation — REMOVED.
Per-chunk data was freed after concatenation (M1 memory optimization).
The vectorized path (_call_jax_vectorized) is used exclusively.
"""
raise RuntimeError(
"_call_jax_chunked is unavailable: per-chunk data was freed "
"after concatenation. Use _call_jax_vectorized instead."
)
[docs]
def jax_residual(self, params: jnp.ndarray) -> jnp.ndarray:
return self._call_jax(params)
def __call__(self, params: np.ndarray) -> np.ndarray:
params_jax = jnp.asarray(params)
residuals_jax = self._call_jax(params_jax)
return np.asarray(residuals_jax)
[docs]
def validate_chunk_structure(self) -> bool:
"""
Validate that all chunks contain all phi angles.
This is a critical validation to ensure per-angle parameter gradients
will be non-zero. If any chunk is missing an angle, the gradient for
that angle's parameters will be zero, causing optimization failure.
Returns:
True if validation passes
Raises:
ValueError: If any chunk is missing angles or has inconsistent structure
"""
if not hasattr(self, "chunks"):
# Chunks were freed by _concatenate_chunk_data() after inline validation
# (_validate_chunk_structure_inline) already ran during __init__.
# Return the cached result — True means construction succeeded.
self.logger.info(
"Chunk structure validation passed (cached -- validated during build)"
)
return getattr(self, "_chunk_structure_valid", True)
# Chunks still live (unusual path, e.g. external test bypass): validate now.
# Get expected angles from first chunk
expected_angles = set(np.unique(np.round(self.chunks[0].phi, decimals=6)))
n_expected = len(expected_angles)
self.logger.info(
f"Validating chunk structure: {self.n_chunks} chunks, "
f"{n_expected} expected angles per chunk"
)
# Validate each chunk
for i, chunk in enumerate(self.chunks):
chunk_angles = set(np.unique(np.round(chunk.phi, decimals=6)))
# Check angle completeness
if chunk_angles != expected_angles:
missing = expected_angles - chunk_angles
extra = chunk_angles - expected_angles
error_msg = f"Chunk {i} has inconsistent angles:\n"
if missing:
error_msg += f" Missing: {missing}\n"
if extra:
error_msg += f" Extra: {extra}\n"
raise ValueError(error_msg)
# Check for valid data
if len(chunk.g2) == 0:
raise ValueError(f"Chunk {i} has no data points")
# Check array shapes match
# Note: sigma is stored at parent level (self.sigma), not in chunks
n_points = len(chunk.g2)
if not (len(chunk.phi) == len(chunk.t1) == len(chunk.t2) == n_points):
raise ValueError(
f"Chunk {i} has inconsistent array shapes: "
f"phi={len(chunk.phi)}, t1={len(chunk.t1)}, "
f"t2={len(chunk.t2)}, g2={len(chunk.g2)}"
)
self.logger.info("Chunk structure validation passed")
return True
[docs]
def get_diagnostics(self) -> dict[str, Any]:
"""
Get diagnostic information about the residual function.
Returns:
Dictionary containing:
- n_chunks: Number of chunks
- n_total_points: Total data points
- n_angles: Number of unique phi angles
- per_angle_scaling: Whether per-angle scaling is enabled
- chunk_sizes: List of points per chunk
- chunk_angle_counts: List of angles per chunk
- min_chunk_size: Minimum chunk size
- max_chunk_size: Maximum chunk size
- mean_chunk_size: Mean chunk size
"""
# Use cached arrays when chunks have been freed (normal post-init path).
# _cached_chunk_sizes and _cached_chunk_angle_counts are set in
# _concatenate_chunk_data() immediately before del self.chunks.
if hasattr(self, "_cached_chunk_sizes"):
chunk_sizes = self._cached_chunk_sizes
chunk_angle_counts = self._cached_chunk_angle_counts
n_angles = self._cached_n_angles
else:
# Chunks still live — compute directly (unusual path)
chunk_sizes = [len(chunk.g2) for chunk in self.chunks]
chunk_angle_counts = [len(np.unique(chunk.phi)) for chunk in self.chunks]
n_angles = len(np.unique(self.chunks[0].phi))
diagnostics = {
"n_chunks": self.n_chunks,
"n_total_points": self.n_total_points,
"n_angles": n_angles,
"per_angle_scaling": self.per_angle_scaling,
"chunk_sizes": chunk_sizes,
"chunk_angle_counts": chunk_angle_counts,
"min_chunk_size": min(chunk_sizes),
"max_chunk_size": max(chunk_sizes),
"mean_chunk_size": np.mean(chunk_sizes),
}
return diagnostics
[docs]
def log_diagnostics(self) -> None:
"""Log diagnostic information for monitoring."""
diag = self.get_diagnostics()
self.logger.info(
f"StratifiedResidualFunction diagnostics:\n"
f" Chunks: {diag['n_chunks']}\n"
f" Total points: {diag['n_total_points']:,}\n"
f" Angles: {diag['n_angles']}\n"
f" Per-angle scaling: {diag['per_angle_scaling']}\n"
f" Chunk sizes: min={diag['min_chunk_size']:,}, "
f"max={diag['max_chunk_size']:,}, mean={diag['mean_chunk_size']:.0f}\n"
f" Angle counts per chunk: {set(diag['chunk_angle_counts'])}"
)
[docs]
def create_stratified_residual_function(
stratified_data: Any,
per_angle_scaling: bool,
physical_param_names: list[str],
logger: logging.Logger | None = None,
validate: bool = True,
) -> StratifiedResidualFunction:
"""
Factory function to create and validate a stratified residual function.
This is a convenience function that creates a StratifiedResidualFunction,
optionally validates its structure, and logs diagnostics.
Args:
stratified_data: Object with .chunks attribute containing angle-stratified chunks
per_angle_scaling: Whether per-angle scaling parameters are used
physical_param_names: List of physical parameter names (e.g., ['D0', 'alpha', 'D_offset'])
logger: Optional logger for diagnostics
validate: Whether to validate chunk structure (recommended)
Returns:
Validated StratifiedResidualFunction instance
Raises:
ValueError: If validation fails
Example:
>>> residual_fn = create_stratified_residual_function(
... stratified_data=stratified_data,
... per_angle_scaling=True,
... physical_param_names=['D0', 'alpha', 'D_offset'],
... validate=True
... )
>>> residual_fn.log_diagnostics()
"""
residual_fn = StratifiedResidualFunction(
stratified_data=stratified_data,
per_angle_scaling=per_angle_scaling,
physical_param_names=physical_param_names,
logger=logger,
)
if validate:
residual_fn.validate_chunk_structure()
residual_fn.log_diagnostics()
return residual_fn