"""PBS (Portable Batch System) backend for CMC HPC cluster execution.
This module provides distributed MCMC execution on HPC clusters
using PBS job scheduling.
Note: This backend requires:
- PBS/Torque job scheduler (qsub, qstat commands)
- Shared filesystem accessible from all nodes
- homodyne installed on compute nodes
"""
from __future__ import annotations
import json
import os
import re
import shutil
import subprocess # nosec B404
import tempfile
import time
from collections.abc import Callable
from pathlib import Path
from typing import TYPE_CHECKING, Any
import numpy as np
from homodyne.optimization.cmc.backends.base import CMCBackend, combine_shard_samples
from homodyne.utils.logging import get_logger
if TYPE_CHECKING:
from homodyne.optimization.cmc.config import CMCConfig
from homodyne.optimization.cmc.data_prep import PreparedData
from homodyne.optimization.cmc.sampler import MCMCSamples
logger = get_logger(__name__)
_SAFE_PATH_RE = re.compile(r"^[/a-zA-Z0-9._-]+$")
_PBS_JOB_ID_RE = re.compile(r"^\d+(\.\w+)*$")
_PBS_QUEUE_RE = re.compile(r"^[a-zA-Z0-9_-]+$")
_PBS_WALLTIME_RE = re.compile(r"^\d{1,3}:\d{2}:\d{2}$")
_PBS_MEMORY_RE = re.compile(r"^\d+[kmgKMG][bB]$")
def _validate_pbs_params(*, queue: str, walltime: str, memory: str) -> None:
"""Validate PBS parameters to prevent shell injection.
Parameters
----------
queue : str
PBS queue name. Must contain only alphanumeric chars, hyphens,
and underscores.
walltime : str
Job walltime in HH:MM:SS format (hours may be 1-3 digits).
memory : str
Memory specification (e.g. "8gb", "16GB", "512mb").
Raises
------
ValueError
If any parameter fails validation.
"""
if not _PBS_QUEUE_RE.match(queue):
raise ValueError(
f"Invalid PBS queue name {queue!r}: must contain only "
"alphanumeric characters, hyphens, and underscores"
)
if not _PBS_WALLTIME_RE.match(walltime):
raise ValueError(
f"Invalid PBS walltime {walltime!r}: must match HH:MM:SS format "
"(e.g. '04:00:00', '120:00:00')"
)
if not _PBS_MEMORY_RE.match(memory):
raise ValueError(
f"Invalid PBS memory {memory!r}: must match pattern like "
"'8gb', '16GB', '512mb', '4kb'"
)
# Default PBS job template
PBS_JOB_TEMPLATE = """#!/bin/bash
#PBS -N cmc_shard_{shard_id}
#PBS -l nodes=1:ppn={ppn}
#PBS -l walltime={walltime}
#PBS -l mem={memory}
#PBS -o {output_dir}/shard_{shard_id}.out
#PBS -e {output_dir}/shard_{shard_id}.err
#PBS -q {queue}
cd $PBS_O_WORKDIR
# Activate environment if specified
{activate_env}
# Run shard worker
python -m homodyne.optimization.cmc.backends.pbs_worker \\
--shard-file {shard_file} \\
--config-file {config_file} \\
--output-file {result_file}
"""
[docs]
class PBSBackend(CMCBackend):
"""PBS backend for HPC cluster MCMC execution.
Submits each data shard as a separate PBS job and combines
results after all jobs complete.
Parameters
----------
queue : str
PBS queue name (default: "batch").
ppn : int
Processors per node (default: 4).
walltime : str
Job walltime (default: "04:00:00").
memory : str
Memory per job (default: "8gb").
poll_interval : int
Seconds between job status checks (default: 30).
max_wait_time : int
Maximum wait time in seconds (default: 14400 = 4 hours).
"""
[docs]
def __init__(
self,
queue: str = "batch",
ppn: int = 4,
walltime: str = "04:00:00",
memory: str = "8gb",
poll_interval: int = 30,
max_wait_time: int = 14400,
) -> None:
"""Initialize PBS backend."""
_validate_pbs_params(queue=queue, walltime=walltime, memory=memory)
self.queue = queue
self.ppn = ppn
self.walltime = walltime
self.memory = memory
self.poll_interval = poll_interval
self.max_wait_time = max_wait_time
self._validate_pbs_available()
def _validate_pbs_available(self) -> None:
"""Check if PBS commands are available."""
try:
qstat_path = shutil.which("qstat")
if not qstat_path:
raise FileNotFoundError("qstat not found")
subprocess.run( # noqa: S603 - qstat_path from shutil.which is trusted
[qstat_path, "--version"],
capture_output=True,
timeout=10,
check=False,
)
logger.info("PBSBackend: PBS scheduler detected")
except FileNotFoundError:
logger.warning(
"PBSBackend: qstat not found. PBS commands may not be available."
)
except subprocess.TimeoutExpired:
logger.warning("PBSBackend: qstat timed out")
[docs]
def get_name(self) -> str:
"""Get backend name.
Returns
-------
str
Backend identifier.
"""
return "pbs"
[docs]
def is_available(self) -> bool:
"""Check if PBS backend is available.
Returns
-------
bool
True if PBS commands are accessible.
Notes
-----
P2-R6-06: Previously ran bare ``qsub`` with no arguments, which exits
non-zero on all PBS/Torque versions (missing jobscript), so this method
always returned False on valid clusters. Now checks for the presence of
the qsub binary via shutil.which, which is sufficient to determine
availability without triggering an error submission.
"""
return shutil.which("qsub") is not None
[docs]
def run(
self,
model: Callable,
model_kwargs: dict[str, Any],
config: CMCConfig,
shards: list[PreparedData] | None = None,
) -> MCMCSamples:
"""Run MCMC sampling via PBS job submission.
Parameters
----------
model : Callable
NumPyro model function (not directly used - workers import it).
model_kwargs : dict[str, Any]
Common model arguments.
config : CMCConfig
CMC configuration.
shards : list[PreparedData] | None
Data shards for parallel execution.
Returns
-------
MCMCSamples
Combined samples from all PBS jobs.
Raises
------
RuntimeError
If jobs fail or timeout.
"""
if shards is None or len(shards) == 0:
raise ValueError("PBSBackend requires sharded data")
logger.info(f"PBSBackend: Submitting {len(shards)} PBS jobs")
# Create temporary directory for job files
with tempfile.TemporaryDirectory(prefix="cmc_pbs_") as tmpdir:
tmpdir_path = Path(tmpdir)
# Save config
config_file = tmpdir_path / "config.json"
self._save_config(config, model_kwargs, config_file)
# Submit jobs for each shard
job_ids: list[str] = []
result_files: list[Path] = []
for i, shard in enumerate(shards):
shard_file = tmpdir_path / f"shard_{i}.npz"
result_file = tmpdir_path / f"result_{i}.npz"
# Save shard data
self._save_shard(shard, shard_file)
result_files.append(result_file)
# Submit PBS job
job_id = self._submit_job(
shard_id=i,
shard_file=shard_file,
config_file=config_file,
result_file=result_file,
output_dir=tmpdir_path,
)
job_ids.append(job_id)
logger.info(f"Submitted shard {i} as job {job_id}")
# Wait for all jobs to complete
self._wait_for_jobs(job_ids)
# Load and combine results
shard_results = self._load_results(result_files)
# Combine samples
# P2-R6-01: Use config.combination_method directly; CMCConfig always
# has this field (defaults to "robust_consensus_mc"). The stale
# "weighted_gaussian" fallback was misleading and incorrect.
combined = combine_shard_samples(
shard_results,
method=config.combination_method,
)
logger.info("PBSBackend: All jobs completed successfully")
return combined
def _save_config(
self,
config: CMCConfig,
model_kwargs: dict[str, Any],
path: Path,
) -> None:
"""Save configuration for PBS workers."""
# Convert config to serializable dict
config_dict = {
"num_warmup": config.num_warmup,
"num_samples": config.num_samples,
"num_chains": config.num_chains,
"target_accept_prob": config.target_accept_prob,
"max_tree_depth": getattr(config, "max_tree_depth", 10),
}
# Add model kwargs (excluding non-serializable)
serializable_kwargs = {}
for key, value in model_kwargs.items():
if isinstance(value, (int, float, str, bool, list, dict)):
serializable_kwargs[key] = value
elif isinstance(value, np.ndarray):
# Skip arrays - they're in shard files
pass
combined = {
"config": config_dict,
"model_kwargs": serializable_kwargs,
}
with open(path, "w", encoding="utf-8") as f:
json.dump(combined, f, indent=2)
def _save_shard(self, shard: PreparedData, path: Path) -> None:
"""Save shard data for PBS worker."""
np.savez_compressed(
path,
data=shard.data,
t1=shard.t1,
t2=shard.t2,
phi=shard.phi,
phi_indices=shard.phi_indices,
phi_unique=shard.phi_unique,
n_phi=shard.n_phi,
noise_scale=shard.noise_scale,
)
def _submit_job(
self,
shard_id: int,
shard_file: Path,
config_file: Path,
result_file: Path,
output_dir: Path,
) -> str:
"""Submit a PBS job for one shard.
Returns
-------
str
PBS job ID.
"""
# Check for conda/venv activation (validate paths for shell safety)
activate_env = ""
conda_prefix = os.environ.get("CONDA_PREFIX", "")
virtual_env = os.environ.get("VIRTUAL_ENV", "")
if conda_prefix:
if not _SAFE_PATH_RE.match(conda_prefix):
raise ValueError(f"Unsafe characters in CONDA_PREFIX: {conda_prefix!r}")
activate_env = f"source activate {conda_prefix}"
elif virtual_env:
if not _SAFE_PATH_RE.match(virtual_env):
raise ValueError(f"Unsafe characters in VIRTUAL_ENV: {virtual_env!r}")
activate_env = f"source {virtual_env}/bin/activate"
# Generate job script
job_script = PBS_JOB_TEMPLATE.format(
shard_id=shard_id,
ppn=self.ppn,
walltime=self.walltime,
memory=self.memory,
queue=self.queue,
output_dir=output_dir,
activate_env=activate_env,
shard_file=shard_file,
config_file=config_file,
result_file=result_file,
)
# Write job script
script_file = output_dir / f"job_{shard_id}.pbs"
with open(script_file, "w", encoding="utf-8") as f:
f.write(job_script)
# Submit job
qsub_path = shutil.which("qsub") or "qsub"
result = subprocess.run( # noqa: S603 - qsub_path from shutil.which is trusted
[qsub_path, str(script_file)],
capture_output=True,
text=True,
check=True,
)
job_id = result.stdout.strip()
if not _PBS_JOB_ID_RE.match(job_id):
raise ValueError(f"Unexpected PBS job ID format: {job_id!r}")
return job_id
def _wait_for_jobs(self, job_ids: list[str]) -> None:
"""Wait for all PBS jobs to complete.
Raises
------
RuntimeError
If jobs fail or timeout.
"""
start_time = time.time()
pending_jobs = set(job_ids)
while pending_jobs:
elapsed = time.time() - start_time
if elapsed > self.max_wait_time:
raise RuntimeError(
f"PBS jobs timed out after {self.max_wait_time}s. "
f"Remaining jobs: {pending_jobs}"
)
# Check job status
completed = set()
for job_id in pending_jobs:
status = self._get_job_status(job_id)
if status == "C": # Completed
completed.add(job_id)
logger.debug(f"Job {job_id} completed")
elif status == "E": # Exiting: PBS/Torque epilogue/cleanup (normal)
# P1-R5-02: PBS "E" means "Exiting" (job completing normally),
# NOT "Error". Every successful PBS job passes through "E" state
# during its teardown/epilogue phase. Treating it as failure
# aborts every successful CMC job. Wait for the job to leave
# qstat entirely (returncode != 0 in _get_job_status -> "C").
pass
pending_jobs -= completed
if pending_jobs:
logger.info(
f"Waiting for {len(pending_jobs)} jobs... ({int(elapsed)}s elapsed)"
)
time.sleep(self.poll_interval)
def _get_job_status(self, job_id: str) -> str:
"""Get PBS job status.
Returns
-------
str
Job status per PBS/Torque conventions:
Q = queued, R = running, E = exiting (epilogue/cleanup, normal
completion phase), C = completed (returned when job no longer
appears in qstat output). Note: there is no "error" job_state in
standard PBS/Torque; failures are detected via exit_status in
qstat -f output, not via job_state.
"""
try:
qstat_path = shutil.which("qstat") or "qstat"
result = subprocess.run( # noqa: S603 - qsub_path from shutil.which is trusted
[qstat_path, "-f", job_id],
capture_output=True,
text=True,
timeout=30,
check=False,
)
if result.returncode != 0:
# Job no longer in queue - assume completed
return "C"
# Parse status and exit_status from output
state = None
exit_status = None
for line in result.stdout.split("\n"):
stripped = line.strip()
if "job_state" in stripped:
state = stripped.split("=")[-1].strip()
if "exit_status" in stripped:
try:
exit_status = int(stripped.split("=")[-1].strip())
except ValueError:
pass
if state is None:
return "C" # Assume completed if can't parse
if state == "C" and exit_status is not None and exit_status != 0:
logger.error(
f"PBS job {job_id} failed with exit_status={exit_status}. "
"Result file will be absent; shard will be treated as failed."
)
return state
except subprocess.TimeoutExpired:
logger.warning(f"qstat timeout for job {job_id}")
return "R" # Assume still running
def _load_results(self, result_files: list[Path]) -> list[MCMCSamples]:
"""Load results from completed PBS jobs."""
from homodyne.optimization.cmc.sampler import MCMCSamples
results = []
missing = [str(p) for p in result_files if not p.exists()]
if missing:
raise RuntimeError(
f"{len(missing)}/{len(result_files)} PBS result file(s) not found "
f"(shard worker(s) failed or timed out): {missing[:5]}"
+ (" ..." if len(missing) > 5 else "")
)
for path in result_files:
data = np.load(path, allow_pickle=False)
# Reconstruct samples dict from prefixed arrays
# P3-R6-01: .tolist() converts numpy string scalars to Python str,
# matching the convention in io.py and downstream string comparisons.
param_names = data["param_names"].tolist()
samples_dict = {
name: data[f"sample_{name}"]
for name in param_names
if f"sample_{name}" in data
}
# Fallback for old format with single "samples" key
if not samples_dict and "samples" in data:
samples_arr = data["samples"]
samples_dict = {
name: samples_arr[..., i] for i, name in enumerate(param_names)
}
# Reconstruct extra_fields from prefixed arrays
extra_fields: dict[str, Any] = {}
for key in data.files:
if key.startswith("extra_"):
field_name = key[6:] # Remove "extra_" prefix
extra_fields[field_name] = data[key]
# Derive actual chain/sample counts from array shape (not serialized metadata)
first_param = next(iter(samples_dict.values()), None)
actual_n_chains = (
first_param.shape[0]
if first_param is not None
else int(data["n_chains"])
)
actual_n_samples = (
first_param.shape[1]
if first_param is not None and first_param.ndim >= 2
else int(data["n_samples"])
)
samples = MCMCSamples(
samples=samples_dict,
param_names=param_names,
n_chains=actual_n_chains,
n_samples=actual_n_samples,
extra_fields=extra_fields,
)
results.append(samples)
return results