"""HomodyneModel - Hybrid Architecture Wrapper
============================================
Hybrid architecture combining stateful robustness with functional JIT performance.
This module implements the long-term architectural recommendation from the
architectural comparison analysis, providing:
1. **Stateful Storage**: Configuration validated once, stored in instance
2. **Pre-computed Factors**: Physics factors computed once at initialization
3. **High-level API**: Simple methods using stored configuration
4. **JIT Performance**: Calls functional cores for optimal performance
Best of both worlds: Robustness + Performance
Usage Example
-------------
>>> from homodyne.core.homodyne_model import HomodyneModel
>>>
>>> # Create model from configuration
>>> config = load_config("config.yaml")
>>> model = HomodyneModel(config)
>>>
>>> # Compute C2 - NO dt parameter needed!
>>> params = np.array([100.0, 0.0, 10.0, 1e-4, 0.0, 0.0, 0.0])
>>> phi_angles = np.array([0, 30, 45, 60, 90])
>>> c2 = model.compute_c2(params, phi_angles)
>>>
>>> # Or use convenience method
>>> model.plot_simulated_data(params, phi_angles, output_dir="./results")
"""
from pathlib import Path
import numpy as np
from homodyne.core.jax_backend import compute_g2_scaled_with_factors, jnp
from homodyne.core.models import CombinedModel
from homodyne.core.physics_factors import create_physics_factors_from_config_dict
from homodyne.utils.logging import get_logger
logger = get_logger(__name__)
[docs]
class HomodyneModel:
"""Hybrid architecture wrapper for homodyne XPCS analysis.
This class combines the robustness of stateful object-oriented design
with the performance of functional JAX programming. It:
1. Stores configuration (dt, q, L) as instance state
2. Pre-computes physics factors once at initialization
3. Provides high-level methods that use stored state
4. Calls JIT-compiled functional cores for performance
Benefits
--------
- **Robustness**: Configuration validated once at initialization
- **Performance**: Physics factors pre-computed, JIT-compiled cores
- **Usability**: Simple API, no dt parameter passing needed
- **Safety**: No dt estimation errors possible
- **Efficiency**: Factors computed once, reused for all calculations
Attributes
----------
physics_factors : PhysicsFactors
Pre-computed physics factors (q²dt/2, qLdt/2π)
time_array : jnp.ndarray
Time array for correlation calculations [s]
t1_grid, t2_grid : jnp.ndarray
2D time grids for correlation matrices
model : homodyne.core.models.CombinedModel
Underlying physics model (for backward compatibility)
dt : float
Time step [s]
wavevector_q : float
Scattering wave vector magnitude [Å⁻¹]
stator_rotor_gap : float
Sample-detector distance [Å]
analysis_mode : str
Analysis mode ("static", "laminar_flow")
Examples
--------
Basic usage:
>>> model = HomodyneModel(config)
>>> c2 = model.compute_c2(params, phi_angles)
With plotting:
>>> model.plot_simulated_data(params, phi_angles, output_dir="./results")
Access configuration:
>>> print(model.config_summary)
>>> print(f"dt = {model.dt} s")
>>> print(f"Pre-computed factors: {model.physics_factors}")
"""
[docs]
def __init__(self, config: dict):
"""Initialize HomodyneModel from configuration dictionary.
Parameters
----------
config : dict
Homodyne configuration dictionary with structure::
{
'analyzer_parameters': {
'temporal': {'dt': float, 'start_frame': int, 'end_frame': int},
'scattering': {'wavevector_q': float},
'geometry': {'stator_rotor_gap': float}
},
'analysis_settings': {...} # Optional
}
Raises
------
KeyError
If required configuration keys are missing
ValueError
If configuration values are invalid
"""
logger.info("Initializing HomodyneModel with hybrid architecture")
# Extract and validate configuration
self._extract_config(config)
# Pre-compute physics factors ONCE
self.physics_factors = create_physics_factors_from_config_dict(config)
logger.info(f"Pre-computed physics factors: {self.physics_factors}")
# Resolve end_frame sentinel (-1 means "use all frames")
if self.end_frame < 0:
raise ValueError(
f"end_frame={self.end_frame} is a sentinel value and must be resolved "
f"to a concrete frame index before constructing HomodyneModel. "
f"Use XPCSDataLoader to resolve this value from the HDF5 file."
)
# Create time array
n_time = self.end_frame - self.start_frame + 1
self.time_array = jnp.linspace(
0,
self.dt * (n_time - 1),
n_time,
dtype=jnp.float64,
)
# Create time grids for correlation calculations
self.t1_grid, self.t2_grid = jnp.meshgrid(
self.time_array,
self.time_array,
indexing="ij",
)
logger.debug(
f"Time array: n={n_time}, range=[0, {self.dt * (n_time - 1):.2f}] s",
)
# Create underlying model (for backward compatibility)
self.model = CombinedModel(analysis_mode=self.analysis_mode)
logger.info("HomodyneModel initialized successfully")
logger.info(f" Analysis mode: {self.analysis_mode}")
logger.info(f" Time points: {n_time}")
logger.info(f" dt: {self.dt} s")
[docs]
def compute_c2(
self,
params: np.ndarray,
phi_angles: np.ndarray,
contrast: float = 0.5,
offset: float = 1.0,
) -> np.ndarray:
"""Compute C2 correlation function using stored configuration.
This high-level method:
- Uses pre-computed time grids (self.t1_grid, self.t2_grid)
- Uses pre-computed physics factors (self.physics_factors)
- Calls JIT-compiled functional core for performance
- Returns C2 for all phi angles
NO dt parameter needed - uses stored configuration!
Parameters
----------
params : np.ndarray
Physical parameters:
- For laminar_flow (7 params): [D0, alpha, D_offset, gamma_dot_t0, beta, gamma_dot_t_offset, phi0]
- For static (3 params): [D0, alpha, D_offset]
phi_angles : np.ndarray
Scattering angles [degrees], shape (n_phi,)
contrast : float, default=0.5
Contrast parameter (β in literature)
offset : float, default=1.0
Baseline offset
Returns
-------
np.ndarray
C2 correlation matrices, shape (n_phi, n_time, n_time)
Examples
--------
>>> model = HomodyneModel(config)
>>> params = np.array([100.0, 0.0, 10.0, 1e-4, 0.0, 0.0, 0.0])
>>> phi_angles = np.array([0, 30, 45, 60, 90])
>>> c2 = model.compute_c2(params, phi_angles)
>>> print(c2.shape) # (5, 100, 100) for 5 angles, 100 time points
"""
# Convert to JAX arrays
params_jax = jnp.array(params)
phi_angles_jax = jnp.array(phi_angles)
# Extract pre-computed factors
q_factor, sinc_factor = self.physics_factors.to_tuple()
# Single vectorized call: pass all phi angles at once.
# _compute_g1_shear_core handles phi arrays in matrix mode via vmap,
# returning shape (n_phi, n_times, n_times) — no Python loop needed.
result = compute_g2_scaled_with_factors(
params_jax,
self.t1_grid,
self.t2_grid,
phi_angles_jax,
q_factor, # Pre-computed at init
sinc_factor, # Pre-computed at init
contrast,
offset,
self.dt, # Time step from experimental configuration
)
logger.debug(
f"Computed C2 for {len(phi_angles)} angles, "
f"shape: {result.shape}, "
f"range: [{float(np.nanmin(result)):.4f}, {float(np.nanmax(result)):.4f}]",
)
return np.array(result)
[docs]
def compute_c2_single_angle(
self,
params: np.ndarray,
phi: float,
contrast: float = 0.5,
offset: float = 1.0,
) -> np.ndarray:
"""Compute C2 correlation function for a single angle.
Convenience method for single-angle calculations.
Parameters
----------
params : np.ndarray
Physical parameters
phi : float
Scattering angle [degrees]
contrast : float, default=0.5
Contrast parameter
offset : float, default=1.0
Baseline offset
Returns
-------
np.ndarray
C2 correlation matrix, shape (n_time, n_time)
"""
c2 = self.compute_c2(params, np.array([phi]), contrast, offset)
result: np.ndarray = c2[0]
return result
[docs]
def plot_simulated_data(
self,
params: np.ndarray,
phi_angles: np.ndarray,
output_dir: str = "./simulated_data",
contrast: float = 0.5,
offset: float = 1.0,
generate_plots: bool = True,
) -> tuple[np.ndarray, Path]:
"""Generate and optionally plot simulated C2 data.
This convenience method:
1. Computes C2 using stored configuration
2. Optionally generates heatmap plots for each angle
3. Saves data to NumPy file
4. Returns both data and output path
Parameters
----------
params : np.ndarray
Physical parameters
phi_angles : np.ndarray
Scattering angles [degrees]
output_dir : str, default="./simulated_data"
Output directory for plots and data
contrast : float, default=0.5
Contrast parameter
offset : float, default=1.0
Baseline offset
generate_plots : bool, default=True
Whether to generate heatmap plots
Returns
-------
tuple of (np.ndarray, Path)
(c2_data, output_path)
- c2_data: Computed correlation matrices
- output_path: Path to saved data file
Examples
--------
>>> model = HomodyneModel(config)
>>> c2_data, output_path = model.plot_simulated_data(
... params, phi_angles, output_dir="./results"
... )
>>> print(f"Data saved to: {output_path}")
"""
# Create output directory
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# Compute C2
logger.info(f"Computing C2 for {len(phi_angles)} angles...")
c2_data = self.compute_c2(params, phi_angles, contrast, offset)
# Save data
data_file = output_path / "c2_simulated_data.npz"
np.savez_compressed(
data_file,
c2_data=c2_data,
phi_angles=phi_angles,
time_array=np.array(self.time_array),
params=params,
contrast=contrast,
offset=offset,
dt=self.dt,
**{f"pf_{k}": v for k, v in self.physics_factors.to_dict().items()},
)
logger.info(f"Saved data to: {data_file}")
# Generate plots if requested
if generate_plots:
try:
self._generate_heatmap_plots(
c2_data,
phi_angles,
output_path,
contrast,
offset,
)
except Exception as e:
logger.warning(f"Failed to generate plots: {e}")
logger.warning("Data was saved successfully, continuing...")
return c2_data, data_file
def _generate_heatmap_plots(
self,
c2_data: np.ndarray,
phi_angles: np.ndarray,
output_dir: Path,
contrast: float,
offset: float,
) -> None:
"""Generate heatmap plots for C2 data."""
try:
import matplotlib.pyplot as plt
except ImportError:
logger.warning("matplotlib not available, skipping plots")
return
logger.info(f"Generating heatmap plots for {len(phi_angles)} angles...")
for i, phi in enumerate(phi_angles):
c2_matrix = c2_data[i]
# Create figure
fig, ax = plt.subplots(figsize=(8, 6))
# Create heatmap with fixed color scale [1.0, 1.5]
im = ax.imshow(
c2_matrix,
aspect="equal",
origin="lower",
extent=(
float(self.time_array[0]),
float(self.time_array[-1]),
float(self.time_array[0]),
float(self.time_array[-1]),
),
cmap="jet",
vmin=1.0,
vmax=1.5,
)
# Add colorbar
cbar = plt.colorbar(im, ax=ax)
cbar.set_label("C₂(t₁, t₂)", fontsize=12)
# Set labels and title
ax.set_xlabel("t₁ (s)", fontsize=12)
ax.set_ylabel("t₂ (s)", fontsize=12)
ax.set_title(
f"Simulated C₂ Correlation Function (φ = {phi:.1f}°)\n"
f"contrast={contrast}, offset={offset}",
fontsize=14,
)
# Save plot
filename = f"c2_simulated_phi_{phi:.1f}deg.png"
filepath = output_dir / filename
plt.tight_layout()
plt.savefig(filepath, dpi=300, bbox_inches="tight")
plt.close(fig)
logger.debug(f" Saved: {filename}")
logger.info(f"Generated {len(phi_angles)} heatmap plots")
def _extract_config(self, config: dict) -> None:
"""Extract and validate configuration parameters."""
try:
analyzer_params = config["analyzer_parameters"]
# Temporal parameters
self.dt = analyzer_params["temporal"]["dt"]
self.start_frame = analyzer_params["temporal"]["start_frame"]
self.end_frame = analyzer_params["temporal"]["end_frame"]
# Physical parameters
self.wavevector_q = analyzer_params["scattering"]["wavevector_q"]
self.stator_rotor_gap = analyzer_params["geometry"]["stator_rotor_gap"]
# Analysis mode
self.analysis_mode = self._determine_analysis_mode(config)
except KeyError as e:
raise KeyError(
f"Missing required configuration key: {e}. "
f"Expected structure: config['analyzer_parameters'][...]",
) from e
def _determine_analysis_mode(self, config: dict) -> str:
"""Determine analysis mode from configuration."""
analysis_settings = config.get("analysis_settings", {})
if analysis_settings:
is_static = bool(analysis_settings.get("static_mode", False))
is_isotropic = bool(analysis_settings.get("isotropic_mode", False))
if is_static:
return "static_isotropic" if is_isotropic else "static_anisotropic"
return "laminar_flow"
mode = config.get("analysis_mode")
if mode:
mode_lower = str(mode).lower()
if "static" in mode_lower:
return (
"static_isotropic"
if "isotropic" in mode_lower
else "static_anisotropic"
)
if mode_lower in {"laminar", "laminar_flow"}:
return "laminar_flow"
return "laminar_flow"
@property
def config_summary(self) -> dict:
"""Get configuration summary for logging/debugging.
Returns
-------
dict
Configuration summary with all key parameters
"""
return {
"dt": self.dt,
"time_length": len(self.time_array),
"time_range": [0, self.dt * (len(self.time_array) - 1)],
"wavevector_q": self.wavevector_q,
"stator_rotor_gap": self.stator_rotor_gap,
"analysis_mode": self.analysis_mode,
"physics_factors": self.physics_factors.to_dict(),
"start_frame": self.start_frame,
"end_frame": self.end_frame,
}
[docs]
def __repr__(self) -> str:
"""String representation of HomodyneModel."""
return (
f"HomodyneModel(\n"
f" analysis_mode='{self.analysis_mode}',\n"
f" dt={self.dt} s,\n"
f" time_points={len(self.time_array)},\n"
f" q={self.wavevector_q} AA^-1,\n"
f" L={self.stator_rotor_gap} AA\n"
f")"
)
__all__ = ["HomodyneModel"]