Source code for homodyne.cli.main
"""Minimal CLI Entry Point for Homodyne
=======================================
Simplified command-line interface for homodyne scattering analysis with
JAX-first optimization methods.
Entry point for console script: homodyne [args]
"""
import os
import sys
# ============================================================================
# JAX CPU Device Configuration (MUST be set before JAX import)
# ============================================================================
# Configure JAX to use multiple CPU devices for parallel MCMC chains
# This MUST be set before JAX/XLA is initialized (import time)
# Default: 4 devices for parallel MCMC, can be overridden by user
# P2-A: Set JAX_ENABLE_X64 explicitly before any JAX import.
os.environ.setdefault("JAX_ENABLE_X64", "1")
_DEFAULT_XLA_FLAGS = [
"--xla_force_host_platform_device_count=4",
"--xla_disable_hlo_passes=constant_folding",
]
if "XLA_FLAGS" not in os.environ:
os.environ["XLA_FLAGS"] = " ".join(_DEFAULT_XLA_FLAGS)
else:
existing = os.environ["XLA_FLAGS"]
for flag in _DEFAULT_XLA_FLAGS:
flag_name = flag.split("=")[0]
if flag_name not in existing:
os.environ["XLA_FLAGS"] += " " + flag
# Suppress NLSQ GPU warnings (v2.3.0 is CPU-only)
os.environ.setdefault("NLSQ_SKIP_GPU_CHECK", "1")
# Force non-interactive matplotlib backend for CLI (renders to files only).
# Must be set before any matplotlib import. Without this, the default TkAgg
# backend creates tkinter objects that crash during GC when background threads
# (e.g. performance_engine) are still alive at shutdown.
os.environ.setdefault("MPLBACKEND", "Agg")
# Suppress JAX backend logs (set to ERROR to hide GPU fallback warnings)
# This must be done before any imports that trigger JAX initialization
import logging # noqa: E402 - Must import after os.environ configuration
logging.getLogger("jax._src.xla_bridge").setLevel(logging.ERROR)
logging.getLogger("jax._src.compiler").setLevel(logging.ERROR)
# Note: GPU support removed in v2.3.0 (CPU-only)
from homodyne.cli.args_parser import create_parser # noqa: E402
from homodyne.cli.commands import dispatch_command # noqa: E402
from homodyne.utils.logging import ( # noqa: E402
LogConfiguration,
get_logger,
log_exception,
)
logger = get_logger(__name__)
[docs]
def main() -> None:
"""Main CLI entry point.
Processes command-line arguments and dispatches to appropriate command handler.
Uses LogConfiguration.from_cli_args() for --verbose/-v and --quiet/-q flags.
Creates timestamped log file per analysis run.
"""
try:
# Parse arguments
parser = create_parser()
args = parser.parse_args()
# Configure logging using LogConfiguration (T017, T018)
# This handles --verbose, --quiet, and creates timestamped log files
# Apply --threads: set XLA intra-op parallelism
if getattr(args, "threads", None) is not None:
os.environ["XLA_FLAGS"] = (
os.environ.get("XLA_FLAGS", "")
+ " --xla_cpu_multi_thread_eigen=true"
+ f" intra_op_parallelism_threads={args.threads}"
)
# Apply --no-jit: disable JIT compilation for debugging
if getattr(args, "no_jit", False):
import jax
jax.config.update("jax_disable_jit", True)
log_config = LogConfiguration.from_cli_args(
verbose=getattr(args, "verbose", 0) > 0,
quiet=getattr(args, "quiet", False),
log_file=None, # Auto-generate timestamped log file
)
# Apply granular verbosity: -vv = DEBUG, -vvv = TRACE
verbose_level = getattr(args, "verbose", 0)
if verbose_level >= 3:
import logging as _logging
_logging.getLogger("homodyne").setLevel(5) # TRACE
elif verbose_level >= 2:
import logging as _logging
_logging.getLogger("homodyne").setLevel(_logging.DEBUG)
log_file = log_config.apply()
if log_file:
logger.debug(f"Log file created: {log_file}")
# Log startup
logger.info("Starting homodyne analysis...")
logger.debug(f"Arguments: {vars(args)}")
# Dispatch command (device configuration happens inside dispatch_command)
# Note: GPU status is checked and logged during device configuration,
# not here at startup, to avoid premature/inaccurate warnings
result = dispatch_command(args)
# Handle result
if result and result.get("success", False):
logger.info("Analysis completed successfully")
sys.exit(0)
else:
error_msg = (
result.get("error", "Unknown error") if result else "Command failed"
)
logger.error(f"Analysis failed: {error_msg}")
sys.exit(1)
except KeyboardInterrupt:
logger.info("Analysis interrupted by user")
sys.exit(130)
except Exception as e:
# Use structured exception logging (T003)
log_exception(logger, e, context={"command": "main"})
sys.exit(1)
[docs]
def main_hexp() -> None:
"""Entry point for ``hexp`` — plot experimental data."""
sys.argv[1:] = ["--plot-experimental-data"] + sys.argv[1:]
main()
[docs]
def main_hsim() -> None:
"""Entry point for ``hsim`` — plot simulated data."""
sys.argv[1:] = ["--plot-simulated-data"] + sys.argv[1:]
main()
if __name__ == "__main__":
main()