Performance Tuning: CPU/NUMA Optimization¶
Learning Objectives
By the end of this section you will understand:
How JAX uses CPU resources on multi-core machines
XLA flags for optimal thread and device configuration
NUMA-aware execution for multi-socket systems
Memory management for large datasets
Using
homodyne-config-xlafor configuration
—
Overview¶
Homodyne is CPU-only by design. Performance on modern multi-core CPUs is determined by:
XLA thread count: how many threads JAX uses per device
Virtual JAX device count: how many logical devices are exposed
NUMA topology: memory locality on multi-socket nodes
Memory allocation: preventing OOM errors for large datasets
CMC worker count: how many processes the multiprocessing backend spawns
—
homodyne-config-xla¶
The homodyne-config-xla command provides pre-configured XLA settings
for common CPU configurations:
# Show recommended settings for your hardware
homodyne-config-xla --show
# Configure for CMC Bayesian inference
homodyne-config-xla --mode cmc
# Configure for CMC on HPC nodes
homodyne-config-xla --mode cmc-hpc
# Configure for NLSQ fitting
homodyne-config-xla --mode nlsq
# Auto-detect best settings
homodyne-config-xla --mode auto
Example output:
Detected CPU: 16 logical cores, 8 physical cores, 1 NUMA node
Recommended settings:
XLA_FLAGS=--xla_cpu_multi_thread_eigen=false
XLA_FLAGS+=--xla_force_host_platform_device_count=4
OMP_NUM_THREADS=2
OPENBLAS_NUM_THREADS=2
—
XLA Flags¶
The most important XLA flags for homodyne:
xla_force_host_platform_device_count:
Controls how many virtual CPU devices JAX creates. The CMC multiprocessing
backend uses --xla_force_host_platform_device_count=4 per worker process
so that parallel chain execution can distribute 4 chains across 4 devices.
import os
# Set BEFORE importing jax
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"
import jax
print(f"JAX devices: {jax.devices()}") # [CpuDevice(id=0), ..., CpuDevice(id=3)]
xla_cpu_multi_thread_eigen:
Controls Eigen’s internal multithreading. Setting to false prevents
thread oversubscription when homodyne’s own parallelism is used:
export XLA_FLAGS="--xla_cpu_multi_thread_eigen=false --xla_force_host_platform_device_count=4"
—
OMP Thread Configuration¶
Homodyne’s CMC multiprocessing backend carefully manages OMP threads to prevent oversubscription:
Before spawning workers: clears
OMP_PROC_BINDandOMP_PLACESEach worker: sets
OMP_NUM_THREADS=1or2After spawning: restores parent environment
You can set the default manually:
# For 8 physical cores, 4 workers (2 cores per worker)
export OMP_NUM_THREADS=2
Or let homodyne configure it automatically (recommended):
homodyne-config-xla --mode auto
source ~/.homodyne_xla_config # Apply the generated config
—
CMC Worker Count¶
The number of CMC worker processes is determined automatically:
n_workers = max(1, physical_cores // 2 - 1)
For a 16-physical-core machine: n_workers = 7
Override manually:
optimization:
cmc:
num_workers: 4 # Use 4 workers (useful on shared nodes)
Resource reservation: Always leave at least 1 physical core for the main process and OS.
—
NUMA Awareness¶
On dual-socket servers (2 NUMA nodes), memory locality matters. Homodyne does not currently implement explicit NUMA pinning, but you can use numactl to pin the process:
# Run on NUMA node 0 only
numactl --cpunodebind=0 --membind=0 \
uv run homodyne --config config.yaml --output results/
# For a 64-core dual-socket (32 cores per node):
numactl --cpunodebind=0 --membind=0 \
uv run homodyne --config config_half.yaml --output results_node0/ &
numactl --cpunodebind=1 --membind=1 \
uv run homodyne --config config_half.yaml --output results_node1/ &
wait
This is faster than letting the OS distribute memory across NUMA nodes.
—
JAX Compilation Caching¶
JAX JIT compilation can dominate execution time for small datasets. Cache compiled functions between runs.
Warning
In JAX 0.8+, setting the JAX_COMPILATION_CACHE_DIR environment variable
alone does NOT enable the persistent cache. You must use
jax.config.update() after importing JAX. The CMC multiprocessing backend
handles this automatically for worker processes.
import jax
jax.config.update("jax_compilation_cache_dir", "/path/to/cache")
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
The min_compile_time_secs default of 1.0 s is too high for CMC physics
functions (which compile in 0.07–0.15 s). Setting it to 0 ensures all JIT
compilations are cached. With proper caching, the first CMC worker compiles
all functions; subsequent workers load from the disk cache, saving ~10 s
per run across workers.
—
Memory Profiling¶
Profile memory usage to tune the memory_fraction threshold:
import tracemalloc
from homodyne.config import ConfigManager
from homodyne.data import load_xpcs_data
from homodyne.optimization.nlsq import fit_nlsq_jax
tracemalloc.start()
config = ConfigManager("config.yaml")
data = load_xpcs_data("config.yaml")
result = fit_nlsq_jax(data, config)
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
print(f"Peak memory: {peak / 1e9:.2f} GB")
Alternatively, use the memory estimator:
from homodyne.optimization.nlsq import estimate_peak_memory_gb
n_points = len(data['c2_exp'].ravel())
n_params = 9 # For laminar_flow with auto mode
peak = estimate_peak_memory_gb(n_points, n_params)
print(f"Estimated peak memory: {peak:.1f} GB")
—
JAX Profiling (Advanced)¶
For CMC, enable JAX-level profiling to see XLA kernel execution times:
optimization:
cmc:
per_shard_mcmc:
enable_jax_profiling: true
jax_profile_dir: ./profiles/jax
Then view with TensorBoard:
pip install tensorboard
tensorboard --logdir=./profiles/jax
This shows XLA operation timelines, useful for identifying bottlenecks in the NUTS sampler.
—
Performance Checklist¶
Before running a long analysis:
[ ] Check available RAM:
free -h(Linux) orvm_stat(macOS)[ ] Set appropriate
memory_fractionin config[ ] Use
per_angle_mode: "auto"notindividualfor large n_phi[ ] Set
chain_method: "parallel"in CMC config[ ] Use
homodyne-config-xlato set XLA flags[ ] Ensure JIT cache uses
jax.config.update()(not just env var) for JAX 0.8+[ ] Consider
numactlfor multi-socket systems
Note
CMC shards are dispatched using noise-weighted LPT (Longest Processing Time first) scheduling. The most expensive shards run first to minimize tail latency. Per-shard data is placed in shared memory to avoid serialization overhead. These optimizations are automatic and require no configuration.
—
See Also¶
YAML Configuration Reference — Full YAML configuration reference
Large Dataset Handling and Streaming — Large dataset handling
Bayesian Inference with CMC — CMC worker configuration
Troubleshooting Guide — Memory and performance issues