Source code for homodyne.device.config
"""Hardware detection and configuration helpers for CMC.
=======================================================
This module now only detects hardware characteristics to size shards and
recommend the execution backend for Consensus Monte Carlo (CMC). Method
selection is handled upstream and CMC is always used for MCMC paths.
Usage
-----
from homodyne.device.config import detect_hardware
hw_config = detect_hardware()
print(f"Platform: {hw_config.platform}")
print(f"Recommended backend: {hw_config.recommended_backend}")
Integration
-----------
- CMC coordinator reads :class:`HardwareConfig` for backend selection and
shard sizing.
- No method-selection logic remains here; CMC is the only MCMC path.
"""
import multiprocessing
import os
from dataclasses import dataclass
from typing import Literal
try:
import psutil
HAS_PSUTIL = True
except ImportError:
HAS_PSUTIL = False
from homodyne.utils.logging import get_logger
logger = get_logger(__name__)
[docs]
@dataclass(frozen=True)
class HardwareConfig:
"""Hardware configuration for CMC optimization.
This dataclass encapsulates all detected hardware information needed
for intelligent CMC decision-making and backend selection.
Attributes
----------
platform : {'cpu'}
Primary compute platform (CPU-only in v2.3.0+)
num_devices : int
Number of available CPU devices
memory_per_device_gb : float
Available system memory in GB
num_nodes : int
Number of cluster nodes (1 for standalone)
cores_per_node : int
Number of physical CPU cores per node
total_memory_gb : float
Total system memory in GB
cluster_type : {'pbs', 'slurm', 'standalone', None}
Detected cluster scheduler type
recommended_backend : str
Recommended CMC backend based on hardware
Options: 'pjit', 'multiprocessing', 'pbs', 'slurm'
max_parallel_shards : int
Maximum number of shards that can run in parallel
- Multi-node cluster: num_nodes * cores_per_node
- CPU: cores_per_node
Examples
--------
>>> hw = detect_hardware()
>>> print(hw.platform)
'cpu'
>>> print(hw.max_parallel_shards)
4
>>> print(hw.recommended_backend)
'multiprocessing'
"""
platform: Literal["cpu"]
num_devices: int
memory_per_device_gb: float
num_nodes: int
cores_per_node: int
total_memory_gb: float
cluster_type: Literal["pbs", "slurm", "standalone"] | None
recommended_backend: str
max_parallel_shards: int
[docs]
def detect_hardware() -> HardwareConfig:
"""Auto-detect hardware configuration for CMC optimization.
This function performs comprehensive hardware detection to inform
intelligent CMC strategy selection and backend choice.
Detection Logic
---------------
1. **JAX Devices**: Query JAX for CPU devices (v2.3.0+ is CPU-only)
2. **System Memory**: Query total system memory via psutil
- Fallback: Assume 32GB if psutil unavailable
3. **Cluster Environment**: Check environment variables
- PBS: PBS_JOBID, PBS_NODEFILE
- Slurm: SLURM_JOB_NUM_NODES, SLURM_CPUS_ON_NODE
- Standalone: Neither PBS nor Slurm detected
4. **CPU Resources**: Count physical cores using psutil
5. **Backend Recommendation**: Select optimal backend based on:
- Multi-node cluster → PBS/Slurm backend
- CPU standalone → multiprocessing backend
Returns
-------
HardwareConfig
Comprehensive hardware configuration for CMC
Examples
--------
>>> hw = detect_hardware()
>>> print(hw.platform)
'cpu'
>>> print(hw.num_devices)
4
>>> print(hw.memory_per_device_gb)
64.0
>>> print(hw.cluster_type)
'pbs'
>>> print(hw.recommended_backend)
'pbs'
Notes
-----
- Detection is robust with multiple fallback mechanisms
- Cluster detection requires environment variables set by scheduler
- CPU core count excludes hyperthreading for accurate parallelism
- v2.3.0+ is CPU-only; JAX will always report platform='cpu'
"""
logger.info("Detecting hardware configuration for CMC...")
# Step 1: Detect JAX devices
# Use the actual active backend, not just first device in list
# When JAX_PLATFORMS="cpu,gpu", devices[0] may be CPU even if GPU is active
try:
# Try new API first (JAX 0.8.0+), fall back to legacy API
try:
from jax.extend import backend as jax_backend
backend = jax_backend.get_backend()
except (ImportError, AttributeError):
# Legacy API for JAX < 0.8.0
import importlib
xla_bridge = importlib.import_module("jax.lib.xla_bridge")
backend = xla_bridge.get_backend()
platform = backend.platform
devices = backend.devices()
num_devices = len(devices)
logger.info(f"JAX devices detected: {num_devices} {platform} device(s)")
except Exception as e:
logger.warning(f"JAX device detection failed: {e}. Falling back to CPU.")
platform = "cpu"
num_devices = 1
# Step 2: Query system memory (CPU-only in v2.3.0+)
if HAS_PSUTIL:
memory_gb = psutil.virtual_memory().total / 1e9
logger.info(f"System memory detected: {memory_gb:.2f} GB")
else:
logger.warning("psutil not available. Assuming 32 GB system memory")
memory_gb = 32.0
# Step 3: Detect cluster environment
cluster_type: Literal["pbs", "slurm", "standalone"] | None = None
num_nodes = 1
if "PBS_JOBID" in os.environ:
cluster_type = "pbs"
# Parse PBS_NODEFILE for node count
nodefile = os.environ.get("PBS_NODEFILE")
if nodefile and os.path.exists(nodefile):
try:
with open(nodefile, encoding="utf-8") as f:
# Strip whitespace and skip blank lines before deduplication;
# PBS nodefiles often contain a trailing newline or blank lines.
num_nodes = len(
{line.strip() for line in f.read().splitlines() if line.strip()}
)
logger.info(f"PBS cluster detected: {num_nodes} nodes")
except Exception as e:
logger.warning(f"Failed to parse PBS_NODEFILE: {e}")
num_nodes = 1
else:
logger.debug("PBS_JOBID present but PBS_NODEFILE not found")
num_nodes = 1
elif "SLURM_JOB_NUM_NODES" in os.environ:
cluster_type = "slurm"
try:
num_nodes = int(os.environ.get("SLURM_JOB_NUM_NODES", 1))
logger.info(f"Slurm cluster detected: {num_nodes} nodes")
except ValueError:
logger.warning("Failed to parse SLURM_JOB_NUM_NODES")
num_nodes = 1
else:
cluster_type = "standalone"
num_nodes = 1
logger.info("Standalone system detected (no cluster scheduler)")
# Step 4: Detect CPU cores
if HAS_PSUTIL:
# Use physical cores (exclude hyperthreading)
cores_per_node = psutil.cpu_count(logical=False) or 1
total_memory_gb = psutil.virtual_memory().total / 1e9
logger.info(f"CPU cores detected: {cores_per_node} physical cores")
else:
logger.warning("psutil not available. Using multiprocessing for CPU count")
cores_per_node = multiprocessing.cpu_count()
total_memory_gb = memory_gb # Use previously detected value
# Step 5: Recommend backend and calculate max parallel shards (CPU-only in v2.3.0+)
recommended_backend: str
if cluster_type in ["pbs", "slurm"] and num_nodes > 1:
# Multi-node cluster: Use PBS/Slurm backend
recommended_backend = cluster_type
max_parallel_shards = num_nodes * cores_per_node
logger.info(
f"Recommended backend: {recommended_backend} "
f"(max {max_parallel_shards} parallel shards)"
)
else:
# CPU standalone: Use multiprocessing backend
recommended_backend = "multiprocessing"
max_parallel_shards = cores_per_node
logger.info(
f"Recommended backend: multiprocessing "
f"(max {max_parallel_shards} parallel shards)"
)
# Construct and return HardwareConfig
hw_config = HardwareConfig(
platform=platform,
num_devices=num_devices,
memory_per_device_gb=memory_gb,
num_nodes=num_nodes,
cores_per_node=cores_per_node,
total_memory_gb=total_memory_gb,
cluster_type=cluster_type,
recommended_backend=recommended_backend,
max_parallel_shards=max_parallel_shards,
)
logger.info(f"Hardware detection complete: {hw_config.platform} platform")
return hw_config
# Export public API
__all__ = [
"HardwareConfig",
"detect_hardware",
]