homodyne.core.jax_backend¶
The jax_backend module provides the JIT-compiled computational core for
homodyne scattering analysis. All physics functions in this module are
compiled with jax.jit and support automatic differentiation for NLSQ
Jacobian evaluation and NUTS leapfrog integration.
The physical model computed here is:
Warning
This module uses jnp.where(x > eps, x, eps) instead of
jnp.maximum(x, eps) for positivity floors. jnp.maximum zeros
the gradient below the floor, stalling NLSQ Jacobian and NUTS leapfrog.
See Anti-Degeneracy Defense System for details.
Backend Availability¶
Flag |
Description |
|---|---|
|
|
|
|
g₁ Correlation Functions¶
compute_g1_diffusion¶
- homodyne.core.jax_backend.compute_g1_diffusion(params, t1, t2, q, dt=None)[source]
Wrapper function that computes g1 diffusion using configuration dt.
IMPORTANT: The dt parameter should come from configuration, not be computed.
- Parameters:
params (
Array) – Physical parameters [D0, alpha, D_offset, …]t1 (
Array) – Time grids (should be identical: t1 = t2 = t)t2 (
Array) – Time grids (should be identical: t1 = t2 = t)q (
float) – Scattering wave vector magnitudedt (
float|None) – Time step from configuration (REQUIRED for correct physics)
- Return type:
- Returns:
Diffusion contribution to g1 correlation function
compute_g1_shear¶
- homodyne.core.jax_backend.compute_g1_shear(params, t1, t2, phi, q, L, dt)[source]
Wrapper function that computes g1 shear using configuration dt.
IMPORTANT: The dt parameter MUST come from configuration. No fallback estimation - explicit dt is required for correct physics.
- Parameters:
params (
Array) – Physical parameters [D0, alpha, D_offset, gamma_dot_t0, beta, gamma_dot_t_offset, phi0]t1 (
Array) – Time grids (should be identical: t1 = t2 = t)t2 (
Array) – Time grids (should be identical: t1 = t2 = t)phi (
Array) – Scattering anglesq (
float) – Scattering wave vector magnitudeL (
float) – Sample-detector distance (stator_rotor_gap)dt (
float) – Time step from configuration [s] (REQUIRED)
- Return type:
- Returns:
Shear contribution to g1 correlation function (sinc² values)
- Raises:
TypeError – If dt is None (no longer accepts None)
ValueError – If dt <= 0 or not finite
compute_g1_total¶
- homodyne.core.jax_backend.compute_g1_total(params, t1, t2, phi, q, L, dt)[source]
Wrapper function that computes total g1 using configuration dt.
IMPORTANT: The dt parameter MUST come from configuration. No fallback estimation - explicit dt is required for correct physics.
- Parameters:
params (
Array) – Physical parameters [D0, alpha, D_offset, gamma_dot_t0, beta, gamma_dot_t_offset, phi0]t1 (
Array) – Time grids (should be identical: t1 = t2 = t)t2 (
Array) – Time grids (should be identical: t1 = t2 = t)phi (
Array) – Scattering anglesq (
float) – Scattering wave vector magnitudeL (
float) – Sample-detector distance (stator_rotor_gap)dt (
float|None) – Time step from configuration [s] (REQUIRED)
- Return type:
- Returns:
Total g1 correlation function with shape (n_phi, n_times, n_times)
- Raises:
TypeError – If dt is None (no longer accepts None)
ValueError – If dt <= 0 or not finite
g₂ Correlation Functions¶
compute_g2_scaled¶
Primary scalar entry point for NLSQ optimization.
- homodyne.core.jax_backend.compute_g2_scaled(params, t1, t2, phi, q, L, contrast, offset, dt)[source]
Wrapper function that computes g2 using configuration dt.
IMPORTANT: The dt parameter MUST come from configuration. No fallback estimation - explicit dt is required for correct physics.
- Parameters:
params (
Array) – Physical parameters [D0, alpha, D_offset, gamma_dot_t0, beta, gamma_dot_t_offset, phi0]t1 (
Array) – Time points for correlation calculationt2 (
Array) – Time points for correlation calculationphi (
Array) – Scattering anglesq (
float) – Scattering wave vector magnitudeL (
float) – Sample-detector distance (stator_rotor_gap)contrast (
float) – Contrast parameter (β in literature)offset (
float) – Baseline offsetdt (
float|None) – Time step from configuration [s] (REQUIRED)
- Return type:
- Returns:
g2 correlation function with scaled fitting and physical bounds applied
- Raises:
TypeError – If dt is None (no longer accepts None)
ValueError – If dt <= 0 or not finite
compute_g2_scaled_with_factors¶
JIT-compiled variant accepting pre-computed physics factors. Used by
HomodyneModel to avoid redundant
factor computation in tight loops.
- homodyne.core.jax_backend.compute_g2_scaled_with_factors(params, t1, t2, phi, wavevector_q_squared_half_dt, sinc_prefactor, contrast, offset, dt)[source]
JIT-optimized g2 computation using pre-computed physics factors.
This is the hybrid architecture functional core - accepts pre-computed factors directly, avoiding runtime computation. Suitable for use with HomodyneModel where factors are computed once at initialization.
- Parameters:
params (
Array) – Physical parameters [D0, alpha, D_offset, gamma_dot_t0, beta, gamma_dot_t_offset, phi0]t1 (
Array) – Time grids for correlation calculationt2 (
Array) – Time grids for correlation calculationphi (
Array) – Scattering angles [degrees]wavevector_q_squared_half_dt (
float) – Pre-computed factor (0.5 * q² * dt)sinc_prefactor (
float) – Pre-computed factor (q * L * dt / 2π)contrast (
float) – Contrast parameter (β in literature)offset (
float) – Baseline offsetdt (
float) – Time step from experimental configuration (time per frame) [seconds]
- Return type:
- Returns:
g2 correlation function with scaled fitting
Note
This function is JIT-compiled for maximum performance. Use with HomodyneModel for best results.
Chi-Squared¶
compute_chi_squared¶
- homodyne.core.jax_backend.compute_chi_squared(params, data, sigma, t1, t2, phi, q, L, contrast, offset, dt)[source]
Compute chi-squared goodness of fit.
χ² = Σᵢ [(data_i - theory_i) / σᵢ]²
- Parameters:
params (
Array) – Physical parametersdata (
Array) – Experimental correlation datasigma (
Array) – Measurement uncertaintiest1 (
Array) – Time gridst2 (
Array) – Time gridsphi (
Array) – Angle gridq (
float) – Wave vector magnitudeL (
float) – Sample-detector distancecontrast (
float) – Scaling parametersoffset (
float) – Scaling parametersdt (
float) – Time step from configuration
- Return type:
- Returns:
Chi-squared value
Batched Computations¶
vectorized_g2_computation¶
- homodyne.core.jax_backend.vectorized_g2_computation(params_batch, t1, t2, phi, q, L, contrast, offset, dt=None)[source]
Vectorized g2 computation for multiple parameter sets.
Uses JAX vmap for efficient parallel computation.
- Parameters:
params_batch (
Array) – Batch of parameter arrays, shape (n_batch, n_params)t1 (
Array) – Time arrays for correlation calculationt2 (
Array) – Time arrays for correlation calculationphi (
Array) – Scattering anglesq (
float) – Wavevector magnitude [Å⁻¹]L (
float) – Beam width [Å]contrast (
float) – Contrast parameteroffset (
float) – Baseline offsetdt (
float|None) – Time step from configuration [seconds]. MUST be provided for correct physics.
- Return type:
batch_chi_squared¶
- homodyne.core.jax_backend.batch_chi_squared(params_batch, data, sigma, t1, t2, phi, q, L, contrast, offset, dt=None)[source]
Compute chi-squared for multiple parameter sets efficiently.
- Parameters:
params_batch (
Array) – Batch of parameter arrays, shape (n_batch, n_params)data (
Array) – Experimental g2 datasigma (
Array) – Uncertainty in datat1 (
Array) – Time arrays for correlation calculationt2 (
Array) – Time arrays for correlation calculationphi (
Array) – Scattering anglesq (
float) – Wavevector magnitude [Å⁻¹]L (
float) – Beam width [Å]contrast (
float) – Contrast parameteroffset (
float) – Baseline offsetdt (
float|None) – Time step from configuration [seconds]. MUST be provided for correct physics.
- Return type:
Automatic Differentiation¶
These are pre-built JIT-compiled derivative functions:
Function |
Derivative |
Description |
|---|---|---|
|
|
Gradient of g₂ w.r.t. params (argnums=0) |
|
|
Hessian of g₂ w.r.t. params |
|
|
Gradient of χ² w.r.t. params |
|
|
Hessian of χ² w.r.t. params |
Meshgrid Cache¶
The module maintains an LRU cache for time-grid meshgrids to avoid
recomputing (t1_grid, t2_grid) on every function call.
- homodyne.core.jax_backend.get_cached_meshgrid(t1, t2)[source]
Get or create cached meshgrid for time arrays.
For repeated calls with the same time arrays (common in optimization loops), this avoids recreating the same meshgrid ~23 times per iteration (once per phi).
When called inside a JIT context (traced arrays), caching is skipped and meshgrid is created directly (the JIT will handle caching via tracing).
Performance Optimization (Spec 006 - FR-010, T041): Increments hit/miss counters for cache monitoring.
- homodyne.core.jax_backend.clear_meshgrid_cache()[source]
Clear the meshgrid cache.
Call this when switching between datasets or when memory is constrained.
- Return type:
- homodyne.core.jax_backend.get_cache_stats()[source]
Get meshgrid cache statistics.
Performance Optimization (Spec 006 - FR-010, T042): Returns cache hit/miss statistics for monitoring and optimization.
- Returns:
hits: Number of cache hits
misses: Number of cache misses
evictions: Number of cache evictions
skipped_large: Arrays too large for caching
skipped_traced: Skipped due to JIT tracing
hit_rate: Cache hit rate (hits / total lookups)
cache_size: Current number of cached entries
- Return type:
Diagnostics¶
validate_backend¶
get_device_info¶
get_performance_summary¶
Usage Examples¶
Computing g₂ for a single angle¶
from homodyne.core.jax_backend import compute_g2_scaled
import jax.numpy as jnp
params = jnp.array([19231.0, 1.5, 0.1, 0.003, 0.8, 0.0, 0.0])
t = jnp.linspace(0, 0.1, 100)
t1, t2 = jnp.meshgrid(t, t, indexing="ij")
g2 = compute_g2_scaled(
params, t1, t2, phi=jnp.array([0.0]),
q=0.01, L=2_000_000.0,
contrast=0.5, offset=1.0, dt=0.001,
)
Computing gradients for optimization¶
from homodyne.core.jax_backend import gradient_g2
grad_params = gradient_g2(
params, t1, t2, phi=jnp.array([0.0]),
q=0.01, L=2_000_000.0,
contrast=0.5, offset=1.0, dt=0.001,
)
print(f"Gradient shape: {grad_params.shape}") # (7,)
Cache management¶
from homodyne.core.jax_backend import get_cache_stats, clear_meshgrid_cache
stats = get_cache_stats()
print(f"Hit rate: {stats['hit_rate']:.1%}")
# Clear when switching datasets
clear_meshgrid_cache()