ADR-001: JAX CPU-Only Backend¶
- Status:
Accepted
- Date:
2024
- Deciders:
Core team
Context¶
Homodyne performs large-scale numerical optimization over two-time correlation matrices from XPCS experiments. The correlation matrices can reach \(N_t^2 \times N_\phi\) elements (e.g., \(1000^2 \times 23 \approx 23 \times 10^6\) values), and the CMC backend spawns multiple worker processes, each running NUTS chains via JAX-compiled NumPyro models.
The key question: which computational backend should homodyne target — CPU, GPU, or both?
The target deployment environment is synchrotron beamlines, which typically provide:
Multi-core CPU workstations (32–128 cores), often with NUMA topology.
No GPUs at most beamlines, or GPUs shared across multiple users.
Strict software environment controls (no arbitrary CUDA versions).
Requirements for reproducibility and portability across institutions.
Decision¶
Homodyne targets CPU-only JAX (no GPU/TPU support). The JAX backend is used for:
JIT compilation of the physics kernel (
jax_backend.py,physics_cmc.py).Automatic differentiation for Jacobians (NLSQ) and NUTS gradients (CMC).
Vectorized operations via
jax.vmapover angles and time points.
The CMC backend explicitly sets:
# In multiprocessing.py worker initialization
os.environ["JAX_ENABLE_X64"] = "1"
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"
This creates 4 virtual JAX devices per worker process, enabling NumPyro’s parallel
chain execution mode (pmap over 4 devices) without requiring a real GPU.
Rationale¶
1. Deployment environment
Synchrotron beamline workstations universally have multi-core CPUs. GPU availability is inconsistent across facilities and cannot be assumed. A CPU-only implementation guarantees portability.
2. Performance is adequate on CPU
The JIT compilation boundary means that after the first call, numerical kernels run as compiled XLA code at near-native speed. For the problem sizes typical in XPCS (correlation matrices up to \(\sim 10^8\) elements per shard), CPU throughput is sufficient when parallelism is structured correctly (worker pool + virtual devices).
Empirical benchmark (laminar_flow, 5K points/shard, 23 angles, 4 virtual devices):
parallelchain execution: 4.9 s wall time per shard.vectorizedchain execution: 101 s wall time per shard.
The parallel mode with virtual CPU devices achieves 20x speedup over the naive
vectorized mode, making CPU-only performance viable.
3. Reproducibility
CPU execution is deterministic (given the same PRNG seed) on a given machine. GPU execution is non-deterministic by default due to non-deterministic floating-point reductions, which would compromise result reproducibility across runs.
4. 64-bit precision
The physics model requires 64-bit floating-point arithmetic for accurate computation of
transport integrals (differences of large numbers). On CPU, JAX 64-bit is enabled by
setting JAX_ENABLE_X64=1. On GPU, 64-bit operations are significantly slower and
often unavailable on consumer-grade hardware.
5. Simplicity
Maintaining a single backend (CPU) avoids the complexity of:
Conditional code paths for GPU/CPU.
GPU memory management (sharding strategies differ on GPU).
GPU-specific debugging tools.
The codebase remains simpler and easier to audit.
Consequences¶
Positive:
Portability: works on any Python 3.12+ system with JAX installed.
Reproducibility: deterministic computation with explicit PRNG keys.
Full 64-bit precision throughout.
Simple deployment (
uv syncis sufficient; no CUDA setup).
Negative / Accepted trade-offs:
Cannot use GPU-accelerated matrix operations for very large correlation matrices.
The multiprocessing backend must spawn N separate processes rather than using GPU thread parallelism, which has higher per-process startup overhead.
For extremely large datasets (\(> 10^9\) data points), GPU acceleration would provide significant speedup that the current CPU implementation cannot match.
Alternatives Considered¶
A. GPU-first JAX
Would enable faster matrix operations for the NLSQ Jacobian computation and NUTS leapfrog steps. Rejected because: GPU not universally available at beamlines; 32-bit precision limitation; non-deterministic results.
B. NumPy / SciPy only
Would eliminate the JAX dependency, simplifying installation. Rejected because: no automatic differentiation (would require finite-difference Jacobians, ~100x slower); no JIT compilation; cannot leverage the NumPyro ecosystem for NUTS.
C. PyTorch backend
PyTorch has mature CPU and GPU support with autograd. Rejected because: the NumPyro
probabilistic programming library is JAX-native, and rewriting NUTS from scratch would
be a significant engineering effort; JAX’s jit/vmap/pmap composability is
a better fit for the parallelism patterns in CMC.
D. Mixed CPU/GPU
Support both, dispatch at runtime. Rejected for the current version because: doubles the testing matrix; GPU code path is untested and risky; adds complexity without clear benefit for the target deployment environment. May be reconsidered in a future version.
See also
ADR-006: No GPU Acceleration on Consumer Hardware — quantitative GPU feasibility assessment (RTX 4090)
System Architecture — overall system design
ADR-002: NLSQ / CMC Architectural Split — NLSQ optimizer choice
ADR-004: Consensus Monte Carlo for Bayesian Inference — CMC multiprocessing architecture