homodyne.optimization.cmc.sampler¶
The sampler module manages NUTS (No-U-Turn Sampler) execution for each
CMC shard. It provides SamplingPlan — the single source of truth for
per-shard warmup and sample counts — and run_nuts_sampling(), which
configures NumPyro’s MCMC object and executes sampling.
Note
Always use SamplingPlan.from_config() instead of accessing
config.num_warmup / config.num_samples directly in sampling
code paths. SamplingPlan applies adaptive scaling that can
significantly reduce overhead for small shards.
Divergence Rate Constants¶
Centralized thresholds for NUTS convergence diagnostics:
Constant |
Value |
Meaning |
|---|---|---|
|
0.05 |
Below this: acceptable sampling quality |
|
0.10 |
Above this: posterior may be biased |
|
0.30 |
Above this: posterior likely unreliable; shard filtered |
SamplingPlan¶
Captures the actual warmup and sample counts used per shard after adaptive
scaling. Instantiate via SamplingPlan.from_config() — never construct
directly in hot paths.
- class homodyne.optimization.cmc.sampler.SamplingPlan[source]
Bases:
objectAdapted MCMC sampling counts for a single shard.
Captures the actual warmup/sample counts after adaptive scaling, which may differ from CMCConfig defaults for small shards.
Use SamplingPlan.from_config() instead of accessing config.num_warmup / config.num_samples in hot paths.
- n_warmup: int
- n_samples: int
- n_chains: int
- shard_size: int
- n_params: int
- was_adapted: bool
- classmethod from_config(config, shard_size, n_params)[source]
- Return type:
SamplingPlan
- property total_samples: int
- __init__(n_warmup, n_samples, n_chains, shard_size, n_params, was_adapted)
SamplingPlan.from_config¶
- classmethod SamplingPlan.from_config(config, shard_size, n_params)[source]
- Return type:
SamplingPlan
Adaptive Sampling Behaviour¶
When adaptive_sampling: true (default), warmup and sample counts are
scaled down automatically for small shards to reduce NUTS overhead:
Shard Size |
Warmup |
Samples |
Total |
Reduction |
|---|---|---|---|---|
50 pts |
140 |
350 |
490 |
~75 % |
5 K pts |
250 |
750 |
1 000 |
~50 % |
50 K+ pts |
500 |
1 500 |
2 000 |
None |
optimization:
cmc:
per_shard_mcmc:
adaptive_sampling: true
min_warmup: 100
min_samples: 200
max_tree_depth: 10
run_nuts_sampling¶
- homodyne.optimization.cmc.sampler.run_nuts_sampling(model, model_kwargs, config, initial_values, parameter_space, n_phi, analysis_mode, rng_key=None, progress_bar=True, per_angle_mode='individual')[source]
Run NUTS sampling with configuration.
- Parameters:
model (Callable) – NumPyro model function.
model_kwargs (dict[str, Any]) – Keyword arguments to pass to model.
config (CMCConfig) – CMC configuration.
initial_values (dict[str, float] | None) – Initial parameter values from config.
parameter_space (ParameterSpace) – Parameter space for building init values.
n_phi (int) – Number of phi angles.
analysis_mode (str) – Analysis mode.
rng_key (jax.random.PRNGKey | None) – Random key. If None, creates from seed.
progress_bar (bool) – Whether to show progress bar.
per_angle_mode (str) – Per-angle scaling mode: “individual”, “auto”, “constant”, or “constant_averaged”. Controls which parameters are sampled vs fixed.
- Returns:
Samples and timing statistics.
- Return type:
tuple[MCMCSamples, SamplingStats]
MCMCSamples¶
Return type for per-shard sampling results.
- class homodyne.optimization.cmc.sampler.MCMCSamples[source]
Bases:
objectContainer for MCMC samples.
- n_chains
Number of chains.
- Type:
- n_samples
Number of samples per chain.
- Type:
- num_shards
Number of shards combined (1 for single shard, >1 for CMC). Used for correct divergence rate calculation in CMC.
- Type:
- n_chains: int
- n_samples: int
- num_shards: int = 1
- bimodal_consensus: Any = None
- __init__(samples, param_names, n_chains, n_samples, extra_fields=<factory>, num_shards=1, shard_adapted_n_warmup=None, bimodal_consensus=None)
NUTS Configuration¶
Key NumPyro NUTS parameters accessible via CMCConfig.per_shard_mcmc:
Parameter |
Default |
Description |
|---|---|---|
|
0.8 |
Target HMC acceptance probability |
|
10 |
Max NUTS tree depth (210 = 1024 leapfrog steps) |
|
|
|
|
4 |
Number of parallel MCMC chains |
Warning
For the multiprocessing backend, chain_method: "parallel" is the only
recommended setting. vectorized causes workers to drop to 1–2 CPUs,
resulting in an empirically observed 20× slowdown (4.9 s vs 101 s wall time
for identical workloads).
JAX Profiling¶
Enable XLA-level profiling to diagnose JIT compilation and execution bottlenecks.
Note that py-spy only profiles Python code; XLA runs native code invisible
to it. JAX profiling provides XLA-level insights.
optimization:
cmc:
per_shard_mcmc:
enable_jax_profiling: true
jax_profile_dir: ./profiles/jax
View results with TensorBoard:
tensorboard --logdir=./profiles/jax
Usage Examples¶
Checking adaptive scaling¶
from homodyne.optimization.cmc.sampler import SamplingPlan
from homodyne.optimization.cmc.config import CMCConfig
config = CMCConfig() # default settings
# Small shard (1000 pts, 9 params)
plan = SamplingPlan.from_config(config, shard_size=1000, n_params=9)
print(f"Warmup: {plan.n_warmup}")
print(f"Samples: {plan.n_samples}")
print(f"Adapted: {plan.was_adapted}")
# Full-size shard (50K pts)
plan_full = SamplingPlan.from_config(config, shard_size=50_000, n_params=9)
print(f"Adapted: {plan_full.was_adapted}") # False
Inspecting divergence constants¶
from homodyne.optimization.cmc.sampler import (
DIVERGENCE_RATE_TARGET,
DIVERGENCE_RATE_HIGH,
DIVERGENCE_RATE_CRITICAL,
)
div_rate = 0.08 # example shard divergence rate
if div_rate < DIVERGENCE_RATE_TARGET:
print("Excellent sampling quality")
elif div_rate < DIVERGENCE_RATE_HIGH:
print("Acceptable — monitor carefully")
else:
print("High divergence — shard may be filtered")
NUTS sampler wrapper for CMC analysis.
This module provides utilities for running NumPyro NUTS sampling with proper initialization and progress tracking.
- class homodyne.optimization.cmc.sampler.SamplingStats[source]
Bases:
objectStatistics from MCMC sampling.
- warmup_time
Time spent in warmup phase (seconds).
- Type:
- sampling_time
Time spent in sampling phase (seconds).
- Type:
- total_time
Total sampling time (seconds).
- Type:
- num_divergent
Number of divergent transitions.
- Type:
- accept_prob
Mean acceptance probability.
- Type:
- step_size
Final step size.
- Type:
- step_size_min
Minimum adapted step size across chains (if available).
- Type:
- step_size_max
Maximum adapted step size across chains (if available).
- Type:
- inverse_mass_matrix_summary
Compact summary of the adapted inverse mass matrix (if available).
- Type:
str | None
- tree_depth
Mean tree depth.
- Type:
- warmup_time: float = 0.0
- sampling_time: float = 0.0
- total_time: float = 0.0
- num_divergent: int = 0
- accept_prob: float = 0.0
- step_size: float = 0.0
- tree_depth: float = 0.0
- plan: SamplingPlan | None = None
- __init__(warmup_time=0.0, sampling_time=0.0, total_time=0.0, num_divergent=0, accept_prob=0.0, step_size=0.0, step_size_min=None, step_size_max=None, inverse_mass_matrix_summary=None, tree_depth=0.0, plan=None)
- homodyne.optimization.cmc.sampler.create_init_strategy(initial_values, param_names, use_init_to_value=True, z_space_values=None)[source]
Create initialization strategy for NUTS.
- Parameters:
initial_values (
dict[str,float] |None) – Initial values from config (original space).param_names (
list[str]) – Expected parameter names in order.use_init_to_value (
bool) – If True, use init_to_value when values provided.z_space_values (
dict[str,float] |None) – Initial values in z-space (for scaled model). If provided, these are used directly as {name}_z values.
- Returns:
NumPyro initialization function.
- Return type:
- homodyne.optimization.cmc.sampler.run_nuts_with_retry(model, model_kwargs, config, initial_values, parameter_space, n_phi, analysis_mode, max_retries=3, rng_key=None, per_angle_mode='individual')[source]
Run NUTS sampling with automatic retry on failure.
- Parameters:
model (Callable) – NumPyro model function.
model_kwargs (dict[str, Any]) – Model arguments.
config (CMCConfig) – Configuration.
initial_values (dict[str, float] | None) – Initial values.
parameter_space (ParameterSpace) – Parameter space.
n_phi (int) – Number of phi angles.
analysis_mode (str) – Analysis mode.
max_retries (int) – Maximum number of retry attempts.
rng_key (jax.random.PRNGKey | None) – Random key.
- Returns:
Samples and statistics.
- Return type:
tuple[MCMCSamples, SamplingStats]
- Raises:
RuntimeError – If all retries fail.