"""XPCS Data Loader for Homodyne
================================
Enhanced XPCS data loader supporting both APS (old) and APS-U (new) HDF5 formats
with YAML-first configuration system, JAX compatibility, and modern architecture integration.
This module provides:
- YAML-first configuration with JSON support
- Smart NPZ caching to avoid reloading large HDF5 files
- Auto-detection of APS vs APS-U format
- Half-matrix reconstruction for correlation matrices
- Mandatory diagonal correction applied post-load
- JAX array output with numpy fallback
- Integration with v2 logging and physics validation
Key Features:
- Format Support: APS old format and APS-U new format
- Configuration: YAML primary, JSON via converter
- Caching: Intelligent NPZ caching with compression
- Output: JAX arrays when available, numpy fallback
- Validation: Optional physics-based data quality checks
"""
from __future__ import annotations
import json
import os
import re
import string
import time
from collections.abc import Callable
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeVar
# Handle optional dependencies with graceful fallback
if TYPE_CHECKING:
from numpy.typing import NDArray
else:
NDArray = Any
try:
import numpy as np
HAS_NUMPY = True
except ImportError:
HAS_NUMPY = False
np = None # type: ignore[assignment]
try:
import h5py
HAS_H5PY = True
except ImportError:
HAS_H5PY = False
h5py = None
try:
import yaml
HAS_YAML = True
except ImportError:
HAS_YAML = False
yaml = None # type: ignore[assignment]
# JAX integration
try:
import jax.numpy as jnp
from homodyne.core.jax_backend import jax_available
HAS_JAX = True
except ImportError:
HAS_JAX = False
jax_available = False
jnp = np # type: ignore[misc]
# V2 system integration
try:
from homodyne.utils.logging import (
get_logger as _get_logger,
)
from homodyne.utils.logging import (
log_calls as _log_calls,
)
from homodyne.utils.logging import (
log_performance as _log_performance,
)
from homodyne.utils.logging import (
log_phase as _log_phase,
)
HAS_V2_LOGGING = True
get_logger = _get_logger
log_performance = _log_performance
log_calls = _log_calls
log_phase = _log_phase
except ImportError:
# Fallback to standard logging if v2 logging not available
import logging
from collections.abc import Iterator
from contextlib import contextmanager
HAS_V2_LOGGING = False
F = TypeVar("F", bound=Callable[..., Any])
def get_logger(name: str | None = None, **kwargs: Any) -> logging.Logger:
return logging.getLogger(name)
def log_performance(*args: Any, **kwargs: Any) -> Callable[[F], F]:
def decorator(func: F) -> F:
return func
return decorator
def log_calls(*args: Any, **kwargs: Any) -> Callable[[F], F]:
def decorator(func: F) -> F:
return func
return decorator
@contextmanager
def log_phase(name: str, **kwargs: Any) -> Iterator[Any]:
"""Fallback log_phase for environments without v2 logging."""
yield type("PhaseContext", (), {"duration": 0.0, "memory_peak_gb": None})()
# Physics validation integration
try:
from homodyne.core.physics import (
PhysicsConstants as _PhysicsConstants,
)
from homodyne.core.physics import (
validate_experimental_setup as _validate_experimental_setup,
)
HAS_PHYSICS_VALIDATION = True
PhysicsConstants = _PhysicsConstants
validate_experimental_setup = _validate_experimental_setup
except ImportError:
HAS_PHYSICS_VALIDATION = False
PhysicsConstants = None # type: ignore
validate_experimental_setup = None # type: ignore
# Diagonal correction from unified module
try:
from homodyne.core.diagonal_correction import (
apply_diagonal_correction_batch as _apply_diagonal_correction_batch,
)
HAS_DIAGONAL_CORRECTION = True
apply_diagonal_correction_batch = _apply_diagonal_correction_batch
except ImportError:
HAS_DIAGONAL_CORRECTION = False
apply_diagonal_correction_batch = None # type: ignore
# Performance engine integration
try:
from homodyne.data.memory_manager import (
AdvancedMemoryManager as _AdvancedMemoryManager,
)
from homodyne.data.optimization import (
AdvancedDatasetOptimizer as _AdvancedDatasetOptimizer,
)
from homodyne.data.performance_engine import PerformanceEngine as _PerformanceEngine
HAS_PERFORMANCE_ENGINE = True
PerformanceEngine = _PerformanceEngine
AdvancedMemoryManager = _AdvancedMemoryManager
AdvancedDatasetOptimizer = _AdvancedDatasetOptimizer
except ImportError:
HAS_PERFORMANCE_ENGINE = False
PerformanceEngine = None # type: ignore
AdvancedMemoryManager = None # type: ignore
AdvancedDatasetOptimizer = None # type: ignore
logger = get_logger(__name__)
# Regex to detect old str.format()-style placeholders: {var} or {var:.4f}
_OLD_FORMAT_RE = re.compile(r"\{(\w+)(?::[^}]*)?\}")
def _migrate_cache_template(template: str) -> str:
"""Auto-convert old {var} format templates to ${var} syntax.
Returns the template unchanged if it already uses $ syntax.
Logs a warning on first migration.
"""
if "$" not in template and _OLD_FORMAT_RE.search(template):
migrated = _OLD_FORMAT_RE.sub(r"${\1}", template)
logger.warning(
"Cache template uses deprecated {var} format; auto-migrated to ${var}. "
"Update your YAML config: %r -> %r",
template,
migrated,
)
return migrated
return template
[docs]
class XPCSDependencyError(Exception):
"""Raised when required dependencies are not available."""
[docs]
class XPCSConfigurationError(Exception):
"""Raised when configuration is invalid or missing required parameters."""
[docs]
def load_xpcs_config(config_path: str | Path) -> dict[str, Any]:
"""Load XPCS configuration from YAML or JSON file.
Primary format: YAML
JSON support: Automatically converted to YAML format
Args:
config_path: Path to YAML or JSON configuration file
Returns:
Configuration dictionary with YAML-style structure
Raises:
XPCSConfigurationError: If configuration format is unsupported or invalid
"""
config_path = Path(config_path)
if not config_path.exists():
raise XPCSConfigurationError(f"Configuration file not found: {config_path}")
try:
if config_path.suffix.lower() in [".yaml", ".yml"]:
if not HAS_YAML:
raise XPCSDependencyError(
"PyYAML required for YAML configuration files",
)
# Native YAML loading
with open(config_path, encoding="utf-8") as f:
config: dict[str, Any] = yaml.safe_load(f)
logger.info(f"Loaded YAML configuration: {config_path}")
return config
elif config_path.suffix.lower() == ".json":
# JSON loading with structure conversion
with open(config_path, encoding="utf-8") as f:
json_config: dict[str, Any] = json.load(f)
logger.info(f"Loaded JSON configuration (converted to YAML): {config_path}")
logger.info("Consider migrating to YAML format for better readability")
# Convert JSON structure to YAML-style (for now, keep identical structure)
# In future, can add more sophisticated conversion via existing converter
return json_config
else:
raise XPCSConfigurationError(
f"Unsupported configuration format: {config_path.suffix}. "
f"Supported formats: .yaml, .yml, .json",
)
except (yaml.YAMLError, json.JSONDecodeError) as e:
raise XPCSConfigurationError(
f"Failed to parse configuration file {config_path}: {e}",
) from e
[docs]
class XPCSDataLoader:
"""Enhanced XPCS data loader for Homodyne.
Supports both APS (old) and APS-U (new) formats with YAML-first configuration,
intelligent caching, and JAX integration.
Features:
- YAML-first configuration with JSON support
- Auto-detection of HDF5 format (APS vs APS-U)
- Smart NPZ caching with compression
- Half-matrix reconstruction for correlation matrices
- Mandatory diagonal correction applied consistently
- JAX array output when available
- Integration with v2 physics validation
"""
[docs]
@log_calls(include_args=False)
def __init__(
self,
config_path: str | None = None,
config_dict: dict | None = None,
configure_logging: bool = True,
generate_quality_reports: bool = False, # Only generate reports when explicitly requested
):
"""Initialize XPCS data loader with YAML-first configuration.
Args:
config_path: Path to YAML or JSON configuration file
config_dict: Configuration dictionary (alternative to config_path)
configure_logging: Whether to apply logging configuration from config
generate_quality_reports: Whether to generate quality reports (default: False)
Raises:
XPCSDependencyError: If required dependencies are not available
XPCSConfigurationError: If configuration is invalid
"""
# Check for required dependencies
self._check_dependencies()
# Store whether to generate quality reports (only for --plot-experimental-data)
self.generate_quality_reports = generate_quality_reports
if config_path and config_dict:
raise ValueError("Provide either config_path or config_dict, not both")
if config_path:
self.config = load_xpcs_config(config_path)
elif config_dict:
self.config = config_dict
else:
raise ValueError("Must provide either config_path or config_dict")
# Transform flat structure to nested structure for backward compatibility
self._normalize_config_structure()
# Process v2 configuration enhancements
self._process_v2_config_enhancements()
# Extract main configuration sections
self.exp_config = self.config.get("experimental_data", {})
self.analyzer_config = self.config.get("analyzer_parameters", {})
self.v2_config = self.config.get("v2_features", {})
# Initialize performance optimization components
self._init_performance_components()
# Validate configuration
self._validate_configuration()
logger.info(
f"XPCS data loader initialized with {len(self.config)} config sections",
)
def _check_dependencies(self) -> None:
"""Check for required dependencies and raise error if missing."""
missing_deps = []
if not HAS_NUMPY:
missing_deps.append("numpy")
if not HAS_H5PY:
missing_deps.append("h5py")
if missing_deps:
error_msg = f"Missing required dependencies: {', '.join(missing_deps)}. "
error_msg += "Please install them with: pip install " + " ".join(
missing_deps,
)
logger.error(error_msg)
raise XPCSDependencyError(error_msg)
def _normalize_config_structure(self) -> None:
"""Transform flat config structure to nested structure for backward compatibility.
Detects flat structure (config with data_file at root level) and transforms it to
nested structure (config with experimental_data, analyzer_parameters sections).
Flat structure example:
{
"data_file": "/path/to/file.h5",
"analysis_mode": "static_isotropic",
"dt": 0.1,
"start_frame": 1,
"end_frame": -1,
}
Nested structure example:
{
"analysis_mode": "static_isotropic",
"experimental_data": {
"data_folder_path": "/path/to",
"data_file_name": "file.h5",
},
"analyzer_parameters": {
"dt": 0.1,
"start_frame": 1,
"end_frame": -1,
},
}
"""
# Check if already in nested structure (has experimental_data section)
if "experimental_data" in self.config:
return # Already normalized
# Check if in flat structure (has data_file at root)
if "data_file" not in self.config:
return # Neither flat nor nested - let validation handle it
# Transform flat to nested
import os
data_file = self.config.pop("data_file")
data_folder_path = os.path.dirname(data_file) or "."
data_file_name = os.path.basename(data_file)
# Create experimental_data section
self.config["experimental_data"] = {
"data_folder_path": data_folder_path,
"data_file_name": data_file_name,
}
# Move analyzer parameters to analyzer_parameters section with defaults
analyzer_params = {
"dt": 0.1, # Default time step (seconds)
"start_frame": 1, # Default start frame
"end_frame": -1, # Default end frame (-1 means all frames)
}
self.config["analyzer_parameters"] = {}
for param, default_value in analyzer_params.items():
if param in self.config:
self.config["analyzer_parameters"][param] = self.config.pop(param)
else:
# Provide default for backward compatibility
self.config["analyzer_parameters"][param] = default_value
# Move output parameters to output section if present
output_params = ["output_directory"]
if any(param in self.config for param in output_params):
self.config["output"] = {}
for param in output_params:
if param in self.config:
self.config["output"][param] = self.config.pop(param)
logger.debug("Transformed flat config structure to nested structure")
def _process_v2_config_enhancements(self) -> None:
"""Process v2 configuration enhancements and set defaults."""
if "v2_features" not in self.config:
self.config["v2_features"] = {}
v2_defaults = {
"output_format": "auto", # 'numpy', 'jax', 'auto'
"validation_level": "basic", # 'none', 'basic', 'full'
"performance_optimization": True,
"physics_validation": False,
"cache_strategy": "intelligent", # 'none', 'simple', 'intelligent'
}
for key, default_value in v2_defaults.items():
if key not in self.config["v2_features"]:
self.config["v2_features"][key] = default_value
# Add performance optimization defaults
performance_defaults = {
"performance_engine_enabled": True,
"memory_mapped_io": True,
"advanced_chunking": True,
"multi_level_caching": True,
"background_prefetching": True,
"memory_pressure_monitoring": True,
}
if "performance" not in self.config:
self.config["performance"] = {}
for key, default_value in performance_defaults.items():
if key not in self.config["performance"]:
self.config["performance"][key] = default_value
def _init_performance_components(self) -> None:
"""Initialize performance optimization components."""
self.performance_engine = None
self.memory_manager = None
self.advanced_optimizer = None
# Check if performance optimization is enabled
performance_config = self.config.get("performance", {})
if not performance_config.get("performance_engine_enabled", True):
logger.info("Performance engine disabled in configuration")
return
if not HAS_PERFORMANCE_ENGINE:
logger.warning(
"Performance engine not available - falling back to basic optimization",
)
return
try:
# Initialize performance engine
if performance_config.get("performance_engine_enabled", True):
self.performance_engine = PerformanceEngine(self.config)
logger.info("Performance engine initialized")
# Initialize memory manager
if performance_config.get("memory_pressure_monitoring", True):
self.memory_manager = AdvancedMemoryManager(self.config)
logger.info("Advanced memory manager initialized")
# Initialize advanced optimizer
self.advanced_optimizer = AdvancedDatasetOptimizer(
config=self.config,
performance_engine=self.performance_engine,
memory_manager=self.memory_manager,
)
logger.info("Advanced dataset optimizer initialized")
except Exception as e:
logger.warning(f"Performance components initialization failed: {e}")
logger.info("Falling back to basic optimization")
self.performance_engine = None
self.memory_manager = None
self.advanced_optimizer = None
def _validate_configuration(self) -> None:
"""Validate configuration parameters."""
required_exp_data = ["data_folder_path", "data_file_name"]
required_analyzer = ["dt", "start_frame", "end_frame"]
for key in required_exp_data:
if key not in self.exp_config:
raise XPCSConfigurationError(
f"Missing required experimental_data parameter: {key}",
)
for key in required_analyzer:
if key not in self.analyzer_config:
raise XPCSConfigurationError(
f"Missing required analyzer_parameters parameter: {key}",
)
# Validate file existence
data_file_path = os.path.join(
self.exp_config["data_folder_path"],
self.exp_config["data_file_name"],
)
if ".." in str(data_file_path) or "\x00" in str(data_file_path):
raise ValueError(
f"Path traversal detected in data file path: {data_file_path}"
)
if not os.path.exists(data_file_path):
logger.warning(f"Data file not found: {data_file_path}")
logger.info("File will be checked again during data loading")
def _get_output_format(self) -> str:
"""Get output array format from configuration."""
format_val: Any = self.v2_config.get("output_format", "auto")
return str(format_val)
def _should_perform_validation(self) -> dict[str, bool]:
"""Get validation settings from configuration."""
validation_level = self.v2_config.get("validation_level", "basic")
return {
"physics_checks": self.v2_config.get("physics_validation", False)
and HAS_PHYSICS_VALIDATION,
"data_quality": validation_level != "none",
"comprehensive": validation_level == "full",
}
def _convert_arrays_to_target_format(
self,
data: dict[str, NDArray],
) -> dict[str, Any]:
"""Convert arrays to target format based on configuration.
Args:
data: Dictionary with numpy arrays
Returns:
Dictionary with arrays in target format (JAX or numpy)
"""
output_format = self._get_output_format()
if output_format == "jax" and HAS_JAX and jax_available:
logger.debug("Converting arrays to JAX format")
return {
k: jnp.asarray(v, dtype=jnp.float64) if isinstance(v, np.ndarray) else v
for k, v in data.items()
}
elif output_format == "auto" and HAS_JAX and jax_available:
logger.debug("Auto-selecting JAX format (available)")
return {
k: jnp.asarray(v, dtype=jnp.float64) if isinstance(v, np.ndarray) else v
for k, v in data.items()
}
elif output_format == "auto":
logger.debug("Auto-selecting numpy format (JAX not available)")
return data # Keep numpy format
[docs]
@log_performance(threshold=0.5)
def load_experimental_data(self) -> dict[str, Any]:
"""Load experimental data with priority: cache NPZ → raw HDF → error.
Returns:
Dictionary containing:
- wavevector_q_list: Array of q values
- phi_angles_list: Array of phi angles
- t1: Time array for first dimension
- t2: Time array for second dimension
- c2_exp: Experimental correlation data
"""
# Construct file paths
data_folder = self.exp_config.get("data_folder_path", "./")
data_file = self.exp_config.get("data_file_name", "")
cache_folder = self.exp_config.get("cache_file_path", data_folder)
# Get frame parameters
start_frame = self.analyzer_config.get("start_frame", 1)
end_frame = self.analyzer_config.get("end_frame", 8000)
# Construct cache filename (using string.Template for safety)
cache_template = _migrate_cache_template(
self.exp_config.get(
"cache_filename_template",
"cached_c2_frames_${start_frame}_${end_frame}.npz",
)
)
# Get wavevector_q for cache filename (selective caching support)
scattering_config = self.analyzer_config.get("scattering", {})
wavevector_q = scattering_config.get("wavevector_q", 0.0054)
tmpl = string.Template(cache_template)
cache_filename = tmpl.safe_substitute(
start_frame=start_frame,
end_frame=end_frame,
wavevector_q=f"{wavevector_q:.4f}",
)
if os.sep in cache_filename or ".." in cache_filename:
raise ValueError(f"Unsafe cache filename from template: {cache_filename!r}")
cache_path = os.path.join(cache_folder, cache_filename)
# If user provided a direct NPZ path, prefer it
direct_path = os.path.join(data_folder, data_file) if data_file else ""
if direct_path.endswith(".npz") and os.path.exists(direct_path):
logger.info(f"Loading data from NPZ override: {direct_path}")
data = self._load_from_cache(direct_path)
# Otherwise, try cache then raw HDF
elif (
os.path.exists(cache_path)
and self.v2_config.get("cache_strategy", "intelligent") != "none"
):
logger.info(f"Loading cached data from: {cache_path}")
data = self._load_from_cache(cache_path)
else:
# Load from raw HDF file
hdf_path = os.path.join(data_folder, data_file)
if not os.path.exists(hdf_path):
raise FileNotFoundError(
f"Neither cache file {cache_path} nor HDF file {hdf_path} exists",
)
logger.info(f"Loading raw data from: {hdf_path}")
data = self._load_from_hdf(hdf_path)
# Save to cache if caching enabled
if self.v2_config.get("cache_strategy", "intelligent") != "none":
logger.info(f"Saving processed data to cache: {cache_path}")
self._save_to_cache(data, cache_path)
# Generate text files
self._save_text_files(data)
# Initialize quality control if enabled
quality_controller = self._initialize_quality_control()
quality_results = []
# Stage 1: Raw data validation
if quality_controller:
raw_validation_result = quality_controller.validate_data_stage(
data,
quality_controller.QualityControlStage.RAW_DATA,
)
quality_results.append(raw_validation_result)
# Apply auto-repair if data was modified
if raw_validation_result.data_modified:
logger.info("Raw data was modified by quality control auto-repair")
# Apply filtering with quality control validation
if quality_controller:
filtered_validation_result = quality_controller.validate_data_stage(
data,
quality_controller.QualityControlStage.FILTERED_DATA,
previous_result=quality_results[-1] if quality_results else None,
)
quality_results.append(filtered_validation_result)
# Apply preprocessing pipeline if enabled with quality control
data = self._apply_preprocessing_pipeline(
data,
quality_controller,
quality_results,
)
# Convert to target array format (JAX or numpy)
data = self._convert_arrays_to_target_format(data)
# Apply mandatory diagonal correction (post-load for consistent behavior)
# Uses unified diagonal_correction module (v2.14.2+)
logger.debug("Applying mandatory diagonal correction to correlation matrices")
if HAS_DIAGONAL_CORRECTION:
data["c2_exp"] = apply_diagonal_correction_batch(data["c2_exp"])
else:
# Fallback to local implementation if unified module not available
data["c2_exp"] = self._correct_diagonal_batch(data["c2_exp"])
# Final quality control validation
if quality_controller:
final_validation_result = quality_controller.validate_data_stage(
data,
quality_controller.QualityControlStage.FINAL_DATA,
previous_result=quality_results[-1] if quality_results else None,
)
quality_results.append(final_validation_result)
# Generate quality report only when explicitly requested (--plot-experimental-data)
# Do NOT generate reports during normal VI/MCMC runs to avoid cluttering
if self.generate_quality_reports and self.v2_config.get(
"quality_control",
{},
).get("generate_reports", True):
quality_report = quality_controller.generate_quality_report(
quality_results,
self._get_quality_report_path(),
)
logger.info(
f"Quality report generated with overall status: {quality_report['overall_summary']['status']}",
)
# Perform legacy validation if enabled
validation_settings = self._should_perform_validation()
if any(validation_settings.values()) and not quality_controller:
self._validate_loaded_data(data, validation_settings)
logger.info(
f"Data loaded successfully - shapes: q{data['wavevector_q_list'].shape}, "
f"phi{data['phi_angles_list'].shape}, c2{data['c2_exp'].shape}",
)
return data
@log_performance(threshold=0.2)
def _load_from_cache(self, cache_path: str) -> dict[str, Any]:
"""Load data from NPZ cache file with q-vector validation.
Returns 1D time arrays for both NLSQ (regenerates meshgrids) and CMC (uses 1D).
Only supports new 1D array cache format. Old 2D caches must be regenerated.
"""
with np.load(cache_path, allow_pickle=False, mmap_mode="r") as data:
if "cache_metadata_json" in data:
metadata_text = str(np.asarray(data["cache_metadata_json"]).item())
try:
metadata = json.loads(metadata_text)
except json.JSONDecodeError as exc:
raise ValueError(
f"Cache {cache_path} has malformed cache_metadata_json "
f"(not valid JSON): {exc}"
) from exc
if not isinstance(metadata, dict):
raise ValueError(
f"Cache {cache_path}: cache_metadata_json must encode a "
f"JSON object, got {type(metadata).__name__}"
)
self._validate_cache_q_vector(metadata)
logger.debug(f"Cache metadata validation passed: {metadata}")
elif "cache_metadata" in data.files:
# Legacy object-array metadata format is refused for safety:
# object deserialization from a config-controlled path allows
# arbitrary code execution. Delete the cache file and regenerate;
# the new format stores metadata as JSON under 'cache_metadata_json'.
raise ValueError(
f"Cache {cache_path} uses the legacy 'cache_metadata' "
"object-array format (unsafe object deserialization). "
"Delete the cache file and regenerate with current code."
)
# Extract correlation data — np.array() copies from mmap before
# the context manager closes the file (prevents dangling mmap views).
# allow_pickle=False causes object-dtype arrays to raise here; we
# surface that as a clearer error rather than letting numpy leak it.
try:
c2_exp = np.array(data["c2_exp"])
t1 = np.array(data["t1"])
t2 = np.array(data["t2"])
wavevector_q_list = np.array(data["wavevector_q_list"])
phi_angles_list = np.array(data["phi_angles_list"])
except ValueError as exc:
raise ValueError(
f"Cache {cache_path} contains an object-dtype array under "
"a data key. Delete the cache file and regenerate."
) from exc
# Reject old 2D meshgrid cache format
if t1.ndim == 2 or t2.ndim == 2:
raise ValueError(
f"Old 2D meshgrid cache format detected in {cache_path}. "
"Please delete the cache file and regenerate with current code. "
"New cache format uses 1D time arrays."
)
return {
"wavevector_q_list": wavevector_q_list,
"phi_angles_list": phi_angles_list,
"t1": t1, # 1D array: [0, dt, 2*dt, ...]
"t2": t2, # 1D array: [0, dt, 2*dt, ...]
"c2_exp": c2_exp,
}
@log_performance(threshold=1.0)
def _load_from_hdf(self, hdf_path: str) -> dict[str, Any]:
"""Load and process data from HDF5 file."""
# T037: Add log_phase for data loading with memory tracking
with log_phase("hdf5_data_loading", logger=logger, track_memory=True) as phase:
# Detect format
logger.debug("Starting HDF5 format detection")
format_type = self._detect_format(hdf_path)
logger.info(f"Detected format: {format_type}")
# Load based on format
if format_type == "aps_old":
data = self._load_aps_old_format(hdf_path)
elif format_type == "aps_u":
data = self._load_aps_u_format(hdf_path)
else:
raise XPCSDataFormatError(f"Unsupported format: {format_type}")
logger.info(
f"HDF5 loading completed in {phase.duration:.2f}s, "
f"peak memory: {phase.memory_peak_gb:.2f} GB"
if phase.memory_peak_gb
else f"HDF5 loading completed in {phase.duration:.2f}s"
)
return data
@log_performance(threshold=0.1)
def _detect_format(self, hdf_path: str) -> str:
"""Detect whether HDF5 file is APS old or APS-U new format.
Returns:
"aps_u" for APS-U format
"aps_old" for APS old format
"unknown" for unrecognized or empty files
"""
with h5py.File(hdf_path, "r") as f:
# Check for APS-U format keys
if (
"xpcs" in f
and "qmap" in f["xpcs"]
and "dynamic_v_list_dim0" in f["xpcs/qmap"]
and "twotime" in f["xpcs"]
and "correlation_map" in f["xpcs/twotime"]
):
return "aps_u"
# Check for APS old format keys
elif (
"xpcs" in f
and "dqlist" in f["xpcs"]
and "dphilist" in f["xpcs"]
and "exchange" in f
and "C2T_all" in f["exchange"]
):
return "aps_old"
else:
# Log the top-level keys for debugging unrecognized formats
top_keys = list(f.keys())
logger.warning(
f"Unrecognized HDF5 format: top-level keys={top_keys}. "
"Expected APS-U (xpcs/twotime/correlation_map) or "
"APS old (xpcs/dqlist + exchange/C2T_all)."
)
return "unknown"
@log_performance(threshold=0.8)
def _load_aps_old_format(self, hdf_path: str) -> dict[str, Any]:
"""Load data from APS old format HDF5 file.
Optimization (v2.9.1): Uses selective HDF5 reads when quality filtering
is disabled. Instead of loading all matrices upfront, we:
1. First determine which indices are needed based on q-selection
2. Only load those specific matrices from HDF5
This reduces I/O by up to 98% for typical datasets where only ~23 of
~1150 matrices are actually used.
"""
with h5py.File(hdf_path, "r") as f:
# Load q and phi lists (small metadata - always needed)
dqlist = f["xpcs/dqlist"][0, :] # Shape (1, N) -> (N,)
dphilist = f["xpcs/dphilist"][0, :] # Shape (1, N) -> (N,)
# Load correlation data from exchange/C2T_all
c2t_group = f["exchange/C2T_all"]
c2_keys = list(c2t_group.keys())
if not c2_keys:
raise ValueError(
f"APS old-format HDF5 file contains no correlation matrices "
f"in 'exchange/C2T_all': {hdf_path}"
)
# Check if quality-based filtering is enabled (requires loading all matrices)
filtering_config = self.config.get("data_filtering", {})
quality_filtering_enabled = filtering_config.get(
"enabled", False
) and filtering_config.get("quality_filtering", {}).get("enabled", False)
# Select optimal q-vector first (doesn't require matrices)
logger.debug("Selecting optimal q-vector for caching")
selected_q_idx = self._select_optimal_wavevector(dqlist)
selected_q = dqlist[selected_q_idx]
# Calculate q-vector tolerance as fraction of selected q-vector
q_tolerance_fraction = self.config.get("q_tolerance_fraction", 0.1)
q_tolerance = selected_q * q_tolerance_fraction
q_matching_indices = np.where(np.abs(dqlist - selected_q) <= q_tolerance)[0]
# If we still get too few phi angles, expand the search
if len(q_matching_indices) < 5:
# Sort by distance from selected q and take closest N entries
q_distances = np.abs(dqlist - selected_q)
closest_indices = np.argsort(q_distances)
# Take up to 10 closest q-vectors to ensure good phi angle coverage
n_desired = min(10, len(closest_indices))
q_matching_indices_list = [int(i) for i in closest_indices[:n_desired]]
q_matching_indices = np.array(q_matching_indices_list, dtype=int)
logger.debug(
f"Expanded selection to {len(q_matching_indices)} closest q-vectors for better phi coverage",
)
logger.debug(
f"Selected {len(q_matching_indices)} (q,phi) pairs with q-range: "
f"{dqlist[q_matching_indices].min():.6f} - {dqlist[q_matching_indices].max():.6f} AA^-1",
)
if quality_filtering_enabled:
# Two-pass optimization: metadata filter first, then load + quality filter
# Pass 1: phi/q filtering without loading matrices (metadata only)
logger.debug(
"Quality filtering enabled - running metadata-only pre-filter"
)
metadata_indices = self._get_selected_indices(
dqlist,
dphilist,
None, # No matrices needed for phi-only filtering
)
# Narrow to candidates via q + phi intersection
if metadata_indices is not None:
candidate_indices = np.intersect1d(
q_matching_indices, metadata_indices
)
else:
candidate_indices = q_matching_indices
logger.debug(
f"Pre-filter: {len(c2_keys)} total -> {len(candidate_indices)} candidates "
f"({len(candidate_indices) / len(c2_keys) * 100:.1f}% I/O reduction)"
)
# Pass 2: load only candidate matrices from HDF5
candidate_matrices = []
for idx in candidate_indices:
key = c2_keys[int(idx)]
c2_half = c2t_group[key][()]
c2_full = self._reconstruct_full_matrix(c2_half)
candidate_matrices.append(c2_full)
# Apply quality filtering on the loaded subset
quality_indices = self._get_selected_indices(
dqlist[candidate_indices],
dphilist[candidate_indices],
candidate_matrices,
)
# Map quality filter results back to original indices
if quality_indices is not None:
final_indices = candidate_indices[quality_indices]
selected_c2_matrices = [
candidate_matrices[i] for i in quality_indices
]
logger.debug(
f"After quality filtering: {len(candidate_indices)} -> {len(final_indices)} matrices",
)
else:
final_indices = candidate_indices
selected_c2_matrices = candidate_matrices
else:
# OPTIMIZATION: No quality filtering - selective HDF5 reads
# Only load the matrices we actually need (up to 98% I/O reduction)
logger.debug("Applying phi-only filtering (no quality filtering)")
selected_indices = self._get_selected_indices(
dqlist,
dphilist,
None, # Don't pass matrices - not needed for phi-only filtering
)
# Apply additional phi filtering if enabled
if selected_indices is not None:
final_indices = np.intersect1d(q_matching_indices, selected_indices)
logger.debug(
f"After phi filtering: {len(q_matching_indices)} -> {len(final_indices)} matrices",
)
else:
final_indices = q_matching_indices
logger.debug(
f"No phi filtering - using all {len(final_indices)} (q,phi) pairs",
)
# Selective load: only read the matrices we need
logger.info(
f"Selective HDF5 read: loading {len(final_indices)} of {len(c2_keys)} matrices "
f"({len(final_indices) / len(c2_keys) * 100:.1f}% I/O)"
)
selected_c2_matrices = []
for idx in final_indices:
key = c2_keys[int(idx)]
c2_half = c2t_group[key][()]
c2_full = self._reconstruct_full_matrix(c2_half)
selected_c2_matrices.append(c2_full)
# Extract metadata for final indices
filtered_dqlist = dqlist[final_indices]
filtered_dphilist = dphilist[final_indices]
c2_matrices_array = np.array(selected_c2_matrices)
# Apply frame slicing to selected q-vector data
logger.debug(
f"Applying frame slicing to selected q-vector data: shape {c2_matrices_array.shape}",
)
c2_exp = self._apply_frame_slicing_to_selected_q(c2_matrices_array)
# Calculate 1D time array (meshgrids generated by NLSQ as needed, CMC uses 1D)
time_1d = self._calculate_time_arrays(c2_exp.shape[-1])
return {
"wavevector_q_list": filtered_dqlist, # Selected q-vectors (may be multiple for APS old)
"phi_angles_list": filtered_dphilist, # Corresponding phi angles
"t1": time_1d, # 1D time array starting from 0: [0, dt, 2*dt, ...]
"t2": time_1d.copy(), # Independent copy (prevent aliasing mutation)
"c2_exp": c2_exp, # Shape: (n_selected_pairs, sliced_frames, sliced_frames)
}
@log_performance(threshold=0.8)
def _load_aps_u_format(self, hdf_path: str) -> dict[str, Any]:
"""Load data from APS-U new format HDF5 file using processed_bins mapping."""
with h5py.File(hdf_path, "r") as f:
# Load the processed_bins mapping - this tells us which (q,phi) pairs have correlation data
processed_bins = f["xpcs/twotime/processed_bins"][()]
# Load the q and phi lists
q_values = f["xpcs/qmap/dynamic_v_list_dim0"][()] # All q values
phi_values = f["xpcs/qmap/dynamic_v_list_dim1"][
()
] # All phi values available
n_q = len(q_values)
n_phi = len(phi_values)
logger.debug(f"APS-U format: {n_q} q-values, {n_phi} phi-values")
logger.debug(f"Q range: {q_values.min():.6f} to {q_values.max():.6f} A^-1")
logger.debug(f"Phi values: {phi_values}")
logger.debug(
f"Processed bins: {len(processed_bins)} correlation matrices available",
)
# The processed_bins represent which (q,phi) combinations have correlation data
# We need to map these to actual (q,phi) pairs using the grid structure
# For APS-U format: bin_idx = processed_bin - 1; q_idx = bin_idx // n_phi; phi_idx = bin_idx % n_phi
qphi_pairs = []
valid_bin_indices = []
for i, processed_bin in enumerate(processed_bins):
bin_idx = processed_bin - 1 # Convert to 0-based
q_idx = bin_idx // n_phi
phi_idx = bin_idx % n_phi
# Check if indices are valid
if 0 <= q_idx < n_q and 0 <= phi_idx < n_phi:
q_val = q_values[q_idx]
phi_val = phi_values[phi_idx]
qphi_pairs.append((q_val, phi_val))
valid_bin_indices.append(
i,
) # Track which correlation matrix this corresponds to
else:
logger.warning(
f"Invalid bin mapping: processed_bin={processed_bin}, q_idx={q_idx}, phi_idx={phi_idx}",
)
if len(qphi_pairs) == 0:
raise XPCSDataFormatError(
"No valid (q,phi) pairs found from processed_bins mapping",
)
# Convert to arrays for processing
qphi_array = np.array(qphi_pairs)
filtered_dqlist = qphi_array[:, 0] # q values for valid pairs
filtered_dphilist = qphi_array[:, 1] # phi values for valid pairs
logger.debug(
f"Extracted {len(valid_bin_indices)} valid (q,phi) pairs from processed_bins",
)
# Load correlation matrices - only for the valid bins
corr_group = f["xpcs/twotime/correlation_map"]
c2_keys = sorted(
corr_group.keys(),
) # Sort alphabetically (which works for c2_00001 format)
logger.debug(
f"Loading {len(valid_bin_indices)} correlation matrices corresponding to valid (q,phi) pairs",
)
c2_matrices_for_filtering = []
# Load only the correlation matrices that correspond to valid (q,phi) pairs
for bin_idx in valid_bin_indices:
if bin_idx < len(c2_keys):
key = c2_keys[bin_idx]
c2_half = corr_group[key][()] # Key is already a string
# Reconstruct full matrix from half matrix
c2_full = self._reconstruct_full_matrix(c2_half)
c2_matrices_for_filtering.append(c2_full)
else:
logger.warning(
f"Matrix index {bin_idx} exceeds available matrices ({len(c2_keys)})",
)
# Ensure we have consistent array sizes
min_count = min(len(c2_matrices_for_filtering), len(filtered_dqlist))
if len(c2_matrices_for_filtering) != len(filtered_dqlist):
n_matrices = len(c2_matrices_for_filtering)
n_pairs = len(filtered_dqlist)
n_discarded = abs(n_matrices - n_pairs)
logger.warning(
f"APS-U matrix/pair count mismatch: {n_matrices} matrices vs "
f"{n_pairs} (q,phi) pairs - truncating to {min_count} entries, "
f"discarding {n_discarded} unmatched {'matrices' if n_matrices > n_pairs else '(q,phi) pairs'}. "
"Check HDF5 file integrity."
)
c2_matrices_for_filtering = c2_matrices_for_filtering[:min_count]
filtered_dqlist = filtered_dqlist[:min_count]
filtered_dphilist = filtered_dphilist[:min_count]
# Apply comprehensive data filtering
logger.debug("Applying comprehensive data filtering")
selected_indices = self._get_selected_indices(
filtered_dqlist,
filtered_dphilist,
c2_matrices_for_filtering,
)
# Select optimal q-vector (closest match) from the filtered data
selected_q_idx = self._select_optimal_wavevector(filtered_dqlist)
selected_q = filtered_dqlist[selected_q_idx]
logger.debug(
f"Selected optimal q-vector: {selected_q:.6f} AA^-1 (index {selected_q_idx})",
)
# Find all (q,phi) pairs matching the selected q-vector
q_matching_indices = np.where(np.abs(filtered_dqlist - selected_q) < 1e-10)[
0
]
logger.debug(
f"Found {len(q_matching_indices)} (q,phi) pairs matching selected q-vector",
)
# If phi filtering was applied, intersect with q-vector selection
if selected_indices is not None:
# Keep only indices that match both q-vector selection AND phi filtering
final_indices = np.intersect1d(q_matching_indices, selected_indices)
logger.debug(
f"After intersecting with phi filtering: {len(final_indices)} pairs remain",
)
else:
# No phi filtering, use all pairs for selected q-vector
final_indices = q_matching_indices
logger.debug(
f"No phi filtering applied - using all {len(final_indices)} pairs for selected q-vector",
)
# Extract data for selected indices
if len(final_indices) == 0:
logger.warning(
"No valid indices found, using first available entry as fallback",
)
final_indices = np.array([0], dtype=int)
# Use final indices for both (q,phi) pairs and correlation matrices
final_dqlist = filtered_dqlist[final_indices]
final_dphilist = filtered_dphilist[final_indices]
c2_matrices = [c2_matrices_for_filtering[i] for i in final_indices]
logger.debug(f"Final selection: {len(c2_matrices)} correlation matrices")
# Convert to numpy array for frame slicing
c2_matrices_array = np.array(c2_matrices)
# Apply frame slicing to the selected q-vector data
c2_exp = self._apply_frame_slicing_to_selected_q(c2_matrices_array)
# Calculate 1D time array (meshgrids generated by NLSQ as needed, CMC uses 1D)
time_1d = self._calculate_time_arrays(c2_exp.shape[-1])
return {
"wavevector_q_list": final_dqlist,
"phi_angles_list": final_dphilist,
"t1": time_1d, # 1D time array starting from 0: [0, dt, 2*dt, ...]
"t2": time_1d.copy(), # Independent copy (prevent aliasing mutation)
"c2_exp": c2_exp,
}
def _reconstruct_full_matrix(self, c2_half: NDArray) -> NDArray:
"""Reconstruct full correlation matrix from half matrix (APS storage format).
Based on pyXPCSViewer's approach:
c2 = c2_half + c2_half.T
c2[diag] /= 2
Note: Diagonal correction is now applied post-load for consistent behavior.
"""
if not HAS_NUMPY:
raise RuntimeError("NumPy is required for matrix reconstruction")
c2_full = c2_half + c2_half.T
# Correct diagonal (was doubled in addition)
diag_indices = np.diag_indices(c2_half.shape[0])
c2_full[diag_indices] /= 2
return c2_full # type: ignore[no-any-return]
def _correct_diagonal(self, c2_mat: NDArray) -> NDArray:
"""Apply diagonal correction to correlation matrix.
.. deprecated:: 2.16.0
Use :func:`homodyne.core.diagonal_correction.apply_diagonal_correction`
instead. This method is kept for backward compatibility only.
Based on pyXPCSViewer's correct_diagonal_c2 function.
Handles both JAX and NumPy arrays.
"""
if not HAS_NUMPY:
raise RuntimeError("NumPy is required for diagonal correction")
size = c2_mat.shape[0]
side_band = c2_mat[(np.arange(size - 1), np.arange(1, size))]
# Create diagonal values using the same array type as input
if HAS_JAX and hasattr(c2_mat, "device"): # JAX array
diag_val = jnp.zeros(size, dtype=c2_mat.dtype)
diag_val = diag_val.at[:-1].add(side_band)
diag_val = diag_val.at[1:].add(side_band)
norm = jnp.ones(size, dtype=c2_mat.dtype)
norm = norm.at[1:-1].set(2)
# Update diagonal using JAX immutable operations
diag_indices = np.diag_indices(size)
c2_corrected = c2_mat.at[diag_indices].set(diag_val / norm) # type: ignore
return c2_corrected # type: ignore
else: # NumPy array
diag_val = np.zeros(size)
diag_val[:-1] += side_band
diag_val[1:] += side_band
norm = np.ones(size)
norm[1:-1] = 2
# Only copy if array is read-only (e.g., from mmap)
if not c2_mat.flags.writeable:
c2_corrected = c2_mat.copy()
else:
c2_corrected = c2_mat
# Use fill_diagonal for efficient in-place update
np.fill_diagonal(c2_corrected, diag_val / norm)
return c2_corrected
# Performance Optimization (Spec 006 - FR-006, FR-006a): Batch diagonal correction
def _correct_diagonal_batch(self, c2_matrices: NDArray) -> NDArray:
"""Apply diagonal correction to all matrices in batch.
.. deprecated:: 2.16.0
Use :func:`homodyne.core.diagonal_correction.apply_diagonal_correction_batch`
instead. This method is kept for backward compatibility only.
Performance Optimization (Spec 006 - FR-006):
Pre-allocates output array and uses direct assignment instead of
list append pattern. Expected memory reduction: 30%.
Args:
c2_matrices: Correlation matrices, shape (n_phi, n_t1, n_t2)
Returns:
Corrected matrices with same shape
"""
if not HAS_NUMPY:
raise RuntimeError("NumPy is required for diagonal correction")
n_phi = c2_matrices.shape[0]
size = c2_matrices.shape[1]
# FR-006: Pre-allocate output array (avoid list append)
if HAS_JAX and hasattr(c2_matrices, "device"):
# JAX path: use vmap for vectorized correction (FR-006a)
return self._correct_diagonal_batch_jax(c2_matrices) # type: ignore
else:
# NumPy path: pre-allocate and direct assignment
c2_corrected = np.empty_like(c2_matrices)
# Pre-compute normalization array (reused for all matrices)
norm = np.ones(size)
norm[1:-1] = 2
# Pre-compute index arrays
idx_upper = np.arange(size - 1)
idx_lower = np.arange(1, size)
diag_indices = np.diag_indices(size)
for i in range(n_phi):
c2_mat = c2_matrices[i]
# Extract side band values
side_band = c2_mat[(idx_upper, idx_lower)]
# Compute diagonal values
diag_val = np.zeros(size)
diag_val[:-1] += side_band
diag_val[1:] += side_band
# Copy and apply correction (direct assignment)
c2_corrected[i] = c2_mat.copy()
c2_corrected[i][diag_indices] = diag_val / norm
return c2_corrected
def _correct_diagonal_batch_jax(self, c2_matrices: Any) -> Any:
"""Vectorized diagonal correction using JAX vmap.
Performance Optimization (Spec 006 - FR-006a):
Uses jax.vmap for parallel diagonal correction across all angles.
Expected speedup: 2-4x for diagonal correction.
Args:
c2_matrices: JAX array of shape (n_phi, n_t1, n_t2)
Returns:
Corrected matrices with same shape
"""
if not HAS_JAX:
raise RuntimeError("JAX is required for JAX diagonal correction")
import jax
size = c2_matrices.shape[1]
# Pre-compute normalization and indices once
norm = jnp.ones(size)
norm = norm.at[1:-1].set(2)
idx_upper = jnp.arange(size - 1)
idx_lower = jnp.arange(1, size)
def correct_single(c2_mat: Any) -> Any:
"""Correct diagonal for a single matrix."""
# Extract side band
side_band = c2_mat[idx_upper, idx_lower]
# Compute diagonal values
diag_val = jnp.zeros(size, dtype=c2_mat.dtype)
diag_val = diag_val.at[:-1].add(side_band)
diag_val = diag_val.at[1:].add(side_band)
# Apply correction
diag_indices = jnp.diag_indices(size)
return c2_mat.at[diag_indices].set(diag_val / norm)
# Vectorize over all matrices
correct_all = jax.vmap(correct_single)
return correct_all(c2_matrices)
def _get_selected_indices(
self,
dqlist: NDArray,
dphilist: NDArray,
correlation_matrices: list[NDArray] | None = None,
) -> NDArray | None:
"""Get indices for comprehensive data filtering based on configuration.
Implements multi-criteria filtering including:
- Q-range filtering based on wavevector values
- Phi angle filtering (integrates with existing phi_filtering.py)
- Quality-based filtering using correlation matrix properties
- Frame-based filtering with configurable criteria
- Combined filtering with AND/OR logic
Args:
dqlist: Array of q-values (wavevector magnitudes)
dphilist: Array of phi angles in degrees
correlation_matrices: Optional list of correlation matrices for quality filtering
Returns:
Array of selected indices, or None if no filtering is applied
"""
try:
# Import filtering utilities
from homodyne.data.filtering_utils import DataFilteringError, XPCSDataFilter
# Check if filtering is enabled
filtering_config = self.config.get("data_filtering", {})
if not filtering_config.get("enabled", False):
logger.debug("Data filtering disabled in configuration")
return None
logger.info(
f"Applying comprehensive data filtering to {len(dqlist)} data points",
)
# Initialize data filter
data_filter = XPCSDataFilter(self.config)
# Apply comprehensive filtering
filtering_result = data_filter.apply_filtering(
dqlist,
dphilist,
correlation_matrices,
)
# Log filtering statistics
if filtering_result.filter_statistics:
logger.info("Filtering statistics:")
for filter_name, stats in filtering_result.filter_statistics.items():
if isinstance(stats, dict) and "selected_count" in stats:
logger.info(
f" {filter_name}: {stats['selected_count']} selected "
f"({stats.get('selection_fraction', 0.0):.2%})",
)
# Handle warnings and errors
if filtering_result.warnings:
for warning in filtering_result.warnings:
logger.warning(f"Data filtering warning: {warning}")
if filtering_result.errors:
for error in filtering_result.errors:
logger.error(f"Data filtering error: {error}")
if not filtering_result.fallback_used:
raise DataFilteringError(
f"Data filtering failed: {filtering_result.errors}",
)
# Log final result
if filtering_result.selected_indices is not None:
selected_count = len(filtering_result.selected_indices)
total_count = len(dqlist)
selection_fraction = (
selected_count / total_count if total_count > 0 else 0.0
)
logger.info(
f"Data filtering completed: {selected_count}/{total_count} "
f"data points selected ({selection_fraction:.2%})",
)
if filtering_result.fallback_used:
logger.warning("Filtering used fallback - all data points included")
# Additional integration with phi filtering for compatibility
selected_indices = self._integrate_with_phi_filtering(
filtering_result.selected_indices,
dphilist,
filtering_result,
)
return selected_indices
else:
logger.warning(
"No data filtering criteria matched - returning all angles. "
"Check filter configuration if this is unexpected."
)
return None
except ImportError as e:
logger.warning(
f"Filtering utilities not available: {e}. Skipping data filtering.",
)
return None
except (ValueError, TypeError, KeyError) as e:
logger.error(f"Data filtering failed: {e}")
# Check if we should fallback or raise
fallback_on_empty = filtering_config.get("fallback_on_empty", True)
if fallback_on_empty:
logger.warning("Falling back to no filtering due to error")
return None
else:
raise XPCSDataFormatError(f"Data filtering failed: {e}") from e
def _integrate_with_phi_filtering(
self,
selected_indices: NDArray,
dphilist: NDArray,
filtering_result: Any,
) -> NDArray:
"""Integrate with existing phi filtering system for backward compatibility.
This method ensures that the new filtering system works well with
existing phi angle filtering configurations and provides consistent results.
"""
try:
# Import existing phi filtering system
from homodyne.data.phi_filtering import PhiAngleFilter
# Check if phi filtering was already applied in the main filtering
if "phi_range" in filtering_result.filters_applied:
logger.debug("Phi filtering already applied in main filtering system")
return selected_indices
# Check for legacy phi filtering configuration
optimization_config = self.config.get("optimization_config", {})
angle_filtering = optimization_config.get("angle_filtering", {})
if not angle_filtering.get("enabled", False):
logger.debug("Legacy phi filtering not enabled")
return selected_indices
# Apply legacy phi filtering to already filtered data
selected_phi_angles = dphilist[selected_indices]
phi_filter = PhiAngleFilter(self.config)
phi_indices, filtered_angles = phi_filter.filter_angles_for_optimization(
selected_phi_angles,
)
# Map back to original indices
final_selected_indices = selected_indices[phi_indices]
logger.info(
f"Legacy phi filtering applied: {len(final_selected_indices)} "
f"out of {len(selected_indices)} filtered indices selected",
)
return final_selected_indices
except ImportError:
logger.debug(
"Phi filtering system not available - using original selection",
)
return selected_indices
except (TypeError, IndexError, KeyError) as e:
logger.warning(
f"Phi filtering integration failed: {e} - using original selection",
)
return selected_indices
def _select_optimal_wavevector(self, dqlist: NDArray) -> int:
"""Select q-vector index closest to config value (no tolerance).
Args:
dqlist: Array of available q-vector values
Returns:
Index of selected q-vector in dqlist
"""
if not HAS_NUMPY:
raise RuntimeError("NumPy is required for wavevector selection")
# Get target q-vector from configuration
scattering_config = self.analyzer_config.get("scattering", {})
config_q = scattering_config.get("wavevector_q", 0.0054)
logger.debug(f"Target q-vector: {config_q:.6f} A^-1")
# Find closest q-vector to target
closest_idx = int(np.argmin(np.abs(dqlist - config_q)))
selected_q = dqlist[closest_idx]
deviation = abs(selected_q - config_q)
logger.info(
f"Selected closest q-vector: {selected_q:.6f} AA^-1 (target: {config_q:.6f} AA^-1, index: {closest_idx}, deviation: {deviation:.6f} AA^-1)",
)
return closest_idx
def _apply_frame_slicing_to_selected_q(self, c2_matrices: NDArray) -> NDArray:
"""Apply frame slicing to already q-filtered correlation matrices.
Args:
c2_matrices: Correlation matrices for selected q-vector, shape (n_phi, full_frames, full_frames)
Returns:
Frame-sliced correlation matrices, shape (n_phi, sliced_frames, sliced_frames)
"""
raw_start_frame = self.analyzer_config.get("start_frame", 1)
if raw_start_frame < 1:
logger.warning(f"start_frame={raw_start_frame} < 1, clamping to 1")
raw_start_frame = 1
start_frame = raw_start_frame - 1 # Convert to 0-based indexing
end_frame = self.analyzer_config.get("end_frame", -1)
if end_frame < 0:
end_frame = c2_matrices.shape[-1]
# Validate frame parameters
max_frames = c2_matrices.shape[-1]
if start_frame < 0:
logger.warning(f"start_frame adjusted to 0 (was {start_frame + 1})")
start_frame = 0
if end_frame > max_frames:
original_end_frame = end_frame
end_frame = max_frames
logger.warning(
f"end_frame adjusted to {max_frames} (was {original_end_frame})"
)
# Apply frame slicing if needed
if start_frame > 0 or end_frame < max_frames:
c2_exp = c2_matrices[:, start_frame:end_frame, start_frame:end_frame]
sliced_frames = end_frame - start_frame
logger.debug(
f"Applied frame slicing: [{start_frame}:{end_frame}] -> shape {c2_exp.shape}",
)
logger.debug(
f"Frame reduction: {max_frames}x{max_frames} -> {sliced_frames}x{sliced_frames}",
)
else:
c2_exp = c2_matrices
logger.debug("No frame slicing needed - using full range")
return c2_exp
def _calculate_time_arrays(self, matrix_size: int) -> NDArray:
"""Calculate 1D time array for correlation analysis.
Returns 1D array that can be used directly by CMC (element-wise analysis)
or converted to 2D meshgrids by NLSQ wrapper as needed.
Time starts from 0 (frame 0 corresponds to t=0). The t=0 exclusion
for D(t) singularity prevention is handled during analysis, not caching.
Args:
matrix_size: Number of time points (frames after slicing)
Returns:
1D time array: [0, dt, 2*dt, ..., (N-1)*dt]
"""
dt = self.analyzer_config.get("dt", 1.0)
# Create 1D time array starting from 0
# Last point at index (N-1), not N
time_max = dt * (matrix_size - 1)
time_1d = np.linspace(0, time_max, matrix_size)
return time_1d
@log_performance(threshold=0.3)
def _save_to_cache(self, data: dict[str, Any], cache_path: str) -> None:
"""Save processed data to NPZ cache file with q-vector metadata."""
if not HAS_NUMPY:
raise RuntimeError("NumPy is required for cache saving")
# Ensure cache directory exists
cache_dir = os.path.dirname(cache_path)
if cache_dir:
os.makedirs(cache_dir, exist_ok=True)
# Convert JAX arrays back to numpy for caching
cache_data: dict[str, Any] = {}
for key, value in data.items():
if HAS_JAX and hasattr(value, "device"): # JAX array
cache_data[key] = np.array(value)
else:
cache_data[key] = value
# Add cache metadata for q-vector validation
scattering_config = self.analyzer_config.get("scattering", {})
config_q = scattering_config.get("wavevector_q", 0.0054)
# Calculate actual q-vector stats from cached data
q_values = cache_data["wavevector_q_list"]
# Use nan-safe variants: q_values from HDF5 may contain NaN for bad pixels.
actual_q = float(np.nanmean(q_values)) if len(q_values) > 0 else config_q
q_variance = float(np.nanstd(q_values)) if len(q_values) > 1 else 0.0
cache_metadata = {
"config_wavevector_q": float(config_q),
"actual_wavevector_q": actual_q,
"q_variance": q_variance,
"q_count": len(q_values),
"start_frame": self.analyzer_config.get("start_frame", 1),
# Normalize end_frame=-1 sentinel to actual last frame
"end_frame": (
self.analyzer_config.get("end_frame")
if self.analyzer_config.get("end_frame", -1) != -1
else (
cache_data["c2_exp"].shape[-1]
+ self.analyzer_config.get("start_frame", 1)
- 1
)
),
"phi_count": len(cache_data["phi_angles_list"]),
"cache_version": "2.0",
"selective_q_caching": True,
}
# Metadata is stored as a JSON-encoded scalar (not via object serialization)
# so the loader can read it with allow_pickle=False.
cache_data["cache_metadata_json"] = np.asarray(json.dumps(cache_metadata))
# Save with compression if specified
if self.exp_config.get("cache_compression", True):
np.savez_compressed(cache_path, **cache_data)
else:
np.savez(cache_path, **cache_data)
# Log cache statistics
file_size_mb = os.path.getsize(cache_path) / (1024 * 1024)
logger.info(f"Cache saved: {cache_path}")
logger.info(
f"Cache size: {file_size_mb:.2f} MB, Q-vectors: {cache_metadata['q_count']}, Phi angles: {cache_metadata['phi_count']}",
)
logger.debug(f"Q-vector: {actual_q:.6f} +/- {q_variance:.6f} A^-1")
def _validate_cache_q_vector(self, cache_metadata: dict[str, Any]) -> None:
"""Validate that cached q-vector is compatible with current configuration."""
scattering_config = self.analyzer_config.get("scattering", {})
current_config_q = scattering_config.get("wavevector_q", 0.0054)
cached_config_q = cache_metadata.get("config_wavevector_q", current_config_q)
# Check if configuration q-vectors match (within floating point precision)
if abs(current_config_q - cached_config_q) > 1e-8:
logger.warning(
f"Cache q-vector mismatch: current={current_config_q:.6f}, cached={cached_config_q:.6f} AA^-1",
)
# Check if cache uses selective q-caching (v2.0 feature)
is_selective = cache_metadata.get("selective_q_caching", False)
if not is_selective:
logger.warning(
"Loading legacy cache without selective q-vector optimization",
)
else:
actual_q = cache_metadata.get("actual_wavevector_q", cached_config_q)
q_variance = cache_metadata.get("q_variance", 0.0)
logger.debug(
f"Validated selective cache: q={actual_q:.6f} +/- {q_variance:.6f} AA^-1",
)
def _generate_cache_path(self) -> Path:
"""Generate cache file path based on current configuration."""
# Get data folder and cache configuration
data_folder = self.exp_config.get("data_folder_path", "./data/")
cache_folder = self.exp_config.get("cache_file_path", data_folder)
# Get frame parameters
start_frame = self.analyzer_config.get("start_frame", 1)
end_frame = self.analyzer_config.get("end_frame", 8000)
# Get wavevector_q for cache filename
scattering_config = self.analyzer_config.get("scattering", {})
wavevector_q = scattering_config.get("wavevector_q", 0.0054)
# Construct cache filename (using string.Template for safety)
cache_template = _migrate_cache_template(
self.exp_config.get(
"cache_filename_template",
"cached_c2_frames_${start_frame}_${end_frame}.npz",
)
)
tmpl = string.Template(cache_template)
cache_filename = tmpl.safe_substitute(
start_frame=start_frame,
end_frame=end_frame,
wavevector_q=f"{wavevector_q:.4f}",
)
if os.sep in cache_filename or ".." in cache_filename:
raise ValueError(f"Unsafe cache filename from template: {cache_filename!r}")
return Path(str(cache_folder)) / cache_filename # type: ignore[no-any-return]
@log_performance(threshold=0.1)
def _save_text_files(self, data: dict[str, Any]) -> None:
"""Save phi_angles and wavevector_q lists to text files."""
# Get output directory
phi_folder = self.exp_config.get("phi_angles_path", "./")
data_folder = self.exp_config.get("data_folder_path", "./")
# Convert JAX arrays to numpy for text file saving
phi_angles = (
np.array(data["phi_angles_list"]) if HAS_JAX else data["phi_angles_list"]
)
q_values = (
np.array(data["wavevector_q_list"])
if HAS_JAX
else data["wavevector_q_list"]
)
# Save phi angles list
phi_file = os.path.join(phi_folder, "phi_angles_list.txt")
phi_dir = os.path.dirname(phi_file)
if phi_dir:
os.makedirs(phi_dir, exist_ok=True)
try:
np.savetxt(
phi_file,
phi_angles,
fmt="%.6f",
header="Phi angles (degrees)",
comments="# ",
)
# Save wavevector q list
q_file = os.path.join(data_folder, "wavevector_q_list.txt")
np.savetxt(
q_file,
q_values,
fmt="%.8e",
header="Wavevector q (1/Angstrom)",
comments="# ",
)
logger.debug(f"Text files saved: {phi_file}, {q_file}")
except OSError as e:
logger.warning(f"Could not save text files (non-fatal): {e}")
def _validate_loaded_data(
self,
data: dict[str, Any],
validation_settings: dict[str, bool],
) -> None:
"""Perform validation on loaded data.
Args:
data: Loaded data dictionary
validation_settings: Validation configuration
"""
if validation_settings.get("physics_checks", False):
self._perform_physics_validation(data)
if validation_settings.get("data_quality", False):
self._perform_data_quality_checks(
data,
validation_settings.get("comprehensive", False),
)
def _perform_physics_validation(self, data: dict[str, Any]) -> None:
"""Perform physics-based validation using v2 PhysicsConstants."""
if not HAS_PHYSICS_VALIDATION:
logger.warning(
"Physics validation requested but v2 physics module not available",
)
return
# Validate q-range
q_values = (
np.array(data["wavevector_q_list"])
if HAS_JAX
else data["wavevector_q_list"]
)
if np.any(q_values < PhysicsConstants.Q_MIN_TYPICAL):
logger.warning(
f"Some q-values below typical range: {PhysicsConstants.Q_MIN_TYPICAL}",
)
if np.any(q_values > PhysicsConstants.Q_MAX_TYPICAL):
logger.warning(
f"Some q-values above typical range: {PhysicsConstants.Q_MAX_TYPICAL}",
)
# Validate time parameters
dt = self.analyzer_config.get("dt", 1.0)
if dt < PhysicsConstants.TIME_MIN_XPCS:
logger.warning(
f"Time step dt={dt}s below typical XPCS minimum: {PhysicsConstants.TIME_MIN_XPCS}s",
)
logger.info("Physics validation completed")
def _perform_data_quality_checks(
self,
data: dict[str, Any],
comprehensive: bool = False,
) -> None:
"""Perform data quality validation."""
c2_exp = np.array(data["c2_exp"]) if HAS_JAX else data["c2_exp"]
# Basic checks
if np.any(~np.isfinite(c2_exp)):
logger.error("Correlation data contains non-finite values (NaN or Inf)")
if np.any(c2_exp < 0):
logger.warning("Correlation data contains negative values")
# Check for reasonable correlation values (should be around 1.0 at t=0)
diagonal_values = np.array([c2_exp[i].diagonal() for i in range(len(c2_exp))])
mean_diagonal = np.nanmean(diagonal_values[:, 0]) # t=0 correlation
if not (0.5 < mean_diagonal < 2.0):
logger.warning(
f"Unusual t=0 correlation value: {mean_diagonal:.3f} (expected ~1.0)",
)
if comprehensive:
# Additional comprehensive checks
logger.info("Performing comprehensive data quality analysis...")
# Check correlation decay
decay_rates = []
for i in range(len(c2_exp)):
diag = c2_exp[i].diagonal()
if len(diag) > 10:
decay_rate = (diag[0] - diag[10]) / diag[0]
decay_rates.append(decay_rate)
if decay_rates:
mean_decay = np.nanmean(decay_rates)
logger.info(
f"Mean correlation decay over 10 time steps: {mean_decay:.3f}",
)
logger.info("Data quality validation completed")
def _initialize_quality_control(self) -> Any | None:
"""Initialize quality control system if enabled."""
try:
quality_config = self.config.get("quality_control", {})
if not quality_config.get("enabled", False):
logger.debug("Quality control disabled in configuration")
return None
# Import quality control system
from homodyne.data.quality_controller import (
DataQualityController,
QualityControlStage,
)
logger.info("Initializing data quality control system")
controller = DataQualityController(self.config)
# Store reference to stage enum for convenience
controller.QualityControlStage = QualityControlStage # type: ignore
return controller
except ImportError as e:
logger.warning(f"Quality control system not available: {e}")
return None
except (ValueError, KeyError, AttributeError, TypeError) as e:
# Narrowed from broad Exception: only catch configuration/setup errors.
# MemoryError, SystemExit, KeyboardInterrupt must propagate.
logger.error(f"Failed to initialize quality control: {e}")
return None
def _get_quality_report_path(self) -> str:
"""Generate path for quality control report."""
data_folder = self.exp_config.get("data_folder_path", "./")
data_file = self.exp_config.get("data_file_name", "unknown")
data_file_base = os.path.splitext(data_file)[0]
# Create quality reports subdirectory
quality_dir = os.path.join(data_folder, "quality_reports")
os.makedirs(quality_dir, exist_ok=True)
# Generate filename with timestamp
timestamp = int(time.time())
quality_filename = f"{data_file_base}_quality_report_{timestamp}.json"
return os.path.join(quality_dir, quality_filename)
@log_performance(threshold=0.5)
def _apply_preprocessing_pipeline(
self,
data: dict[str, Any],
quality_controller: Any | None = None,
quality_results: list | None = None,
) -> dict[str, Any]:
"""Apply preprocessing pipeline to loaded data if enabled.
Args:
data: Raw data loaded from HDF5 files
Returns:
Processed data after applying preprocessing pipeline
"""
try:
# Check if preprocessing is enabled
preprocessing_config = self.config.get("preprocessing", {})
if not preprocessing_config.get("enabled", False):
logger.debug("Preprocessing pipeline disabled in configuration")
return data
logger.info("Applying preprocessing pipeline to loaded data")
# Import preprocessing pipeline
from homodyne.data.preprocessing import PreprocessingPipeline
# Create and execute preprocessing pipeline
pipeline = PreprocessingPipeline(self.config)
result = pipeline.process(data)
if result.success:
logger.info("Preprocessing pipeline completed successfully")
logger.info(f"Pipeline stages executed: {len(result.stage_results)}")
# Log stage results
successful_stages = sum(result.stage_results.values())
total_stages = len(result.stage_results)
logger.info(f"Successful stages: {successful_stages}/{total_stages}")
# Quality control validation after preprocessing
if quality_controller and quality_results:
preprocessing_validation_result = (
quality_controller.validate_data_stage(
result.data,
quality_controller.QualityControlStage.PREPROCESSED_DATA,
previous_result=(
quality_results[-1] if quality_results else None
),
)
)
quality_results.append(preprocessing_validation_result)
if not preprocessing_validation_result.passed:
logger.warning(
f"Preprocessing quality validation failed: score={preprocessing_validation_result.metrics.overall_score:.1f}",
)
# Save provenance if requested
if preprocessing_config.get("save_provenance", False):
provenance_path = self._get_provenance_path()
pipeline.save_provenance(result.provenance, provenance_path)
# Log warnings if any
if result.provenance.warnings:
for warning in result.provenance.warnings:
logger.warning(f"Preprocessing warning: {warning}")
return result.data
else:
logger.error("Preprocessing pipeline failed")
# Log errors
for error in result.provenance.errors:
logger.error(f"Preprocessing error: {error}")
# Return original data if fallback is enabled
if preprocessing_config.get("fallback_on_failure", True):
logger.warning(
"Falling back to original data after preprocessing failure",
)
return data
else:
raise XPCSDataFormatError(
"Preprocessing pipeline failed and fallback disabled",
)
except ImportError as e:
logger.warning(
f"Preprocessing pipeline not available: {e}. Using original data.",
)
return data
except (ValueError, KeyError, IndexError, RuntimeError) as e:
# Narrowed from broad Exception: only catch expected processing errors.
# Programming bugs (AttributeError, TypeError) and system errors
# (MemoryError, KeyboardInterrupt) must propagate without swallowing.
logger.error(f"Unexpected error in preprocessing pipeline: {e}")
# Check fallback setting
preprocessing_config = self.config.get("preprocessing", {})
if preprocessing_config.get("fallback_on_failure", True):
logger.warning(
"Falling back to original data after preprocessing error",
)
return data
else:
raise XPCSDataFormatError(f"Preprocessing pipeline failed: {e}") from e
def _get_provenance_path(self) -> str:
"""Generate path for saving preprocessing provenance."""
# Use data folder as base
data_folder = self.exp_config.get("data_folder_path", "./")
# Create provenance subdirectory
provenance_dir = os.path.join(data_folder, "preprocessing_provenance")
os.makedirs(provenance_dir, exist_ok=True)
# Generate filename based on data file and timestamp
data_file = self.exp_config.get("data_file_name", "unknown")
data_file_base = os.path.splitext(data_file)[0]
timestamp = int(time.time())
provenance_filename = (
f"{data_file_base}_preprocessing_provenance_{timestamp}.json"
)
return os.path.join(provenance_dir, provenance_filename)
# Convenience function for simple usage
[docs]
@log_performance(threshold=1.0)
def load_xpcs_data(
config_path: str | dict | None = None,
config_dict: dict | None = None,
) -> dict[str, Any]:
"""Convenience function to load XPCS data from configuration file or dict.
Supports both YAML and JSON configuration files with auto-detection,
or direct configuration dictionary for programmatic use (backward compatible).
Args:
config_path: Path to YAML/JSON config file, OR dict for backward compatibility
config_dict: Configuration dictionary (alternative to config_path)
Returns:
Dictionary containing loaded experimental data with JAX arrays when available
Example:
>>> # From config file
>>> data = load_xpcs_data(config_path="xpcs_config.yaml")
>>> print(data.keys())
dict_keys(['wavevector_q_list', 'phi_angles_list', 't1', 't2', 'c2_exp'])
>>> # From dict (backward compatible - positional)
>>> config = {"data_file": "experiment.h5", "analysis_mode": "static_isotropic"}
>>> data = load_xpcs_data(config)
>>> # From dict (keyword argument)
>>> data = load_xpcs_data(config_dict=config)
"""
# Backward compatibility: if config_path is a dict, treat it as config_dict
if isinstance(config_path, dict):
if config_dict is not None:
raise ValueError(
"Cannot provide both config_path as dict and config_dict parameter"
)
config_dict = config_path
config_path = None
loader = XPCSDataLoader(config_path=config_path, config_dict=config_dict)
return loader.load_experimental_data()
# Export main classes and functions
__all__ = [
"XPCSDataLoader",
"load_xpcs_data",
"XPCSDataFormatError",
"XPCSDependencyError",
"XPCSConfigurationError",
"load_xpcs_config",
]