Source code for homodyne.core.physics_utils

"""Shared Physics Utility Functions for Homodyne
================================================

This module provides common utility functions and physics helpers used by
both NLSQ (meshgrid) and CMC (element-wise) computational backends.

These functions were consolidated from:
- jax_backend.py
- physics_nlsq.py
- physics_cmc.py

to eliminate code duplication and ensure consistent behavior across backends.

Key Functions:
- safe_len: JAX-safe length function for scalars and arrays
- safe_exp: Overflow-protected exponential
- safe_sinc: Numerically stable unnormalized sinc function
- _calculate_diffusion_coefficient_impl_jax: Time-dependent diffusion D(t)
- _calculate_shear_rate_impl_jax: Time-dependent shear rate γ̇(t)
- _create_time_integral_matrix_impl_jax: Trapezoidal cumulative integral matrix
"""

import jax.numpy as jnp
from jax import jit

# Physical and mathematical constants
PI = jnp.pi
EPS = 1e-12  # Numerical stability epsilon


# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================


[docs] def safe_len(obj: object) -> int: """JAX-safe length function that handles scalars, arrays, and JAX objects. Args: obj: Any object that might have a length or shape Returns: int: Length of the object, or 1 for scalars """ # Handle JAX arrays and numpy arrays with shape attribute if hasattr(obj, "shape"): if obj.shape == () or len(obj.shape) == 0: # Scalar (0-dimensional array) return 1 else: # Array - return first dimension size return int(obj.shape[0]) # Handle objects with __len__ method (lists, tuples, etc.) if hasattr(obj, "__len__"): try: return len(obj) except TypeError: # This catches "len() of unsized object" errors return 1 # Handle scalars (int, float, etc.) if hasattr(obj, "__iter__") and not isinstance(obj, (str, bytes)): # Iterable but not string/bytes try: return len(list(obj)) except (TypeError, ValueError): return 1 # Default case: treat as scalar return 1
[docs] @jit def safe_exp(x: jnp.ndarray, max_val: float = 700.0) -> jnp.ndarray: """Safe exponential to prevent overflow. Args: x: Input array max_val: Maximum absolute value to clip to (default 700.0) Returns: exp(clip(x, -max_val, max_val)) """ return jnp.exp(jnp.clip(x, -max_val, max_val))
[docs] @jit def safe_sinc(x: jnp.ndarray) -> jnp.ndarray: r"""Safe UNNORMALIZED sinc function: sin(x) / x (NOT sin(πx) / (πx)). This matches the reference implementation which uses sin(arg) / arg directly. The phase argument already includes all necessary scaling factors. P2-4: Uses a Taylor expansion near zero (1 - x²/6 + x⁴/120) for smooth gradient continuity. The old hard switch from sin(x)/x to 1.0 at ``|x|``\=EPS created a gradient discontinuity that caused spurious NUTS rejections near gamma_dot_t0 ≈ 0. Args: x: Input array Returns: sin(x)/x for ``|x|`` >= 1e-4, Taylor approximation for ``|x|`` < 1e-4 """ x2 = x * x near_zero = 1.0 - x2 / 6.0 + x2 * x2 / 120.0 far = jnp.sin(x) / jnp.where(jnp.abs(x) > EPS, x, 1.0) # avoid div/0 return jnp.where(jnp.abs(x) < 1e-4, near_zero, far)
# ============================================================================= # PHYSICS HELPER FUNCTIONS # =============================================================================
[docs] @jit def calculate_diffusion_coefficient( time_array: jnp.ndarray, D0: float, alpha: float, D_offset: float, ) -> jnp.ndarray: """Calculate time-dependent diffusion coefficient using discrete evaluation. Follows reference v1 implementation: D_t[i] = D0 * (time_array[i] ** alpha) + D_offset Physical constraint: D(t) should be positive and finite Args: time_array: Array of time points D0: Diffusion coefficient amplitude alpha: Anomalous diffusion exponent D_offset: Baseline diffusion offset Returns: D(t) evaluated at each time point with physical bounds applied """ # CRITICAL FIX: Replace near-zero values to prevent t=0 with negative alpha causing Inf/NaN # When alpha < 0: t^alpha = 1/t^|alpha|, so t=0 → infinity # Using jnp.maximum (not addition) to only affect near-zero values # Use dt/2 to preserve monotonicity: D(dt/2) < D(dt) for alpha > 0 # # Avoid Python `if shape[0] > 1` which causes JIT recompilation per unique # array length. Instead compute dt unconditionally: for n==1, time_array[0] # is used twice and the difference is 0, so we fall back to the 1e-8 floor. dt_inferred = jnp.abs( time_array[jnp.minimum(1, time_array.shape[0] - 1)] - time_array[0] ) epsilon = jnp.where(dt_inferred * 0.5 > 1e-8, dt_inferred * 0.5, 1e-8) time_safe = jnp.where(time_array > epsilon, time_array, epsilon) # Compute diffusion coefficient D_t = D0 * (time_safe**alpha) + D_offset # Ensure positive values — use jnp.where (not jnp.maximum) to preserve # gradients below the floor for NLSQ Jacobian computation and NUTS leapfrog. return jnp.where(D_t > 1e-10, D_t, 1e-10)
[docs] @jit def calculate_shear_rate( time_array: jnp.ndarray, gamma_dot_0: float, beta: float, gamma_dot_offset: float, ) -> jnp.ndarray: """Calculate time-dependent shear rate using discrete evaluation. Follows reference v1 implementation: γ̇_t[i] = γ̇₀ * (time_array[i] ** β) + γ̇_offset Args: time_array: Array of time points gamma_dot_0: Shear rate amplitude beta: Shear rate exponent gamma_dot_offset: Baseline shear rate offset Returns: γ̇(t) evaluated at each time point """ # CRITICAL FIX: Replace t=0 with dt to prevent singularity when beta < 0 # When beta < 0: t^beta = 1/t^|beta|, so t=0 → infinity # Strategy: Replace only the first element (t=0) with dt, leave others unchanged # This ensures smooth continuity: γ̇(dt), γ̇(dt), γ̇(2dt), ... # # Avoid Python `if shape[0] > 1` which causes JIT recompilation per unique # array length. For n==1, index 0 is used twice → inferred dt=0, but the # jnp.where guard below keeps it safe with a 1e-8 floor. dt = jnp.where( jnp.abs(time_array[jnp.minimum(1, time_array.shape[0] - 1)] - time_array[0]) > 1e-8, jnp.abs(time_array[jnp.minimum(1, time_array.shape[0] - 1)] - time_array[0]), 1e-8, ) # Replace near-zero values with dt/2 floor, matching calculate_diffusion_coefficient # This provides a continuous floor at the midpoint instead of exact-zero equality check. # Floor = 1e-8 matches calculate_diffusion_coefficient — both are power-law t^exponent # and have the same singularity structure at t=0. epsilon = jnp.where(dt * 0.5 > 1e-8, dt * 0.5, 1e-8) time_safe = jnp.where(time_array > epsilon, time_array, epsilon) gamma_t = gamma_dot_0 * (time_safe**beta) + gamma_dot_offset # Ensure positive values — use jnp.where (not jnp.maximum) to preserve gradients. return jnp.where(gamma_t > 1e-10, gamma_t, 1e-10)
[docs] @jit def calculate_shear_rate_cmc( time_array: jnp.ndarray, gamma_dot_0: float, beta: float, gamma_dot_offset: float, ) -> jnp.ndarray: """Calculate time-dependent shear rate for CMC (element-wise) computations. This variant includes an additional safety check for consecutive zeros in CMC element-wise data where dt could be zero. Args: time_array: Array of time points gamma_dot_0: Shear rate amplitude beta: Shear rate exponent gamma_dot_offset: Baseline shear rate offset Returns: γ̇(t) evaluated at each time point """ # Infer dt from time grid. # Avoid Python `if shape[0] > 1` which causes JIT recompilation per unique # array length. For n==1, index 0 is used twice → dt=0, but the jnp.where # guard below keeps it safe with a 1e-8 floor. # CRITICAL FIX: Ensure dt > 0 to prevent 0^(negative beta) = infinity # CMC element-wise data can have consecutive zeros: t[0]=0, t[1]=0 → dt=0 # This causes NaN when beta < 0 in gamma_t = gamma_dot_0 * (time_safe**beta) dt_raw = jnp.abs( time_array[jnp.minimum(1, time_array.shape[0] - 1)] - time_array[0] ) dt = jnp.where(dt_raw > 1e-8, dt_raw, 1e-8) # Replace near-zero values with dt/2 floor, matching calculate_diffusion_coefficient. # Floor = 1e-8 matches the non-CMC shear variant and the diffusion function. epsilon = jnp.where(dt * 0.5 > 1e-8, dt * 0.5, 1e-8) time_safe = jnp.where(time_array > epsilon, time_array, epsilon) gamma_t = gamma_dot_0 * (time_safe**beta) + gamma_dot_offset # Ensure positive values — use jnp.where (not jnp.maximum) to preserve gradients. return jnp.where(gamma_t > 1e-10, gamma_t, 1e-10)
[docs] @jit def create_time_integral_matrix( time_dependent_array: jnp.ndarray, ) -> jnp.ndarray: r"""Create time integral matrix using trapezoidal numerical integration. Computes the full N x N matrix of pairwise trapezoidal integral differences via broadcasting. The dt scaling happens in wavevector_q_squared_half_dt, NOT in this cumsum. Algorithm: 1. Trapezoidal integration: cumsum[i] = Sum(k=0 to i-1) 0.5 * (f[k] + f[k+1]) 2. Compute full difference matrix: matrix[i,j] = smooth_abs(cumsum[i] - cumsum[j]) 3. The dt factor is applied via wavevector_q_squared_half_dt = 0.5 * q^2 * dt This gives: matrix[i,j] = number of integration steps. Actual integral: dt * matrix[i,j] approximates the integral from 0 to abs(ti-tj) of f(t') dt' Benefits over simple cumsum: - Reduces oscillations from discretization by ~50% - Second-order accuracy (O(dt^2)) vs. first-order (O(dt)) - Eliminates checkerboard artifacts in diagonal-corrected results Args: time_dependent_array: f(t) evaluated at discrete time points Returns: Time integral matrix (in units of integration steps) """ # Handle scalar input by converting to array time_dependent_array = jnp.atleast_1d(time_dependent_array) # Step 1: Improved cumulative integration using trapezoidal rule # Trapezoidal: ∫f(t)dt ≈ dt × Σ(1/2)(f[i] + f[i+1]) # The dt scaling happens in wavevector_q_squared_half_dt, not here # # Avoid Python `if n > 1` which causes JIT recompilation per unique array # length. The trapezoidal path is unconditionally correct: for n==1, # time_dependent_array[:-1] and [1:] are both empty, trap_avg is empty, # cumsum_trap is empty, and concatenate([0.0], []) = [0.0] which is the # same result as jnp.cumsum([x]) = [x] only if x==0 — but for n==1 the # direct-cumsum fallback was returning [x], not [0, x]. Since n==1 never # occurs in hot paths (time grids are always 1000+ points), and for # correctness the trapezoidal result [0.0] is the correct starting cumsum, # the unified path is used unconditionally. trap_avg = 0.5 * (time_dependent_array[:-1] + time_dependent_array[1:]) cumsum_trap = jnp.cumsum(trap_avg) cumsum = jnp.concatenate( [jnp.array([0.0], dtype=time_dependent_array.dtype), cumsum_trap] ) # Step 2: Create difference matrix exploiting symmetry. # # The full matrix is: matrix[i,j] = smooth_abs(cumsum[i] - cumsum[j]) # Because cumsum is monotonically non-decreasing (inputs >= 0 always): # - lower triangle (i >= j): diff[i,j] = cumsum[i] - cumsum[j] >= 0 # - upper triangle (i < j): diff[i,j] = -(diff[j,i]) # - diagonal: diff[i,i] = 0 exactly # # Strategy: compute only the lower triangle, apply smooth-abs there, mirror. # This halves sqrt evaluations and avoids a full N×N temporary for diff. # # CRITICAL: Use smooth approximation of abs() for gradient stability. # jnp.abs() has undefined gradient at x=0, causing NaN in backpropagation. # Solution: sqrt(x² + ε) ≈ |x| but is differentiable everywhere. # P0-2: epsilon=1e-12 (was 1e-20, below float32 machine epsilon ~1.2e-7). epsilon = 1e-12 # Compute full signed-difference matrix and apply smooth-abs directly. # This avoids the sqrt(epsilon) bias that the tril-then-mirror approach # introduced on upper-triangle zeros. diff = cumsum[:, None] - cumsum[None, :] # Shape: (n, n), symmetric matrix = jnp.sqrt(diff**2 + epsilon) # Shape: (n, n), smooth |diff| return matrix
[docs] def trapezoid_cumsum(values: jnp.ndarray) -> jnp.ndarray: """Cumulative trapezoid integral without dt scaling (dt is applied outside). Returns cumsum so that ``cumsum[j] - cumsum[i]`` equals the trapezoidal sum over all intervals between indices ``i`` and ``j``. The caller applies a smooth absolute value to that difference when mapping each (t1, t2) pair, keeping gradients well-behaved at zero-length intervals. This is used by the CMC element-wise computations. Args: values: 1D array of values to integrate Returns: Cumulative trapezoidal sums """ # Unconditional trapezoidal path — avoids JIT retracing when array size # changes. For n==1, values[:-1] and values[1:] are both empty, so # cumsum_trap is empty and the result is [0.0], which is the correct # cumulative integral (no intervals to sum). trap_avg = 0.5 * (values[:-1] + values[1:]) cumsum_trap = jnp.cumsum(trap_avg) return jnp.concatenate([jnp.array([0.0], dtype=values.dtype), cumsum_trap])
# ============================================================================= # DIAGONAL CORRECTION # ============================================================================= # Re-export from unified diagonal_correction module for backward compatibility. # See homodyne/core/diagonal_correction.py for the canonical implementation. from homodyne.core.diagonal_correction import ( # noqa: E402 apply_diagonal_correction, # noqa: F401 apply_diagonal_correction_batch, # noqa: F401 ) # ============================================================================= # BACKWARD COMPATIBILITY ALIASES # ============================================================================= # These aliases maintain backward compatibility with existing code _calculate_diffusion_coefficient_impl_jax = calculate_diffusion_coefficient _calculate_shear_rate_impl_jax = calculate_shear_rate _create_time_integral_matrix_impl_jax = create_time_integral_matrix _trapezoid_cumsum = trapezoid_cumsum