Source code for homodyne.core.theory

"""Theory Computation Engine for Homodyne
==========================================

High-level interface to theoretical calculations for homodyne scattering analysis.
This module provides user-friendly wrappers around the JAX backend functions
with proper error handling, validation, and computational management.

The theory engine handles:
- Model selection and parameter management
- Efficient computation orchestration
- Memory management for large datasets
- Error handling and validation
- Performance monitoring and optimization hints
"""

from __future__ import annotations

from collections.abc import Callable
from typing import Any

import numpy as np

from homodyne.utils.logging import get_logger, log_performance

# Import with fallback handling
try:
    from homodyne.core.jax_backend import (
        batch_chi_squared,
        jax_available,
        jnp,
    )
    from homodyne.core.physics_utils import safe_len
except ImportError:
    jax_available = False
    safe_len: Callable[..., int] = len  # type: ignore[no-redef]
    logger = get_logger(__name__)
    logger.error("Could not import JAX backend - theory computations disabled")

from homodyne.core.models import create_model
from homodyne.core.physics import PhysicsConstants

logger = get_logger(__name__)


[docs] class TheoryEngine: """High-level interface for theoretical homodyne calculations. Manages model selection, parameter validation, and efficient computation orchestration for homodyne scattering analysis. """
[docs] def __init__(self, analysis_mode: str = "laminar_flow"): """Initialize theory engine with specified analysis mode. Args: analysis_mode: "static" or "laminar_flow" """ self.analysis_mode = analysis_mode self.model = create_model(analysis_mode) self._validate_backend() logger.info(f"Theory engine initialized for {analysis_mode}")
def _validate_backend(self) -> None: """Validate that computational backend is available.""" if not jax_available: logger.warning("JAX backend not available - computations will be slower")
[docs] @log_performance(threshold=0.01) def compute_g1( self, params: np.ndarray, t1: np.ndarray, t2: np.ndarray, phi: np.ndarray, q: float, L: float, dt: float | None = None, ) -> Any: """Compute g1 correlation function. Args: params: Physical parameters t1, t2: Time grids phi: Angle grid q: Wave vector magnitude L: Sample-detector distance dt: Time step (if None, will be estimated from t1) Returns: g1 correlation function """ # Validate inputs self._validate_computation_inputs(params, q, L) # Convert to JAX arrays if needed if jax_available: params_jax: Any = jnp.asarray(params, dtype=jnp.float64) t1_jax: Any = jnp.asarray(t1, dtype=jnp.float64) t2_jax: Any = jnp.asarray(t2, dtype=jnp.float64) phi_jax: Any = jnp.asarray(phi, dtype=jnp.float64) else: params_jax = params t1_jax = t1 t2_jax = t2 phi_jax = phi dt_arg: Any = dt return self.model.compute_g1(params_jax, t1_jax, t2_jax, phi_jax, q, L, dt_arg)
[docs] @log_performance(threshold=0.01) def compute_g2( self, params: np.ndarray, t1: np.ndarray, t2: np.ndarray, phi: np.ndarray, q: float, L: float, contrast: float, offset: float, dt: float | None = None, ) -> Any: """Compute g2 with scaled fitting: g₂ = offset + contrast × [g₁]² This is the core equation for homodyne analysis. Args: params: Physical parameters t1, t2: Time grids phi: Angle grid q: Wave vector magnitude L: Sample-detector distance contrast: Contrast parameter offset: Baseline offset dt: Time step in seconds. Required — unlike compute_g1, there is no dt estimation fallback for g2 because sinc_prefactor requires an exact dt. Pass dt explicitly from configuration. Returns: g2 correlation function Raises: ValueError: If dt is None. """ # Fail fast at the API boundary with a clear message rather than letting # CombinedModel.compute_g2 raise a cryptic TypeError from an inner layer. if dt is None: raise ValueError( "TheoryEngine.compute_g2 requires an explicit dt (time step in seconds). " "Pass dt from configuration, e.g. dt=config.dt. " "Unlike compute_g1, there is no estimation fallback for compute_g2." ) # Validate inputs self._validate_computation_inputs(params, q, L) self._validate_scaling_parameters(contrast, offset) # Convert to JAX arrays if needed if jax_available: params_jax: Any = jnp.asarray(params, dtype=jnp.float64) t1_jax: Any = jnp.asarray(t1, dtype=jnp.float64) t2_jax: Any = jnp.asarray(t2, dtype=jnp.float64) phi_jax: Any = jnp.asarray(phi, dtype=jnp.float64) else: params_jax = params t1_jax = t1 t2_jax = t2 phi_jax = phi dt_arg: Any = dt return self.model.compute_g2( params_jax, t1_jax, t2_jax, phi_jax, q, L, contrast, offset, dt_arg )
[docs] @log_performance(threshold=0.05) def compute_chi_squared( self, params: np.ndarray, data: np.ndarray, sigma: np.ndarray, t1: np.ndarray, t2: np.ndarray, phi: np.ndarray, q: float, L: float, contrast: float, offset: float, ) -> float: """Compute chi-squared goodness of fit. Args: params: Physical parameters data: Experimental correlation data sigma: Measurement uncertainties t1, t2: Time grids phi: Angle grid q: Wave vector magnitude L: Sample-detector distance contrast, offset: Scaling parameters Returns: Chi-squared value """ # Validate inputs self._validate_computation_inputs(params, q, L) self._validate_scaling_parameters(contrast, offset) self._validate_data_inputs(data, sigma, t1, t2, phi) # Convert to JAX arrays if needed if jax_available: params_jax: Any = jnp.asarray(params, dtype=jnp.float64) data_jax: Any = jnp.asarray(data, dtype=jnp.float64) sigma_jax: Any = jnp.asarray(sigma, dtype=jnp.float64) t1_jax: Any = jnp.asarray(t1, dtype=jnp.float64) t2_jax: Any = jnp.asarray(t2, dtype=jnp.float64) phi_jax: Any = jnp.asarray(phi, dtype=jnp.float64) else: params_jax = params data_jax = data sigma_jax = sigma t1_jax = t1 t2_jax = t2 phi_jax = phi return self.model.compute_chi_squared( params_jax, data_jax, sigma_jax, t1_jax, t2_jax, phi_jax, q, L, contrast, offset, )
[docs] @log_performance(threshold=0.1) def batch_computation( self, params_batch: np.ndarray, data: np.ndarray, sigma: np.ndarray, t1: np.ndarray, t2: np.ndarray, phi: np.ndarray, q: float, L: float, contrast: float, offset: float, ) -> Any: """Compute chi-squared for multiple parameter sets efficiently. Leverages JAX vectorization for optimal performance. Args: params_batch: Array of parameter sets (n_sets, n_params) data: Experimental correlation data sigma: Measurement uncertainties t1, t2: Time grids phi: Angle grid q: Wave vector magnitude L: Sample-detector distance contrast, offset: Scaling parameters Returns: Chi-squared values for each parameter set """ # Validate batch input if params_batch.ndim != 2: raise ValueError("params_batch must be 2D array (n_sets, n_params)") n_sets, n_params = params_batch.shape if n_params != self.model.n_params: raise ValueError( f"Expected {self.model.n_params} parameters, got {n_params}", ) logger.debug(f"Batch computation for {n_sets} parameter sets") # Convert to JAX arrays if needed if jax_available: params_batch_jax: Any = jnp.asarray(params_batch, dtype=jnp.float64) data_jax: Any = jnp.asarray(data, dtype=jnp.float64) sigma_jax: Any = jnp.asarray(sigma, dtype=jnp.float64) t1_jax: Any = jnp.asarray(t1, dtype=jnp.float64) t2_jax: Any = jnp.asarray(t2, dtype=jnp.float64) phi_jax: Any = jnp.asarray(phi, dtype=jnp.float64) return batch_chi_squared( params_batch_jax, data_jax, sigma_jax, t1_jax, t2_jax, phi_jax, q, L, contrast, offset, ) else: # Fallback: loop over parameter sets results = [] for params in params_batch: chi2 = self.compute_chi_squared( params, data, sigma, t1, t2, phi, q, L, contrast, offset, ) results.append(chi2) return np.array(results)
[docs] def estimate_computation_cost( self, t1: np.ndarray, t2: np.ndarray, phi: np.ndarray, ) -> dict[str, Any]: """Estimate computational cost for given data dimensions. Helps with performance planning and memory management. Args: t1, t2: Time grids phi: Angle grid Returns: Cost estimation dictionary """ n_time_pairs = safe_len(t1) * safe_len(t2) phi_array = np.atleast_1d(phi) n_angles = safe_len(phi_array) n_total_points = n_time_pairs * n_angles # Rough performance estimates (operations per point) ops_per_point = { "static": 10, # Diffusion only "static_isotropic": 10, # Diffusion only "static_anisotropic": 15, # Diffusion with anisotropy "laminar_flow": 50, # Full model with shear } base_ops = ops_per_point.get(self.analysis_mode, 50) total_ops = n_total_points * base_ops # Memory estimates (bytes per point, rough) memory_per_point = 8 * 4 # ~4 float64 values per point total_memory_mb = (n_total_points * memory_per_point) / (1024**2) return { "n_time_pairs": n_time_pairs, "n_angles": n_angles, "n_total_points": n_total_points, "estimated_operations": total_ops, "estimated_memory_mb": total_memory_mb, "analysis_mode": self.analysis_mode, "backend": "JAX" if jax_available else "NumPy", "performance_tier": self._classify_performance_tier(total_ops), }
def _classify_performance_tier(self, operations: int) -> str: """Classify computation as light, medium, or heavy.""" if operations < 1e6: return "light" elif operations < 1e8: return "medium" else: return "heavy" def _validate_computation_inputs( self, params: np.ndarray, q: float, L: float ) -> None: """Validate core computation inputs.""" # Skip parameter validation inside JIT compilation to avoid JAX tracer errors. # q and L are Python floats (not tracers), so we CAN validate them in JAX mode. if not jax_available: # Parameter validation only works with concrete (non-traced) values params_any: Any = params if not self.model.validate_parameters(params_any): logger.warning( "Parameters outside recommended bounds - results may be unreliable", ) # Experimental setup validation (q, L are Python floats, safe in all modes) if q <= 0: raise ValueError(f"Wave vector q must be positive, got {q}") if L <= 0: raise ValueError(f"Sample-detector distance L must be positive, got {L}") # Physical reasonableness checks if not (PhysicsConstants.Q_MIN_TYPICAL <= q <= PhysicsConstants.Q_MAX_TYPICAL): logger.warning( f"q = {q:.2e} outside typical range - check experimental setup", ) # L is in Angstroms - check reasonable range # Typical range: 100,000 Å (10 μm) to 100,000,000 Å (10 mm) # Note: 1 Å = 1e-10 m, so 1e5 Å = 10 μm, 1e8 Å = 10 mm. if not (1e5 <= L <= 1e8): logger.warning( f"L = {L:.1f} AA outside typical range [1e5, 1e8] AA (10 um to 10 mm) - check experimental setup", ) def _validate_scaling_parameters(self, contrast: float, offset: float) -> None: """Validate scaling parameters.""" # Skip validation only for JAX tracers (inside @jit), not all JAX mode if jax_available: import jax.core as jax_core if isinstance(contrast, jax_core.Tracer) or isinstance( # type: ignore[unreachable] offset, # type: ignore[unreachable] jax_core.Tracer, ): return # type: ignore[unreachable] if contrast <= 0: raise ValueError(f"Contrast must be positive, got {contrast}") if offset < 0: logger.warning(f"Negative offset {offset} - check baseline correction") def _validate_data_inputs( self, data: np.ndarray, sigma: np.ndarray, t1: np.ndarray, t2: np.ndarray, phi: np.ndarray, ) -> None: """Validate experimental data inputs.""" # Skip validation only for JAX tracers (inside @jit), not all JAX mode if jax_available: import jax.core as jax_core if isinstance(data, jax_core.Tracer): # type: ignore[unreachable] return # type: ignore[unreachable] # Shape consistency phi_array = np.atleast_1d(phi) expected_shape = (safe_len(phi_array), safe_len(t1), safe_len(t2)) if data.shape != expected_shape: raise ValueError( f"Data shape {data.shape} doesn't match expected {expected_shape}", ) if sigma.shape != expected_shape: raise ValueError( f"Sigma shape {sigma.shape} doesn't match expected {expected_shape}", ) # Data quality checks if np.any(sigma <= 0): raise ValueError("All uncertainties must be positive") if np.any(~np.isfinite(data)): raise ValueError("Data contains non-finite values") if np.any(~np.isfinite(sigma)): raise ValueError("Uncertainties contain non-finite values")
[docs] def get_model_info(self) -> dict[str, Any]: """Get comprehensive model and engine information.""" info = self.model.get_model_info() info.update( { "theory_engine_version": "2.0", "backend_available": jax_available, "supports_batch_computation": jax_available, }, ) return info
def __repr__(self) -> str: backend = "JAX" if jax_available else "NumPy" return f"TheoryEngine(mode='{self.analysis_mode}', backend={backend})"
# Convenience functions for direct computation
[docs] def compute_g2_theory( params: np.ndarray, t1: np.ndarray, t2: np.ndarray, phi: np.ndarray, q: float, L: float, contrast: float, offset: float, analysis_mode: str = "laminar_flow", dt: float | None = None, ) -> Any: """Direct computation of g2 theory. Convenience wrapper for one-off calculations. Note: Creates a new TheoryEngine per call (includes model init overhead, logger I/O, and jnp.array construction). For repeated calls (e.g. parameter sweeps), create a single TheoryEngine instance and call engine.compute_g2() directly. Args: params: Physical parameters t1, t2: Time grids phi: Angle grid q: Wave vector magnitude L: Sample-detector distance contrast, offset: Scaling parameters analysis_mode: Analysis mode dt: Time step in seconds. Required for g2 (no estimation fallback). Returns: g2 correlation function """ engine = TheoryEngine(analysis_mode) return engine.compute_g2(params, t1, t2, phi, q, L, contrast, offset, dt)
[docs] def compute_chi2_theory( params: np.ndarray, data: np.ndarray, sigma: np.ndarray, t1: np.ndarray, t2: np.ndarray, phi: np.ndarray, q: float, L: float, contrast: float, offset: float, analysis_mode: str = "laminar_flow", ) -> float: """Direct computation of chi-squared with minimal overhead. Args: params: Physical parameters data: Experimental data sigma: Uncertainties t1, t2: Time grids phi: Angle grid q: Wave vector magnitude L: Sample-detector distance contrast, offset: Scaling parameters analysis_mode: Analysis mode Returns: Chi-squared value """ engine = TheoryEngine(analysis_mode) return engine.compute_chi_squared( params, data, sigma, t1, t2, phi, q, L, contrast, offset, )
# Export main classes and functions __all__ = [ "TheoryEngine", "compute_g2_theory", "compute_chi2_theory", ]