homodyne.core.jax_backend

The jax_backend module provides the JIT-compiled computational core for homodyne scattering analysis. All physics functions in this module are compiled with jax.jit and support automatic differentiation for NLSQ Jacobian evaluation and NUTS leapfrog integration.

The physical model computed here is:

\[g_2(\phi, t_1, t_2) = \text{offset} + \text{contrast} \times [g_1(\phi, t_1, t_2)]^2\]

Warning

This module uses jnp.where(x > eps, x, eps) instead of jnp.maximum(x, eps) for positivity floors. jnp.maximum zeros the gradient below the floor, stalling NLSQ Jacobian and NUTS leapfrog. See Anti-Degeneracy Defense System for details.


Backend Availability

Flag

Description

jax_available

True if JAX imported successfully; all JIT functions are active

numpy_gradients_available

True only when JAX is absent AND NumPy gradient fallback is importable


g₁ Correlation Functions

compute_g1_diffusion

homodyne.core.jax_backend.compute_g1_diffusion(params, t1, t2, q, dt=None)[source]

Wrapper function that computes g1 diffusion using configuration dt.

IMPORTANT: The dt parameter should come from configuration, not be computed.

Parameters:
  • params (Array) – Physical parameters [D0, alpha, D_offset, …]

  • t1 (Array) – Time grids (should be identical: t1 = t2 = t)

  • t2 (Array) – Time grids (should be identical: t1 = t2 = t)

  • q (float) – Scattering wave vector magnitude

  • dt (float | None) – Time step from configuration (REQUIRED for correct physics)

Return type:

Array

Returns:

Diffusion contribution to g1 correlation function

compute_g1_shear

homodyne.core.jax_backend.compute_g1_shear(params, t1, t2, phi, q, L, dt)[source]

Wrapper function that computes g1 shear using configuration dt.

IMPORTANT: The dt parameter MUST come from configuration. No fallback estimation - explicit dt is required for correct physics.

Parameters:
  • params (Array) – Physical parameters [D0, alpha, D_offset, gamma_dot_t0, beta, gamma_dot_t_offset, phi0]

  • t1 (Array) – Time grids (should be identical: t1 = t2 = t)

  • t2 (Array) – Time grids (should be identical: t1 = t2 = t)

  • phi (Array) – Scattering angles

  • q (float) – Scattering wave vector magnitude

  • L (float) – Sample-detector distance (stator_rotor_gap)

  • dt (float) – Time step from configuration [s] (REQUIRED)

Return type:

Array

Returns:

Shear contribution to g1 correlation function (sinc² values)

Raises:

compute_g1_total

homodyne.core.jax_backend.compute_g1_total(params, t1, t2, phi, q, L, dt)[source]

Wrapper function that computes total g1 using configuration dt.

IMPORTANT: The dt parameter MUST come from configuration. No fallback estimation - explicit dt is required for correct physics.

Parameters:
  • params (Array) – Physical parameters [D0, alpha, D_offset, gamma_dot_t0, beta, gamma_dot_t_offset, phi0]

  • t1 (Array) – Time grids (should be identical: t1 = t2 = t)

  • t2 (Array) – Time grids (should be identical: t1 = t2 = t)

  • phi (Array) – Scattering angles

  • q (float) – Scattering wave vector magnitude

  • L (float) – Sample-detector distance (stator_rotor_gap)

  • dt (float | None) – Time step from configuration [s] (REQUIRED)

Return type:

Array

Returns:

Total g1 correlation function with shape (n_phi, n_times, n_times)

Raises:

g₂ Correlation Functions

compute_g2_scaled

Primary scalar entry point for NLSQ optimization.

homodyne.core.jax_backend.compute_g2_scaled(params, t1, t2, phi, q, L, contrast, offset, dt)[source]

Wrapper function that computes g2 using configuration dt.

IMPORTANT: The dt parameter MUST come from configuration. No fallback estimation - explicit dt is required for correct physics.

Parameters:
  • params (Array) – Physical parameters [D0, alpha, D_offset, gamma_dot_t0, beta, gamma_dot_t_offset, phi0]

  • t1 (Array) – Time points for correlation calculation

  • t2 (Array) – Time points for correlation calculation

  • phi (Array) – Scattering angles

  • q (float) – Scattering wave vector magnitude

  • L (float) – Sample-detector distance (stator_rotor_gap)

  • contrast (float) – Contrast parameter (β in literature)

  • offset (float) – Baseline offset

  • dt (float | None) – Time step from configuration [s] (REQUIRED)

Return type:

Array

Returns:

g2 correlation function with scaled fitting and physical bounds applied

Raises:

compute_g2_scaled_with_factors

JIT-compiled variant accepting pre-computed physics factors. Used by HomodyneModel to avoid redundant factor computation in tight loops.

homodyne.core.jax_backend.compute_g2_scaled_with_factors(params, t1, t2, phi, wavevector_q_squared_half_dt, sinc_prefactor, contrast, offset, dt)[source]

JIT-optimized g2 computation using pre-computed physics factors.

This is the hybrid architecture functional core - accepts pre-computed factors directly, avoiding runtime computation. Suitable for use with HomodyneModel where factors are computed once at initialization.

Parameters:
  • params (Array) – Physical parameters [D0, alpha, D_offset, gamma_dot_t0, beta, gamma_dot_t_offset, phi0]

  • t1 (Array) – Time grids for correlation calculation

  • t2 (Array) – Time grids for correlation calculation

  • phi (Array) – Scattering angles [degrees]

  • wavevector_q_squared_half_dt (float) – Pre-computed factor (0.5 * q² * dt)

  • sinc_prefactor (float) – Pre-computed factor (q * L * dt / 2π)

  • contrast (float) – Contrast parameter (β in literature)

  • offset (float) – Baseline offset

  • dt (float) – Time step from experimental configuration (time per frame) [seconds]

Return type:

Array

Returns:

g2 correlation function with scaled fitting

Note

This function is JIT-compiled for maximum performance. Use with HomodyneModel for best results.


Chi-Squared

compute_chi_squared

homodyne.core.jax_backend.compute_chi_squared(params, data, sigma, t1, t2, phi, q, L, contrast, offset, dt)[source]

Compute chi-squared goodness of fit.

χ² = Σᵢ [(data_i - theory_i) / σᵢ]²

Parameters:
  • params (Array) – Physical parameters

  • data (Array) – Experimental correlation data

  • sigma (Array) – Measurement uncertainties

  • t1 (Array) – Time grids

  • t2 (Array) – Time grids

  • phi (Array) – Angle grid

  • q (float) – Wave vector magnitude

  • L (float) – Sample-detector distance

  • contrast (float) – Scaling parameters

  • offset (float) – Scaling parameters

  • dt (float) – Time step from configuration

Return type:

Array

Returns:

Chi-squared value


Batched Computations

vectorized_g2_computation

homodyne.core.jax_backend.vectorized_g2_computation(params_batch, t1, t2, phi, q, L, contrast, offset, dt=None)[source]

Vectorized g2 computation for multiple parameter sets.

Uses JAX vmap for efficient parallel computation.

Parameters:
  • params_batch (Array) – Batch of parameter arrays, shape (n_batch, n_params)

  • t1 (Array) – Time arrays for correlation calculation

  • t2 (Array) – Time arrays for correlation calculation

  • phi (Array) – Scattering angles

  • q (float) – Wavevector magnitude [Å⁻¹]

  • L (float) – Beam width [Å]

  • contrast (float) – Contrast parameter

  • offset (float) – Baseline offset

  • dt (float | None) – Time step from configuration [seconds]. MUST be provided for correct physics.

Return type:

Array

batch_chi_squared

homodyne.core.jax_backend.batch_chi_squared(params_batch, data, sigma, t1, t2, phi, q, L, contrast, offset, dt=None)[source]

Compute chi-squared for multiple parameter sets efficiently.

Parameters:
  • params_batch (Array) – Batch of parameter arrays, shape (n_batch, n_params)

  • data (Array) – Experimental g2 data

  • sigma (Array) – Uncertainty in data

  • t1 (Array) – Time arrays for correlation calculation

  • t2 (Array) – Time arrays for correlation calculation

  • phi (Array) – Scattering angles

  • q (float) – Wavevector magnitude [Å⁻¹]

  • L (float) – Beam width [Å]

  • contrast (float) – Contrast parameter

  • offset (float) – Baseline offset

  • dt (float | None) – Time step from configuration [seconds]. MUST be provided for correct physics.

Return type:

Array


Automatic Differentiation

These are pre-built JIT-compiled derivative functions:

Function

Derivative

Description

gradient_g2

grad(compute_g2_scaled)

Gradient of g₂ w.r.t. params (argnums=0)

hessian_g2

hessian(compute_g2_scaled)

Hessian of g₂ w.r.t. params

gradient_chi2

grad(compute_chi_squared)

Gradient of χ² w.r.t. params

hessian_chi2

hessian(compute_chi_squared)

Hessian of χ² w.r.t. params


Meshgrid Cache

The module maintains an LRU cache for time-grid meshgrids to avoid recomputing (t1_grid, t2_grid) on every function call.

homodyne.core.jax_backend.get_cached_meshgrid(t1, t2)[source]

Get or create cached meshgrid for time arrays.

For repeated calls with the same time arrays (common in optimization loops), this avoids recreating the same meshgrid ~23 times per iteration (once per phi).

When called inside a JIT context (traced arrays), caching is skipped and meshgrid is created directly (the JIT will handle caching via tracing).

Performance Optimization (Spec 006 - FR-010, T041): Increments hit/miss counters for cache monitoring.

Parameters:
  • t1 (Array) – First time array (1D)

  • t2 (Array) – Second time array (1D)

Return type:

tuple

Returns:

Tuple of (t1_grid, t2_grid) with indexing=”ij”

homodyne.core.jax_backend.clear_meshgrid_cache()[source]

Clear the meshgrid cache.

Call this when switching between datasets or when memory is constrained.

Return type:

None

homodyne.core.jax_backend.get_cache_stats()[source]

Get meshgrid cache statistics.

Performance Optimization (Spec 006 - FR-010, T042): Returns cache hit/miss statistics for monitoring and optimization.

Returns:

  • hits: Number of cache hits

  • misses: Number of cache misses

  • evictions: Number of cache evictions

  • skipped_large: Arrays too large for caching

  • skipped_traced: Skipped due to JIT tracing

  • hit_rate: Cache hit rate (hits / total lookups)

  • cache_size: Current number of cached entries

Return type:

dict[str, int | float]

homodyne.core.jax_backend.reset_cache_stats()[source]

Reset cache statistics counters.

Performance Optimization (Spec 006 - FR-010): Call before benchmarking to get clean statistics.

Return type:

None


Diagnostics

validate_backend

homodyne.core.jax_backend.validate_backend()[source]

Validate computational backends with comprehensive diagnostics.

Return type:

dict[str, Any]

get_device_info

homodyne.core.jax_backend.get_device_info()[source]

Get comprehensive device and backend information.

Return type:

dict[str, Any]

get_performance_summary

homodyne.core.jax_backend.get_performance_summary()[source]

Get performance summary and recommendations.

Return type:

dict[str, Any]


Usage Examples

Computing g₂ for a single angle

from homodyne.core.jax_backend import compute_g2_scaled
import jax.numpy as jnp

params = jnp.array([19231.0, 1.5, 0.1, 0.003, 0.8, 0.0, 0.0])
t = jnp.linspace(0, 0.1, 100)
t1, t2 = jnp.meshgrid(t, t, indexing="ij")

g2 = compute_g2_scaled(
    params, t1, t2, phi=jnp.array([0.0]),
    q=0.01, L=2_000_000.0,
    contrast=0.5, offset=1.0, dt=0.001,
)

Computing gradients for optimization

from homodyne.core.jax_backend import gradient_g2

grad_params = gradient_g2(
    params, t1, t2, phi=jnp.array([0.0]),
    q=0.01, L=2_000_000.0,
    contrast=0.5, offset=1.0, dt=0.001,
)
print(f"Gradient shape: {grad_params.shape}")  # (7,)

Cache management

from homodyne.core.jax_backend import get_cache_stats, clear_meshgrid_cache

stats = get_cache_stats()
print(f"Hit rate: {stats['hit_rate']:.1%}")

# Clear when switching datasets
clear_meshgrid_cache()