Source code for homodyne.cli.xla_config
#!/usr/bin/env python3
"""Configure XLA_FLAGS for Homodyne optimization workflows."""
import argparse
import os
import sys
from pathlib import Path
VALID_MODES = ["cmc", "cmc-hpc", "nlsq", "auto"]
def _get_config_file() -> Path:
"""Get the XLA mode config file path (per-venv > XDG > legacy).
Uses the same resolution logic as post_install.get_xla_mode_path().
"""
# Prefer per-environment config
venv = os.environ.get("VIRTUAL_ENV") or os.environ.get("CONDA_PREFIX")
if venv:
return Path(venv) / "etc" / "homodyne" / "xla_mode"
# Fall back to XDG config directory
xdg_config = os.environ.get("XDG_CONFIG_HOME", "")
if not xdg_config:
xdg_config = str(Path.home() / ".config")
return Path(xdg_config) / "homodyne" / "xla_mode"
[docs]
def detect_optimal_devices():
"""Detect optimal XLA device count based on CPU cores."""
try:
import psutil
cores = psutil.cpu_count(logical=False) or 4
except ImportError:
cores = os.cpu_count() or 4
# Same logic as shell scripts
if cores <= 7:
return 2
elif cores <= 15:
return 4
elif cores <= 35:
return 6
else:
return 8
[docs]
def set_mode(mode: str) -> bool:
"""Save XLA mode preference."""
if mode not in VALID_MODES:
print(f"Error: Invalid mode '{mode}'", file=sys.stderr)
print(f"Valid modes: {', '.join(VALID_MODES)}", file=sys.stderr)
return False
config_file = _get_config_file()
try:
config_file.parent.mkdir(parents=True, exist_ok=True)
config_file.write_text(mode + "\n", encoding="utf-8")
except OSError as e:
print(
f"Error: Cannot write XLA mode config to {config_file}: {e}",
file=sys.stderr,
)
return False
print(f"OK: XLA mode set to: {mode}")
# Show what this means
if mode == "cmc":
print(" -> 4 CPU devices for parallel CMC chains")
elif mode == "cmc-hpc":
print(" -> 8 CPU devices for HPC clusters (36+ cores)")
elif mode == "nlsq":
print(" -> 1 CPU device (NLSQ doesn't need parallelism)")
elif mode == "auto":
devices = detect_optimal_devices()
cores = os.cpu_count() or 4
print(f" -> Auto-detect: {devices} devices (detected {cores} CPU cores)")
# Suggest reactivation to pick up the new mode
print("\nDeactivate and reactivate your venv to apply the new mode.")
return True
[docs]
def show_config():
"""Display current XLA configuration."""
config_file = _get_config_file()
# Read current mode
if config_file.exists():
try:
mode = config_file.read_text(encoding="utf-8").strip()
except OSError:
mode = "cmc (default, config unreadable)"
else:
mode = "cmc (default)"
# Get current XLA_FLAGS
xla_flags = os.environ.get("XLA_FLAGS", "Not set")
print("Current XLA Configuration:")
print(f" Mode: {mode}")
print(f" XLA_FLAGS: {xla_flags}")
print(f" Config file: {config_file}")
# Show JAX devices if available
try:
import jax
devices = jax.devices()
print(f" JAX devices: {len(devices)} ({devices[0].platform})")
for i, dev in enumerate(devices):
print(f" [{i}] {dev}")
except ImportError:
print(" JAX: Not installed")
except Exception as e:
print(f" JAX devices: Error - {e}")
[docs]
def main() -> None:
"""Main entry point for homodyne-config-xla."""
parser = argparse.ArgumentParser(
description="Configure XLA_FLAGS for Homodyne workflows",
epilog="""
Examples:
homodyne-config-xla --mode cmc # 4 devices for CMC
homodyne-config-xla --mode auto # Auto-detect based on CPU
homodyne-config-xla --show # Show current configuration
Modes:
cmc 4 devices (multi-core workstations)
cmc-hpc 8 devices (HPC with 36+ cores)
nlsq 1 device (NLSQ-only workflows)
auto Auto-detect based on CPU cores
""",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--mode",
choices=VALID_MODES,
help="Set XLA mode (cmc, cmc-hpc, nlsq, auto)",
)
parser.add_argument(
"--show", action="store_true", help="Show current configuration"
)
args = parser.parse_args()
if args.show:
show_config()
elif args.mode:
set_mode(args.mode)
else:
parser.print_help()
if __name__ == "__main__":
main()