Source code for homodyne.core.models

"""Physical Models for Homodyne
================================

Object-oriented interface to the physical models implemented in the JAX backend.
Provides structured access to diffusion, shear, and combined models with
parameter validation and configuration management.

Homodyne Model
--------------------------
The measured intensity correlation uses per-angle scaling:

    c2(φ, t₁, t₂) = offset + contrast × [c1(φ, t₁, t₂)]²

with a separable field correlation function:

    c1(φ, t₁, t₂) = c1_diff(t₁, t₂) × c1_shear(φ, t₁, t₂)

Diffusion contribution:

    c1_diff(t₁, t₂) = exp[-(q² / 2) ∫|t₂ - t₁| D(t') dt']

Shear contribution:

    c1_shear(φ, t₁, t₂) = [sinc(Φ(φ, t₁, t₂))]²
    Φ(φ, t₁, t₂) = (1 / 2π) · q · L · cos(φ₀ - φ) · ∫|t₂ - t₁| γ̇(t') dt'

Time-dependent transport coefficients:

    D(t) = D₀ · t^α + D_offset
    γ̇(t) = γ̇₀ · t^β + γ̇_offset

Parameter sets:
- Static mode (3 params): D₀, α, D_offset (γ̇₀, β, γ̇_offset, φ₀ fixed/irrelevant)
- Laminar flow (7 params): D₀, α, D_offset, γ̇₀, β, γ̇_offset, φ₀

Experimental parameters:
- q: scattering wavevector magnitude [Å⁻¹]
- L: gap/characteristic length [Å]
- φ: scattering angle [degrees]
- dt: frame time step [s]
"""

from abc import ABC, abstractmethod
from typing import Any

import numpy as np

from homodyne.core.jax_backend import (
    compute_chi_squared,
    compute_g1_diffusion,
    compute_g1_shear,
    compute_g1_total,
    compute_g2_scaled,
    jax_available,
    jnp,
)
from homodyne.core.model_mixins import (
    BenchmarkingMixin,
    GradientCapabilityMixin,
    OptimizationRecommendationMixin,
)
from homodyne.core.physics import validate_parameters
from homodyne.core.physics_utils import safe_len
from homodyne.utils.logging import get_logger, log_calls

logger = get_logger(__name__)


[docs] class PhysicsModelBase(ABC): """Abstract base class for all physical models. Defines the interface that all models must implement and provides common functionality for parameter management and validation. """
[docs] def __init__(self, name: str, parameter_names: list[str]): """Initialize base model. Args: name: Model name for identification parameter_names: List of parameter names in order """ self.name = name self.parameter_names = parameter_names self.n_params = len(parameter_names) self._bounds = None self._default_values = None
[docs] @abstractmethod def compute_g1( self, params: jnp.ndarray, t1: jnp.ndarray, t2: jnp.ndarray, phi: jnp.ndarray, q: float, L: float, dt: float | None = None, ) -> jnp.ndarray: """Compute g1 correlation function for this model."""
[docs] @abstractmethod def get_parameter_bounds(self) -> list[tuple[float, float]]: """Get parameter bounds for optimization."""
[docs] @abstractmethod def get_default_parameters(self) -> jnp.ndarray: """Get default parameter values."""
[docs] def validate_parameters(self, params: jnp.ndarray) -> bool: """Validate parameter values against bounds and constraints.""" return validate_parameters(params, self.get_parameter_bounds()) # type: ignore[arg-type]
[docs] def get_parameter_dict(self, params: jnp.ndarray) -> dict[str, float]: """Convert parameter array to named dictionary.""" # Ensure params is at least 1D to avoid 0D array indexing issues if jax_available and hasattr(params, "ndim"): # Convert JAX arrays to NumPy for safe indexing params_np = np.atleast_1d(np.asarray(params)) else: params_np = np.atleast_1d(params) params_len = safe_len(params_np) if params_len != self.n_params: raise ValueError(f"Expected {self.n_params} parameters, got {params_len}") # Convert to regular Python floats only when safe to do so try: # Try converting to float - will fail if in JIT context return { name: float(val) for name, val in zip(self.parameter_names, params_np, strict=False) } except (TypeError, ValueError, AttributeError): # In JIT context, keep as JAX arrays return dict(zip(self.parameter_names, params_np, strict=False))
def __repr__(self) -> str: return ( f"{self.__class__.__name__}(name='{self.name}', n_params={self.n_params})" )
[docs] class DiffusionModel(PhysicsModelBase): """Anomalous diffusion model: D(t) = D₀ t^α + D_offset Parameters: - D₀: Reference diffusion coefficient [Ų/s] - α: Diffusion time-dependence exponent [-] - D_offset: Baseline diffusion [Ų/s] Physical interpretation: - α = 0: Normal diffusion (Brownian motion) - α > 0: Super-diffusion (enhanced mobility) - α < 0: Sub-diffusion (restricted mobility) - D_offset: Residual diffusion at t=0 """
[docs] def __init__(self) -> None: super().__init__( name="anomalous_diffusion", parameter_names=["D0", "alpha", "D_offset"], )
[docs] @log_calls(include_args=False) def compute_g1( self, params: jnp.ndarray, t1: jnp.ndarray, t2: jnp.ndarray, phi: jnp.ndarray, q: float, L: float, dt: float | None = None, ) -> jnp.ndarray: """Compute diffusion contribution to g1. g₁_diff = exp[-q²/2 ∫|t₂-t₁| D(t')dt'] """ # Skip validation inside JIT to avoid JAX tracer boolean conversion errors # if not self.validate_parameters(params): # logger.warning("Invalid diffusion parameters - results may be unreliable") # Pass q directly without conversion to avoid JAX tracing issues # The backend functions handle any necessary conversions return compute_g1_diffusion(params, t1, t2, q, dt)
[docs] def get_parameter_bounds(self) -> list[tuple[float, float]]: """Standard bounds for diffusion parameters.""" return [ (100.0, 1e5), # D0: 100 to 1e5 Ų/s (-2.0, 2.0), # alpha: -2 to 2 (-1e5, 1e5), # D_offset: -1e5 to 1e5 Ų/s ]
[docs] def get_default_parameters(self) -> jnp.ndarray: """Default values for typical XPCS measurements.""" return jnp.array([100.0, 0.0, 10.0]) # Normal diffusion with small offset
[docs] class ShearModel(PhysicsModelBase): """Time-dependent shear model: γ̇(t) = γ̇₀ t^β + γ̇_offset Parameters: - γ̇₀: Reference shear rate [s⁻¹] - β: Shear rate time-dependence exponent [-] - γ̇_offset: Baseline shear rate [s⁻¹] - φ₀: Angular offset parameter [degrees] Physical interpretation: - β = 0: Constant shear rate (steady shear) - β > 0: Increasing shear rate with time - β < 0: Decreasing shear rate with time - φ₀: Preferred flow direction angle """
[docs] def __init__(self) -> None: super().__init__( name="time_dependent_shear", parameter_names=["gamma_dot_t0", "beta", "gamma_dot_t_offset", "phi0"], )
[docs] @log_calls(include_args=False) def compute_g1( self, params: jnp.ndarray, t1: jnp.ndarray, t2: jnp.ndarray, phi: jnp.ndarray, q: float, L: float, dt: float | None = None, ) -> jnp.ndarray: """Compute shear contribution to g1. g₁_shear = [sinc(Φ)]² where Φ = (qL/2π) cos(φ₀-φ) ∫|t₂-t₁| γ̇(t') dt' """ # Skip validation inside JIT to avoid JAX tracer boolean conversion errors # if not self.validate_parameters(params): # logger.warning("Invalid shear parameters - results may be unreliable") # Pass q directly without conversion to avoid JAX tracing issues # The backend functions handle any necessary conversions # Create full parameter array with dummy diffusion parameters full_params = jnp.concatenate([jnp.array([100.0, 0.0, 10.0]), params]) return compute_g1_shear(full_params, t1, t2, phi, q, L, dt) # type: ignore[arg-type]
[docs] def get_parameter_bounds(self) -> list[tuple[float, float]]: """Standard bounds for shear parameters.""" return [ (1e-6, 0.5), # gamma_dot_t0: 1e-6 to 0.5 s⁻¹ (-2.0, 2.0), # beta: -2 to 2 (-0.1, 0.1), # gamma_dot_t_offset: -0.1 to 0.1 s⁻¹ (-10.0, 10.0), # phi0: -10 to 10 degrees ]
[docs] def get_default_parameters(self) -> jnp.ndarray: """Default values for typical shear flow.""" return jnp.array([0.01, 0.0, 0.0, 0.0]) # Constant shear, zero offset
[docs] class CombinedModel( PhysicsModelBase, GradientCapabilityMixin, BenchmarkingMixin, OptimizationRecommendationMixin, ): """Combined diffusion + shear model for complete homodyne analysis. This is the full model used for laminar flow analysis with both anomalous diffusion and time-dependent shear. Parameters (7 total): - D₀, α, D_offset: Diffusion parameters - γ̇₀, β, γ̇_offset: Shear parameters - φ₀: Angular offset parameter For static analysis, only the first 3 diffusion parameters are used. Mixin capabilities: - GradientCapabilityMixin: gradient/Hessian access with backend selection - BenchmarkingMixin: performance benchmarking and accuracy validation - OptimizationRecommendationMixin: optimization guidance and model info """
[docs] def __init__(self, analysis_mode: str = "laminar_flow"): """Initialize combined model. Args: analysis_mode: "static" or "laminar_flow" """ self.analysis_mode = analysis_mode if analysis_mode in ("static", "static_isotropic", "static_anisotropic"): # Static mode: only diffusion parameters parameter_names = ["D0", "alpha", "D_offset"] name = "static_diffusion" else: # Laminar flow mode: all parameters parameter_names = [ "D0", "alpha", "D_offset", "gamma_dot_t0", "beta", "gamma_dot_t_offset", "phi0", ] name = "laminar_flow_complete" super().__init__(name=name, parameter_names=parameter_names) # Create component models self.diffusion_model = DiffusionModel() self.shear_model = ShearModel()
[docs] @log_calls(include_args=False) def compute_g1( self, params: jnp.ndarray, t1: jnp.ndarray, t2: jnp.ndarray, phi: jnp.ndarray, q: float, L: float, dt: float | None = None, ) -> jnp.ndarray: """Compute total g1 = g1_diffusion × g1_shear.""" # Skip validation inside JIT to avoid JAX tracer boolean conversion errors # if not self.validate_parameters(params): # logger.warning( # "Invalid combined model parameters - results may be unreliable" # ) # Pass q directly without conversion to avoid JAX tracing issues # The backend functions handle any necessary conversions if self.analysis_mode.startswith("static"): # Static mode: only diffusion, no shear if logger.isEnabledFor(10): # DEBUG logger.debug( "CombinedModel.compute_g1: calling compute_g1_diffusion with params.shape=%s", params.shape, ) return compute_g1_diffusion(params, t1, t2, q, dt) else: # Laminar flow mode: full model if logger.isEnabledFor(10): # DEBUG logger.debug( "CombinedModel.compute_g1: calling compute_g1_total with params.shape=%s, t1.shape=%s, t2.shape=%s, phi.shape=%s, q=%s, L=%s, dt=%s", params.shape, t1.shape, t2.shape, phi.shape, q, L, dt, ) try: result = compute_g1_total(params, t1, t2, phi, q, L, dt) # Note: Skip debug logging of result values when traced by JAX # (jax.vmap/jit creates BatchTracer objects that can't be formatted) if logger.isEnabledFor(10): # DEBUG level try: # Use nanmin/nanmax: g1 result may contain NaN from failed shards. min_val = float(jnp.nanmin(result)) max_val = float(jnp.nanmax(result)) logger.debug( f"CombinedModel.compute_g1: compute_g1_total completed, result.shape={result.shape}, min={min_val:.6e}, max={max_val:.6e}", ) except (TypeError, ValueError): # Likely a JAX tracer object during tracing logger.debug( f"CombinedModel.compute_g1: compute_g1_total completed, result.shape={result.shape}", ) return result # P2-R6-07: Narrow broad except — realistic failures from compute_g1_total # are ValueError (bad params), RuntimeError (XLA), or ArithmeticError. # Bare raise preserves the original traceback for all exception types. except (ValueError, RuntimeError, ArithmeticError) as e: logger.error( f"CombinedModel.compute_g1: compute_g1_total failed with error: {e}", ) logger.error("CombinedModel.compute_g1: traceback:", exc_info=True) raise
[docs] def compute_g1_batch( self, params: jnp.ndarray, t1_batch: jnp.ndarray, t2_batch: jnp.ndarray, phi_batch: jnp.ndarray, q: float, L: float, dt: float | None = None, ) -> jnp.ndarray: """Compute g1 for a batch of points using vmap. Performance Optimization (Spec 001 - FR-006, T041): Vectorized computation using jax.vmap for batched point-wise g1 calculation, replacing Python loops. Parameters ---------- params : jnp.ndarray Physical parameters array t1_batch : jnp.ndarray Batch of t1 values, shape (n_points,) t2_batch : jnp.ndarray Batch of t2 values, shape (n_points,) phi_batch : jnp.ndarray Batch of phi values, shape (n_points,) q : float Scattering wave vector magnitude [Å⁻¹] L : float Sample-detector distance (stator_rotor_gap) [Å] dt : float, optional Time step from configuration [s] Returns ------- jnp.ndarray Batch of g1 values, shape (n_points,) """ import jax # Cache the vmap'd function on first call to avoid JIT retrace overhead. # The closure captures `self` — same instance across calls preserves # function identity for JAX's trace cache. if not hasattr(self, "_cached_g1_vmap"): def _compute_g1_single( params_inner: Any, t1_val: Any, t2_val: Any, phi_val: Any, q_inner: Any, L_inner: Any, dt_inner: Any, ) -> Any: g1 = self.compute_g1( params_inner, jnp.array([t1_val]), jnp.array([t2_val]), jnp.array([phi_val]), q_inner, L_inner, dt_inner, ) return g1.flatten()[0] self._cached_g1_vmap = jax.vmap( _compute_g1_single, in_axes=(None, 0, 0, 0, None, None, None), ) result: jnp.ndarray = self._cached_g1_vmap( params, t1_batch, t2_batch, phi_batch, q, L, dt ) return result
[docs] @log_calls(include_args=False) def compute_g2( self, params: jnp.ndarray, t1: jnp.ndarray, t2: jnp.ndarray, phi: jnp.ndarray, q: float, L: float, contrast: float, offset: float, dt: float, ) -> jnp.ndarray: """Compute g2 with scaled fitting: g₂ = offset + contrast × [g₁]² Parameters ---------- params : jnp.ndarray Physical parameters array t1, t2 : jnp.ndarray Time grids for correlation calculation phi : jnp.ndarray Scattering angles in degrees q : float Scattering wave vector magnitude [Å⁻¹] L : float Sample-detector distance (stator_rotor_gap) [Å] contrast : float Contrast parameter (β in literature) offset : float Baseline offset dt : float Time step from configuration [s] (REQUIRED). Fallback estimation has been removed for safety. Returns ------- jnp.ndarray g2 correlation function Raises ------ TypeError If dt is None (no longer accepts None) ValueError If dt <= 0 or not finite """ # Validate dt before passing to backend if dt is None: raise TypeError( "dt parameter is required and cannot be None. " "Pass dt explicitly from configuration.", ) # Pass to functional backend # The backend functions handle additional validation return compute_g2_scaled(params, t1, t2, phi, q, L, contrast, offset, dt)
[docs] @log_calls(include_args=False) def compute_chi_squared( self, params: jnp.ndarray, data: jnp.ndarray, sigma: jnp.ndarray, t1: jnp.ndarray, t2: jnp.ndarray, phi: jnp.ndarray, q: float, L: float, contrast: float, offset: float, ) -> float: """Compute chi-squared goodness of fit.""" result: float = compute_chi_squared( params, data, sigma, t1, t2, phi, q, L, contrast, offset, ) return result
[docs] def get_parameter_bounds(self) -> list[tuple[float, float]]: """Get bounds appropriate for analysis mode.""" bounds = self.diffusion_model.get_parameter_bounds() if not self.analysis_mode.startswith("static"): # Add shear parameter bounds for laminar flow bounds.extend(self.shear_model.get_parameter_bounds()) return bounds
[docs] def get_default_parameters(self) -> jnp.ndarray: """Get default parameters appropriate for analysis mode.""" defaults = self.diffusion_model.get_default_parameters() if not self.analysis_mode.startswith("static"): # Add shear parameter defaults for laminar flow shear_defaults = self.shear_model.get_default_parameters() defaults = jnp.concatenate([defaults, shear_defaults]) return defaults
# Mixin methods are inherited from: # - GradientCapabilityMixin: get_gradient_function, get_hessian_function, # supports_gradients, get_best_gradient_method, get_gradient_capabilities # - BenchmarkingMixin: benchmark_gradient_performance, validate_gradient_accuracy # - OptimizationRecommendationMixin: get_optimization_recommendations, get_model_info # Factory functions for easy model creation
[docs] def create_model(analysis_mode: str) -> CombinedModel: """Factory function to create appropriate model for analysis mode. Args: analysis_mode: "static" or "laminar_flow" Returns: Configured CombinedModel instance """ valid_modes = ["static", "laminar_flow", "static_isotropic", "static_anisotropic"] if analysis_mode not in valid_modes: raise ValueError( f"Invalid analysis mode '{analysis_mode}'. Must be one of {valid_modes}", ) logger.info(f"Creating model for analysis mode: {analysis_mode}") return CombinedModel(analysis_mode=analysis_mode)
[docs] def get_available_models() -> list[str]: """Get list of available analysis modes.""" return ["static", "laminar_flow", "static_isotropic", "static_anisotropic"]
# Export main classes and functions __all__ = [ "PhysicsModelBase", "DiffusionModel", "ShearModel", "CombinedModel", "create_model", "get_available_models", ]