homodyne.optimization.cmc.sampler

The sampler module manages NUTS (No-U-Turn Sampler) execution for each CMC shard. It provides SamplingPlan — the single source of truth for per-shard warmup and sample counts — and run_nuts_sampling(), which configures NumPyro’s MCMC object and executes sampling.

Note

Always use SamplingPlan.from_config() instead of accessing config.num_warmup / config.num_samples directly in sampling code paths. SamplingPlan applies adaptive scaling that can significantly reduce overhead for small shards.


Divergence Rate Constants

Centralized thresholds for NUTS convergence diagnostics:

Constant

Value

Meaning

DIVERGENCE_RATE_TARGET

0.05

Below this: acceptable sampling quality

DIVERGENCE_RATE_HIGH

0.10

Above this: posterior may be biased

DIVERGENCE_RATE_CRITICAL

0.30

Above this: posterior likely unreliable; shard filtered


SamplingPlan

Captures the actual warmup and sample counts used per shard after adaptive scaling. Instantiate via SamplingPlan.from_config() — never construct directly in hot paths.

class homodyne.optimization.cmc.sampler.SamplingPlan[source]

Bases: object

Adapted MCMC sampling counts for a single shard.

Captures the actual warmup/sample counts after adaptive scaling, which may differ from CMCConfig defaults for small shards.

Use SamplingPlan.from_config() instead of accessing config.num_warmup / config.num_samples in hot paths.

n_warmup: int
n_samples: int
n_chains: int
shard_size: int
n_params: int
was_adapted: bool
classmethod from_config(config, shard_size, n_params)[source]
Return type:

SamplingPlan

property total_samples: int
__init__(n_warmup, n_samples, n_chains, shard_size, n_params, was_adapted)

SamplingPlan.from_config

classmethod SamplingPlan.from_config(config, shard_size, n_params)[source]
Return type:

SamplingPlan


Adaptive Sampling Behaviour

When adaptive_sampling: true (default), warmup and sample counts are scaled down automatically for small shards to reduce NUTS overhead:

Shard Size

Warmup

Samples

Total

Reduction

50 pts

140

350

490

~75 %

5 K pts

250

750

1 000

~50 %

50 K+ pts

500

1 500

2 000

None

optimization:
  cmc:
    per_shard_mcmc:
      adaptive_sampling: true
      min_warmup: 100
      min_samples: 200
      max_tree_depth: 10

run_nuts_sampling

homodyne.optimization.cmc.sampler.run_nuts_sampling(model, model_kwargs, config, initial_values, parameter_space, n_phi, analysis_mode, rng_key=None, progress_bar=True, per_angle_mode='individual')[source]

Run NUTS sampling with configuration.

Parameters:
  • model (Callable) – NumPyro model function.

  • model_kwargs (dict[str, Any]) – Keyword arguments to pass to model.

  • config (CMCConfig) – CMC configuration.

  • initial_values (dict[str, float] | None) – Initial parameter values from config.

  • parameter_space (ParameterSpace) – Parameter space for building init values.

  • n_phi (int) – Number of phi angles.

  • analysis_mode (str) – Analysis mode.

  • rng_key (jax.random.PRNGKey | None) – Random key. If None, creates from seed.

  • progress_bar (bool) – Whether to show progress bar.

  • per_angle_mode (str) – Per-angle scaling mode: “individual”, “auto”, “constant”, or “constant_averaged”. Controls which parameters are sampled vs fixed.

Returns:

Samples and timing statistics.

Return type:

tuple[MCMCSamples, SamplingStats]


MCMCSamples

Return type for per-shard sampling results.

class homodyne.optimization.cmc.sampler.MCMCSamples[source]

Bases: object

Container for MCMC samples.

samples

Parameter samples, shape (n_chains, n_samples) per parameter.

Type:

dict[str, np.ndarray]

param_names

Parameter names in sampling order.

Type:

list[str]

n_chains

Number of chains.

Type:

int

n_samples

Number of samples per chain.

Type:

int

extra_fields

Additional MCMC info (divergences, energy, etc.).

Type:

dict[str, Any]

num_shards

Number of shards combined (1 for single shard, >1 for CMC). Used for correct divergence rate calculation in CMC.

Type:

int

samples: dict[str, ndarray]
param_names: list[str]
n_chains: int
n_samples: int
extra_fields: dict[str, Any]
num_shards: int = 1
shard_adapted_n_warmup: int | None = None
bimodal_consensus: Any = None
__init__(samples, param_names, n_chains, n_samples, extra_fields=<factory>, num_shards=1, shard_adapted_n_warmup=None, bimodal_consensus=None)

NUTS Configuration

Key NumPyro NUTS parameters accessible via CMCConfig.per_shard_mcmc:

Parameter

Default

Description

target_accept_prob

0.8

Target HMC acceptance probability

max_tree_depth

10

Max NUTS tree depth (210 = 1024 leapfrog steps)

chain_method

"parallel"

parallel / vectorized / sequential

num_chains

4

Number of parallel MCMC chains

Warning

For the multiprocessing backend, chain_method: "parallel" is the only recommended setting. vectorized causes workers to drop to 1–2 CPUs, resulting in an empirically observed 20× slowdown (4.9 s vs 101 s wall time for identical workloads).


JAX Profiling

Enable XLA-level profiling to diagnose JIT compilation and execution bottlenecks. Note that py-spy only profiles Python code; XLA runs native code invisible to it. JAX profiling provides XLA-level insights.

optimization:
  cmc:
    per_shard_mcmc:
      enable_jax_profiling: true
      jax_profile_dir: ./profiles/jax

View results with TensorBoard:

tensorboard --logdir=./profiles/jax

Usage Examples

Checking adaptive scaling

from homodyne.optimization.cmc.sampler import SamplingPlan
from homodyne.optimization.cmc.config import CMCConfig

config = CMCConfig()   # default settings

# Small shard (1000 pts, 9 params)
plan = SamplingPlan.from_config(config, shard_size=1000, n_params=9)
print(f"Warmup:  {plan.n_warmup}")
print(f"Samples: {plan.n_samples}")
print(f"Adapted: {plan.was_adapted}")

# Full-size shard (50K pts)
plan_full = SamplingPlan.from_config(config, shard_size=50_000, n_params=9)
print(f"Adapted: {plan_full.was_adapted}")   # False

Inspecting divergence constants

from homodyne.optimization.cmc.sampler import (
    DIVERGENCE_RATE_TARGET,
    DIVERGENCE_RATE_HIGH,
    DIVERGENCE_RATE_CRITICAL,
)

div_rate = 0.08   # example shard divergence rate

if div_rate < DIVERGENCE_RATE_TARGET:
    print("Excellent sampling quality")
elif div_rate < DIVERGENCE_RATE_HIGH:
    print("Acceptable — monitor carefully")
else:
    print("High divergence — shard may be filtered")

NUTS sampler wrapper for CMC analysis.

This module provides utilities for running NumPyro NUTS sampling with proper initialization and progress tracking.

class homodyne.optimization.cmc.sampler.SamplingStats[source]

Bases: object

Statistics from MCMC sampling.

warmup_time

Time spent in warmup phase (seconds).

Type:

float

sampling_time

Time spent in sampling phase (seconds).

Type:

float

total_time

Total sampling time (seconds).

Type:

float

num_divergent

Number of divergent transitions.

Type:

int

accept_prob

Mean acceptance probability.

Type:

float

step_size

Final step size.

Type:

float

step_size_min

Minimum adapted step size across chains (if available).

Type:

float

step_size_max

Maximum adapted step size across chains (if available).

Type:

float

inverse_mass_matrix_summary

Compact summary of the adapted inverse mass matrix (if available).

Type:

str | None

tree_depth

Mean tree depth.

Type:

float

warmup_time: float = 0.0
sampling_time: float = 0.0
total_time: float = 0.0
num_divergent: int = 0
accept_prob: float = 0.0
step_size: float = 0.0
step_size_min: float | None = None
step_size_max: float | None = None
inverse_mass_matrix_summary: str | None = None
tree_depth: float = 0.0
plan: SamplingPlan | None = None
__init__(warmup_time=0.0, sampling_time=0.0, total_time=0.0, num_divergent=0, accept_prob=0.0, step_size=0.0, step_size_min=None, step_size_max=None, inverse_mass_matrix_summary=None, tree_depth=0.0, plan=None)
homodyne.optimization.cmc.sampler.create_init_strategy(initial_values, param_names, use_init_to_value=True, z_space_values=None)[source]

Create initialization strategy for NUTS.

Parameters:
  • initial_values (dict[str, float] | None) – Initial values from config (original space).

  • param_names (list[str]) – Expected parameter names in order.

  • use_init_to_value (bool) – If True, use init_to_value when values provided.

  • z_space_values (dict[str, float] | None) – Initial values in z-space (for scaled model). If provided, these are used directly as {name}_z values.

Returns:

NumPyro initialization function.

Return type:

Callable

homodyne.optimization.cmc.sampler.run_nuts_with_retry(model, model_kwargs, config, initial_values, parameter_space, n_phi, analysis_mode, max_retries=3, rng_key=None, per_angle_mode='individual')[source]

Run NUTS sampling with automatic retry on failure.

Parameters:
  • model (Callable) – NumPyro model function.

  • model_kwargs (dict[str, Any]) – Model arguments.

  • config (CMCConfig) – Configuration.

  • initial_values (dict[str, float] | None) – Initial values.

  • parameter_space (ParameterSpace) – Parameter space.

  • n_phi (int) – Number of phi angles.

  • analysis_mode (str) – Analysis mode.

  • max_retries (int) – Maximum number of retry attempts.

  • rng_key (jax.random.PRNGKey | None) – Random key.

Returns:

Samples and statistics.

Return type:

tuple[MCMCSamples, SamplingStats]

Raises:

RuntimeError – If all retries fail.