From 70b19708d9cb7ef13316440694466c9cd9860da0 Mon Sep 17 00:00:00 2001 From: namitdhameja Date: Mon, 6 Apr 2026 09:33:12 -0700 Subject: [PATCH 1/6] feat: ft_launcher integration of log analysis for restart decisions --- docs/source/fault_tolerance/usage_guide.rst | 88 +++- pyproject.toml | 1 + services/nvrx_attrsvc/ATTRSVC_SPEC.md | 10 +- services/nvrx_attrsvc/README.md | 14 +- services/nvrx_attrsvc/config.py | 12 +- services/nvrx_attrsvc/deploy/Dockerfile | 2 +- services/nvrx_attrsvc/deploy/kubernetes.yaml | 6 +- .../nvrx_attrsvc/deploy/nvrx-attrsvc.service | 2 +- services/nvrx_attrsvc/deploy/run_attrsvc.sh | 6 +- services/nvrx_attrsvc/deploy/slurm.sbatch | 8 +- services/scripts/README.md | 14 +- services/scripts/build_enroot_image.sh | 4 +- services/scripts/common.sh | 54 +-- services/scripts/nvrx_services.sbatch | 8 +- services/scripts/run_services.sh | 52 +-- services/scripts/setup_systemd.sh | 18 +- .../attribution/api_keys.py | 28 +- .../combined_log_fr/combined_log_fr.py | 6 +- .../attribution/combined_log_fr/llm_merge.py | 10 +- .../attribution/log_analyzer/__init__.py | 3 +- .../log_analyzer/analysis_pipeline.py | 12 +- .../attribution/log_analyzer/nvrx_logsage.py | 8 +- .../attribution/log_analyzer/runner.py | 77 +++- .../trace_analyzer/fr_attribution.py | 16 +- .../fault_tolerance/__init__.py | 2 + .../fault_tolerance/config.py | 135 +++++- .../fault_tolerance/ft_attribution.py | 417 ++++++++++++++++++ .../fault_tolerance/ft_rendezvous_barrier.py | 65 +-- .../fault_tolerance/launcher.py | 268 ++++++++--- .../fault_tolerance/utils.py | 14 +- tests/attribution/unit/test_api_keys.py | 31 +- 31 files changed, 1095 insertions(+), 296 deletions(-) create mode 100644 src/nvidia_resiliency_ext/fault_tolerance/ft_attribution.py diff --git a/docs/source/fault_tolerance/usage_guide.rst b/docs/source/fault_tolerance/usage_guide.rst index 64bfee41..cda2d451 100644 --- a/docs/source/fault_tolerance/usage_guide.rst +++ b/docs/source/fault_tolerance/usage_guide.rst @@ -109,33 +109,91 @@ Validation behavior: - Other existing types (e.g., devices/symlinks): performs ``stat`` access -Attribution service integration -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Attribution integration +^^^^^^^^^^^^^^^^^^^^^^ -Enable artifact analysis (e.g., logs) during rendezvous health checks by pointing to a running attribution service. -The feature is enabled by specifying both host and port. +Enable artifact analysis (e.g., logs) during rendezvous to make RESTART/STOP decisions. +You can configure **one or more backends** (e.g. ``mcp`` for LogSage + FR via MCP, plus an HTTP URL for a third-party service). The run stops the workload (no restart) if **any** backend reports do not restart. -* CLI: +Use ``--ft-attribution-backend`` (repeatable) and/or YAML ``attribution_backends``. + +* ``mcp``: Log analysis via MCP subprocess (``nvrx-mcp-analysis``). +* **HTTP URL** (no separate keyword): pass the URL as the flag value, e.g. + ``--ft-attribution-backend http://127.0.0.1:8000`` or ``--ft-attribution-backend host:port`` + (``http://`` is added when you use ``host:port`` form). - - ``--ft-attrsvc-host `` (alias: ``--ft_attrsvc_host``) - - ``--ft-attrsvc-port `` (alias: ``--ft_attrsvc_port``) +* CLI: - Example: + - ``--ft-attribution-backend`` (alias: ``--ft_attribution_backend``): Add one backend; repeat for multiple. + Each value is ``mcp`` or an HTTP URL. Combined with YAML ``attribution_backends``. + - ``--ft-attribution-timeout`` (alias: ``--ft_attribution_timeout``): Wait/timeout in seconds; + skip result if exceeded (default: 60). + - ``--ft-attribution-dry-run`` (alias: ``--ft_attribution_dry_run``): Dry run. Run the full + attribution chain (log analysis, Slack, dataflow) but do not apply the restart/stop decision. + Log what would happen instead. Useful for validating the pipeline without affecting behavior. + - ``--ft-slack-token-file`` (alias: ``--ft_slack_token_file``): Path to file containing Slack bot token. + When not set, uses ``SLACK_BOT_TOKEN`` or ``SLACK_BOT_TOKEN_FILE`` env vars. + - ``--ft-slack-channel`` (alias: ``--ft_slack_channel``): Slack channel for alerts. + When not set, uses ``SLACK_CHANNEL`` env var. + - ``--ft-dataflow-index`` (alias: ``--ft_dataflow_index``): Elasticsearch/dataflow index for posting + attribution results (mcp/URL). Requires ``nvdataflow`` (install via ``pip install nvidia-resiliency-ext[dataflow]``). + When not set, dataflow posting is disabled. + - ``--ft-llm-api-key-file`` (alias: ``--ft_llm_api_key_file``): Path to a file containing the LLM API key. + Sets ``LLM_API_KEY_FILE`` in the process before MCP attribution starts. Overrides YAML ``llm_api_key_file`` when both are set. + + Examples: .. code-block:: bash - ft_launcher \ - --ft-attrsvc-host 127.0.0.1 \ - --ft-attrsvc-port 8000 \ - train.py + # MCP: log analysis via nvrx-mcp-analysis + ft_launcher --ft-attribution-backend mcp train.py + + # URL mode (HTTP attribution service) + ft_launcher --ft-attribution-backend http://127.0.0.1:8000 train.py + + # Service with custom timeout + ft_launcher --ft-attribution-backend http://127.0.0.1:8000 --ft-attribution-timeout 90 train.py + + # MCP with Slack and dataflow (token from file; channel from env) + ft_launcher --ft-attribution-backend mcp --ft-slack-token-file /etc/secrets/slack-token train.py + + # MCP with explicit Slack channel and dataflow index + ft_launcher --ft-attribution-backend mcp \ + --ft-slack-token-file /etc/secrets/slack-token --ft-slack-channel "#alerts" \ + --ft-dataflow-index my-attribution-index train.py + + # Dry run: exercise full attribution chain without applying restart/stop decision + ft_launcher --ft-attribution-backend mcp --ft-attribution-dry-run train.py + + # Multiple backends: MCP plus third-party HTTP service + ft_launcher --ft-attribution-backend mcp --ft-attribution-backend http://127.0.0.1:8000 train.py -* YAML: under the ``fault_tolerance`` section +* YAML: under the ``fault_tolerance`` section use ``attribution_backends`` (list of ``mcp`` and/or URLs), + ``attribution_timeout_seconds``, ``slack``, ``dataflow_index``, and optional ``llm_api_key_file``: .. code-block:: yaml fault_tolerance: - attrsvc_host: "127.0.0.1" - attrsvc_port: 8000 + # Prefer explicit list for multiple backends: + attribution_backends: + - "mcp" + - "http://127.0.0.1:8000" + attribution_timeout_seconds: 60 + attribution_dry_run: false # true = run chain but don't apply action; log only + slack: + bot_token_file: "/etc/secrets/slack-token" # or bot_token for inline (less secure) + channel: "#alerts" + dataflow_index: "my-attribution-index" # optional; requires nvdataflow + llm_api_key_file: "/etc/secrets/llm-api-key" # optional; sets LLM_API_KEY_FILE for MCP + +* Environment (fallback when CLI/YAML not set): + + - ``SLACK_BOT_TOKEN`` or ``SLACK_BOT_TOKEN_FILE``: Slack bot token for mcp/URL alerts. + - ``SLACK_CHANNEL``: Slack channel for alerts. + - **LLM / LogSage API key** (MCP backend): ``LLM_API_KEY`` or ``LLM_API_KEY_FILE``, or default files + ``~/.llm_api_key`` / ``~/.config/nvrx/llm_api_key`` (see ``load_llm_api_key`` in + ``nvidia_resiliency_ext.attribution.api_keys``). For ``ft_launcher``, use YAML ``llm_api_key_file`` or + ``--ft-llm-api-key-file``. GPU Memory Reclaim ^^^^^^^^^^^^^^^^^^ diff --git a/pyproject.toml b/pyproject.toml index 344abaa6..e552a9aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ setproctitle = ">=1.3.0" logsage = ">=0.1.7" grpcio = "^1.76.0" grpcio-tools = "^1.76.0" +httpx = ">=0.24.0" protobuf = ">=4.22.0" [tool.poetry.scripts] diff --git a/services/nvrx_attrsvc/ATTRSVC_SPEC.md b/services/nvrx_attrsvc/ATTRSVC_SPEC.md index 9500e380..251327c4 100644 --- a/services/nvrx_attrsvc/ATTRSVC_SPEC.md +++ b/services/nvrx_attrsvc/ATTRSVC_SPEC.md @@ -94,10 +94,10 @@ Two layers: **library** (`nvidia_resiliency_ext.attribution`) and **service** **3.1 Environment variables** — Full table and defaults: **README.md** (source of truth). Summary: -- Prefix **`NVRX_ATTRSVC_`** for service settings (see README for exceptions: NVIDIA - API key, Slack tokens, optional `NVIDIA_API_KEY_FILE` / file paths in `api_keys.py`). -- **`NVIDIA_API_KEY`**: required for attribution; loaded in `config.setup()` after - logging — **empty/missing → log error and process exit**. Slack is optional. +- Prefix **`NVRX_ATTRSVC_`** for service settings (see README for exceptions: LLM + API key, Slack tokens, optional `LLM_API_KEY_FILE` / file paths in `api_keys.py`). +- **`LLM_API_KEY`** / **`LLM_API_KEY_FILE`**: required for attribution (or default key files); + loaded in `config.setup()` after logging — **empty/missing → log error and process exit**. Slack is optional. - LLM-related env vars are optional; unset → library defaults (`LogAnalyzerConfig`). - Rate limits: slowapi, `RATE_LIMIT_SUBMIT` / `RATE_LIMIT_ANALYZE` / `RATE_LIMIT_PREVIEW`. @@ -144,7 +144,7 @@ Patterns tried in order (scheduler-agnostic where possible): `_(\d+)_date_`, -------------------------------------------------------------------------------- **Startup (conceptual)** -Load `Settings` → configure logging → **require non-empty NVIDIA API key** → wire +Load `Settings` → configure logging → **require non-empty LLM API key** → wire postprocessing (`configure`, poster, dataflow index, Slack) → construct `AttributionService` / **`Analyzer`** → background poll → Uvicorn. Optional cache import. diff --git a/services/nvrx_attrsvc/README.md b/services/nvrx_attrsvc/README.md index 012c9ee1..9fb3d28b 100644 --- a/services/nvrx_attrsvc/README.md +++ b/services/nvrx_attrsvc/README.md @@ -24,8 +24,8 @@ pip install -e . # Run export NVRX_ATTRSVC_ALLOWED_ROOT=/path/to/logs -# API key: set env var OR create ~/.nvidia_api_key file -export NVIDIA_API_KEY=nvapi-... +# API key: set env var OR create ~/.llm_api_key file +export LLM_API_KEY=your-llm-api-key-here nvrx-attrsvc ``` @@ -57,11 +57,11 @@ Environment variables (prefix: `NVRX_ATTRSVC_`): | `NVRX_ATTRSVC_COMPUTE_TIMEOUT` | Timeout for analysis in seconds | | `NVRX_ATTRSVC_ANALYSIS_BACKEND` | `mcp` (subprocess MCP, default) or `lib` (in-process LogSage and flight-recorder analysis). Same setting for both; library behavior: **ARCHITECTURE.md §7**. Legacy env: `NVRX_ATTRSVC_LOG_ANALYSIS_BACKEND`. | -**NVIDIA API Key** (required, checked in order): -1. `NVIDIA_API_KEY` environment variable -2. `NVIDIA_API_KEY_FILE` environment variable (path to file) -3. `~/.nvidia_api_key` file -4. `~/.config/nvrx/nvidia_api_key` file +**LLM API Key** (required, checked in order — see `api_keys.load_llm_api_key`): +1. `LLM_API_KEY` environment variable +2. `LLM_API_KEY_FILE` environment variable (path to file) +3. `~/.llm_api_key` file +4. `~/.config/nvrx/llm_api_key` file **Slack Notifications** (optional; no `NVRX_ATTRSVC_` prefix): diff --git a/services/nvrx_attrsvc/config.py b/services/nvrx_attrsvc/config.py index 0cab5787..d035ba02 100644 --- a/services/nvrx_attrsvc/config.py +++ b/services/nvrx_attrsvc/config.py @@ -262,14 +262,14 @@ def setup() -> Settings: logging.getLogger("nvidia_resiliency_ext.attribution.mcp_integration").setLevel(_root_lvl) logging.getLogger("uvicorn.access").setLevel(logging.WARNING) - from nvidia_resiliency_ext.attribution.api_keys import load_nvidia_api_key, load_slack_bot_token + from nvidia_resiliency_ext.attribution.api_keys import load_llm_api_key, load_slack_bot_token - nvidia_key = load_nvidia_api_key() - if not nvidia_key: + llm_key = load_llm_api_key() + if not llm_key: logger.error( - "NVIDIA API key not found or empty. Attribution requires a key. Set NVIDIA_API_KEY " - "or NVIDIA_API_KEY_FILE, or place a key in ~/.nvidia_api_key or " - "~/.config/nvrx/nvidia_api_key. Slack notifications remain optional (SLACK_BOT_TOKEN)." + "LLM API key not found or empty. Attribution requires a key. Set LLM_API_KEY or " + "LLM_API_KEY_FILE, or default key files (~/.llm_api_key, ~/.config/nvrx/llm_api_key). " + "Slack notifications remain optional (SLACK_BOT_TOKEN)." ) raise SystemExit(1) diff --git a/services/nvrx_attrsvc/deploy/Dockerfile b/services/nvrx_attrsvc/deploy/Dockerfile index 20478843..7ea43660 100644 --- a/services/nvrx_attrsvc/deploy/Dockerfile +++ b/services/nvrx_attrsvc/deploy/Dockerfile @@ -7,7 +7,7 @@ # docker run -d \ # -p 8000:8000 \ # -e NVRX_ATTRSVC_ALLOWED_ROOT=/data/logs \ -# -e NVIDIA_API_KEY=nvapi-... \ +# -e LLM_API_KEY=your-llm-api-key-here \ # -v /path/to/logs:/data/logs:ro \ # nvrx-attrsvc # diff --git a/services/nvrx_attrsvc/deploy/kubernetes.yaml b/services/nvrx_attrsvc/deploy/kubernetes.yaml index 815a9cb1..fb34fe25 100644 --- a/services/nvrx_attrsvc/deploy/kubernetes.yaml +++ b/services/nvrx_attrsvc/deploy/kubernetes.yaml @@ -4,7 +4,7 @@ # kubectl apply -f services/nvrx_attrsvc/deploy/kubernetes.yaml # # Prerequisites: -# - Create secret: kubectl create secret generic nvidia-api-key --from-literal=api-key=nvapi-... +# - Create secret: kubectl create secret generic llm-api-key --from-literal=api-key=your-llm-api-key-here # - Ensure log volume is accessible (update hostPath as needed) # # Deployment considerations: @@ -54,10 +54,10 @@ spec: - configMapRef: name: nvrx-attrsvc-config env: - - name: NVIDIA_API_KEY + - name: LLM_API_KEY valueFrom: secretKeyRef: - name: nvidia-api-key + name: llm-api-key key: api-key volumeMounts: - name: logs diff --git a/services/nvrx_attrsvc/deploy/nvrx-attrsvc.service b/services/nvrx_attrsvc/deploy/nvrx-attrsvc.service index 87900541..be5daa93 100644 --- a/services/nvrx_attrsvc/deploy/nvrx-attrsvc.service +++ b/services/nvrx_attrsvc/deploy/nvrx-attrsvc.service @@ -7,7 +7,7 @@ # Manual installation: # 1. Create venv: python3 -m venv /opt/nvrx/venv # 2. Install: /opt/nvrx/venv/bin/pip install -e services -# 3. Create API key: echo "nvapi-xxx" | sudo tee /etc/nvrx/nvidia_api_key +# 3. Create API key: echo "your-llm-api-key-here" | sudo tee /etc/nvrx/llm_api_key # 4. Copy service: sudo cp nvrx-attrsvc.service /etc/systemd/system/ # 5. Reload: sudo systemctl daemon-reload # 6. Enable: sudo systemctl enable nvrx-attrsvc diff --git a/services/nvrx_attrsvc/deploy/run_attrsvc.sh b/services/nvrx_attrsvc/deploy/run_attrsvc.sh index 13befead..3ef7afc7 100644 --- a/services/nvrx_attrsvc/deploy/run_attrsvc.sh +++ b/services/nvrx_attrsvc/deploy/run_attrsvc.sh @@ -6,7 +6,7 @@ # # Required environment variables: # NVRX_ATTRSVC_ALLOWED_ROOT - Root path for log files to analyze -# NVIDIA_API_KEY - API key for LLM (or NVIDIA_API_KEY_FILE) +# LLM_API_KEY - API key for LLM (or LLM_API_KEY_FILE) # # Optional environment variables: # NVRX_ATTRSVC_PORT - Listen port (default: 8000) @@ -17,7 +17,7 @@ # # Example: # export NVRX_ATTRSVC_ALLOWED_ROOT=/lustre/logs -# export NVIDIA_API_KEY=nvapi-... +# export LLM_API_KEY=your-llm-api-key-here # ./run_attrsvc.sh ~/nvrx_logs set -e @@ -38,7 +38,7 @@ PID_FILE="${OUTPUT_DIR}/${PREFIX}_attrsvc.pid" validate_attrsvc_allowed_root || exit 1 # Setup API key -setup_nvidia_api_key || exit 1 +setup_llm_api_key || exit 1 # Create output directory ensure_directory "${OUTPUT_DIR}" "logs directory" || exit 1 diff --git a/services/nvrx_attrsvc/deploy/slurm.sbatch b/services/nvrx_attrsvc/deploy/slurm.sbatch index 914c2d28..3931affd 100644 --- a/services/nvrx_attrsvc/deploy/slurm.sbatch +++ b/services/nvrx_attrsvc/deploy/slurm.sbatch @@ -18,9 +18,9 @@ # NVRX_ATTRSVC_ALLOWED_ROOT - Root path for log files to analyze # # API Key Options (in priority order): -# 1. NVIDIA_API_KEY env var (direct key) -# 2. NVIDIA_API_KEY_FILE env var (path to file containing key) -# 3. Default: ~/.nvidia_api_key +# 1. LLM_API_KEY env var (direct key) +# 2. LLM_API_KEY_FILE env var (path to file containing key) +# 3. Default: ~/.llm_api_key or ~/.config/nvrx/llm_api_key # # Example: # NVRX_ATTRSVC_ALLOWED_ROOT=/lustre/logs sbatch --account=myaccount slurm.sbatch @@ -48,7 +48,7 @@ export NVRX_ATTRSVC_NVDATAFLOW_PROJECT="${NVRX_ATTRSVC_NVDATAFLOW_PROJECT:-}" export NVRX_ATTRSVC_CLUSTER_NAME="${NVRX_ATTRSVC_CLUSTER_NAME:-${SLURM_CLUSTER_NAME:-unknown}}" # Setup API key -setup_nvidia_api_key || exit 1 +setup_llm_api_key || exit 1 # Install packages install_nvrx_packages "attrsvc" diff --git a/services/scripts/README.md b/services/scripts/README.md index 5f664bc7..c8b40280 100644 --- a/services/scripts/README.md +++ b/services/scripts/README.md @@ -20,8 +20,8 @@ Shared shell scripts for deployment and monitoring. ```bash # Set required environment export NVRX_ATTRSVC_ALLOWED_ROOT=/lustre/logs -# API key: set env var OR create ~/.nvidia_api_key file -export NVIDIA_API_KEY=nvapi-... +# API key: set env var OR create ~/.llm_api_key file +export LLM_API_KEY=your-llm-api-key-here # Install, start, and manage ./scripts/run_services.sh install # Install packages @@ -50,10 +50,10 @@ sudo ./scripts/setup_systemd.sh start ### API Key The API key can be provided in multiple ways (checked in order): -1. `NVIDIA_API_KEY` environment variable -2. `NVIDIA_API_KEY_FILE` environment variable (path to key file) -3. `~/.nvidia_api_key` file -4. `~/.config/nvrx/nvidia_api_key` file +1. `LLM_API_KEY` environment variable +2. `LLM_API_KEY_FILE` environment variable (path to key file) +3. `~/.llm_api_key` file +4. `~/.config/nvrx/llm_api_key` file **Output files** (in `~/nvrx_logs/` by default): - `_attrsvc.log` - Attribution service stdout/stderr @@ -129,7 +129,7 @@ Shared functions sourced by other scripts: | Function | Description | |----------|-------------| -| `setup_nvidia_api_key` | Load API key from env, file, or default location | +| `setup_llm_api_key` | Load LLM API key from env, file, or default location | | `install_nvrx_packages` | Install NVRX packages from local repo | | `validate_commands` | Check required commands exist | diff --git a/services/scripts/build_enroot_image.sh b/services/scripts/build_enroot_image.sh index f1b30ce7..0402d06e 100755 --- a/services/scripts/build_enroot_image.sh +++ b/services/scripts/build_enroot_image.sh @@ -14,7 +14,7 @@ # # Run attribution service # srun --container-image=/path/to/nvrx-services.sqsh \ # --container-env=NVRX_ATTRSVC_ALLOWED_ROOT=/data \ -# --container-env=NVIDIA_API_KEY=${NVIDIA_API_KEY} \ +# --container-env=LLM_API_KEY=${LLM_API_KEY} \ # --container-mounts=/path/to/logs:/data:ro \ # nvrx-attrsvc # @@ -150,7 +150,7 @@ echo "" echo " # Attribution service" echo " srun --container-image=${OUTPUT_PATH} \\" echo " --container-env=NVRX_ATTRSVC_ALLOWED_ROOT=/data \\" -echo " --container-env=NVIDIA_API_KEY=\${NVIDIA_API_KEY} \\" +echo " --container-env=LLM_API_KEY=\${LLM_API_KEY} \\" echo " --container-mounts=/path/to/logs:/data:ro \\" echo " nvrx-attrsvc" echo "" diff --git a/services/scripts/common.sh b/services/scripts/common.sh index 75bfdcf6..dcfea9cf 100755 --- a/services/scripts/common.sh +++ b/services/scripts/common.sh @@ -5,51 +5,41 @@ # Get the directory where this script lives COMMON_SETUP_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -# Setup NVIDIA API key from environment, file, or default locations +# Setup LLM API key for processes that read nvidia_resiliency_ext.attribution.api_keys.load_llm_api_key # Checks in order: -# 1. NVIDIA_API_KEY environment variable -# 2. NVIDIA_API_KEY_FILE environment variable (path to file) -# 3. ~/.nvidia_api_key file -# 4. ~/.config/nvrx/nvidia_api_key file -# Sets NVIDIA_API_KEY environment variable -setup_nvidia_api_key() { - if [[ -n "${NVIDIA_API_KEY}" ]]; then - echo "Using NVIDIA_API_KEY from environment" +# 1. LLM_API_KEY environment variable +# 2. LLM_API_KEY_FILE (path to file; must exist) +# 3. ~/.llm_api_key — sets export LLM_API_KEY_FILE to that path +# 4. ~/.config/nvrx/llm_api_key +setup_llm_api_key() { + if [[ -n "${LLM_API_KEY}" ]]; then + echo "Using LLM_API_KEY from environment" return 0 fi - - # Check NVIDIA_API_KEY_FILE - if [[ -n "${NVIDIA_API_KEY_FILE}" ]]; then - if [[ ! -f "${NVIDIA_API_KEY_FILE}" ]]; then - echo "ERROR: NVIDIA_API_KEY_FILE specified but not found: ${NVIDIA_API_KEY_FILE}" - return 1 - fi - export NVIDIA_API_KEY=$(cat "${NVIDIA_API_KEY_FILE}" | tr -d '[:space:]') - if [[ -z "${NVIDIA_API_KEY}" ]]; then - echo "ERROR: NVIDIA_API_KEY_FILE is empty: ${NVIDIA_API_KEY_FILE}" + + if [[ -n "${LLM_API_KEY_FILE}" ]]; then + if [[ ! -f "${LLM_API_KEY_FILE}" ]]; then + echo "ERROR: LLM_API_KEY_FILE specified but not found: ${LLM_API_KEY_FILE}" return 1 fi - echo "Using API key from: ${NVIDIA_API_KEY_FILE}" + echo "Using LLM_API_KEY_FILE=${LLM_API_KEY_FILE}" return 0 fi - - # Check default locations + local KEY_LOCATIONS=( - "${HOME}/.nvidia_api_key" - "${HOME}/.config/nvrx/nvidia_api_key" + "${HOME}/.llm_api_key" + "${HOME}/.config/nvrx/llm_api_key" ) for key_file in "${KEY_LOCATIONS[@]}"; do if [[ -f "${key_file}" ]]; then - export NVIDIA_API_KEY=$(cat "${key_file}" | tr -d '[:space:]') - if [[ -n "${NVIDIA_API_KEY}" ]]; then - echo "Using API key from: ${key_file}" - return 0 - fi + export LLM_API_KEY_FILE="${key_file}" + echo "Using API key from: ${key_file}" + return 0 fi done - - echo "WARNING: NVIDIA_API_KEY not found - LLM analysis may fail" - echo " Set NVIDIA_API_KEY, NVIDIA_API_KEY_FILE, or create ~/.nvidia_api_key" + + echo "WARNING: LLM API key not found - LLM analysis may fail" + echo " Set LLM_API_KEY, LLM_API_KEY_FILE, or create ~/.llm_api_key" return 1 } diff --git a/services/scripts/nvrx_services.sbatch b/services/scripts/nvrx_services.sbatch index 5cb4d70d..94e76c30 100755 --- a/services/scripts/nvrx_services.sbatch +++ b/services/scripts/nvrx_services.sbatch @@ -26,9 +26,9 @@ # NVRX_LOGS_DIR - Directory for service logs (default: current directory) # # API Key Options (in priority order): -# 1. NVIDIA_API_KEY env var (direct key) -# 2. NVIDIA_API_KEY_FILE env var (path to file containing key) -# 3. Default: ~/.nvidia_api_key +# 1. LLM_API_KEY env var (direct key) +# 2. LLM_API_KEY_FILE env var (path to file containing key) +# 3. Default: ~/.llm_api_key or ~/.config/nvrx/llm_api_key # # Example with custom output directory: # export NVRX_LOGS_DIR=/my/logs @@ -69,7 +69,7 @@ export NVRX_ATTRSVC_NVDATAFLOW_PROJECT="${NVRX_ATTRSVC_NVDATAFLOW_PROJECT:-}" export NVRX_ATTRSVC_CLUSTER_NAME="${NVRX_ATTRSVC_CLUSTER_NAME:-${SLURM_CLUSTER_NAME:-unknown}}" # Setup API key -setup_nvidia_api_key || exit 1 +setup_llm_api_key || exit 1 # Configuration - SLURM Monitor Service (can be overridden via env vars) export NVRX_SMONSVC_PORT="${NVRX_SMONSVC_PORT:-8100}" diff --git a/services/scripts/run_services.sh b/services/scripts/run_services.sh index aaaa8379..0217ac4b 100755 --- a/services/scripts/run_services.sh +++ b/services/scripts/run_services.sh @@ -15,7 +15,7 @@ # # Required environment variables: # NVRX_ATTRSVC_ALLOWED_ROOT - Root path for log files to analyze -# NVIDIA_API_KEY - API key for LLM (or NVIDIA_API_KEY_FILE) +# LLM_API_KEY - API key for LLM (or LLM_API_KEY_FILE) # # Optional environment variables: # NVRX_LOGS_DIR - Output directory for logs (default: ~/nvrx_logs) @@ -28,7 +28,7 @@ # # Example: # export NVRX_ATTRSVC_ALLOWED_ROOT=/lustre/logs -# export NVIDIA_API_KEY=nvapi-... +# export LLM_API_KEY=your-llm-api-key-here # ./run_services.sh install # ./run_services.sh start # ./run_services.sh status @@ -104,29 +104,8 @@ cmd_start() { validate_attrsvc_allowed_root || exit 1 ensure_directory "${NVRX_LOGS_DIR}" "logs directory" || exit 1 - # Load API key from file if not already set - if [[ -z "$NVIDIA_API_KEY" ]]; then - # Check NVIDIA_API_KEY_FILE first - if [[ -n "$NVIDIA_API_KEY_FILE" && -f "$NVIDIA_API_KEY_FILE" ]]; then - export NVIDIA_API_KEY=$(cat "$NVIDIA_API_KEY_FILE") - echo "Loaded API key from: $NVIDIA_API_KEY_FILE" - # Check common locations - elif [[ -f "${HOME}/.nvidia_api_key" ]]; then - export NVIDIA_API_KEY=$(cat "${HOME}/.nvidia_api_key") - echo "Loaded API key from: ~/.nvidia_api_key" - elif [[ -f "${HOME}/.config/nvrx/nvidia_api_key" ]]; then - export NVIDIA_API_KEY=$(cat "${HOME}/.config/nvrx/nvidia_api_key") - echo "Loaded API key from: ~/.config/nvrx/nvidia_api_key" - else - echo -e "${RED}Error: NVIDIA_API_KEY not set and no key file found${NC}" - echo "Either:" - echo " export NVIDIA_API_KEY=nvapi-xxx" - echo " export NVIDIA_API_KEY_FILE=/path/to/keyfile" - echo " Or create ~/.nvidia_api_key" - exit 1 - fi - fi - + setup_llm_api_key || exit 1 + # Check if already running if is_running "${ATTRSVC_PID_FILE}"; then echo -e "${YELLOW}Attribution service already running (PID: $(cat ${ATTRSVC_PID_FILE}))${NC}" @@ -283,27 +262,8 @@ cmd_run() { validate_attrsvc_allowed_root || exit 1 ensure_directory "${NVRX_LOGS_DIR}" "logs directory" || exit 1 - # Load API key from file if not already set - if [[ -z "$NVIDIA_API_KEY" ]]; then - if [[ -n "$NVIDIA_API_KEY_FILE" && -f "$NVIDIA_API_KEY_FILE" ]]; then - export NVIDIA_API_KEY=$(cat "$NVIDIA_API_KEY_FILE") - echo "Loaded API key from: $NVIDIA_API_KEY_FILE" - elif [[ -f "${HOME}/.nvidia_api_key" ]]; then - export NVIDIA_API_KEY=$(cat "${HOME}/.nvidia_api_key") - echo "Loaded API key from: ~/.nvidia_api_key" - elif [[ -f "${HOME}/.config/nvrx/nvidia_api_key" ]]; then - export NVIDIA_API_KEY=$(cat "${HOME}/.config/nvrx/nvidia_api_key") - echo "Loaded API key from: ~/.config/nvrx/nvidia_api_key" - else - echo -e "${RED}Error: NVIDIA_API_KEY not set and no key file found${NC}" - echo "Either:" - echo " export NVIDIA_API_KEY=nvapi-xxx" - echo " export NVIDIA_API_KEY_FILE=/path/to/keyfile" - echo " Or create ~/.nvidia_api_key" - exit 1 - fi - fi - + setup_llm_api_key || exit 1 + install_nvrx_packages "both" "${LIB_ROOT}" echo "" diff --git a/services/scripts/setup_systemd.sh b/services/scripts/setup_systemd.sh index e6a6e438..a502dc4c 100755 --- a/services/scripts/setup_systemd.sh +++ b/services/scripts/setup_systemd.sh @@ -299,7 +299,7 @@ cmd_install() { # Shared by nvrx-attrsvc and nvrx-smonsvc # ─── Common ─── -NVIDIA_API_KEY_FILE=${CONFIG_DIR}/nvidia_api_key +LLM_API_KEY_FILE=${CONFIG_DIR}/llm_api_key NVRX_LOGS_DIR=${LOG_DIR} # ─── Attribution Service (nvrx-attrsvc) ─── @@ -392,9 +392,9 @@ EOF echo "" if [[ "$USER_MODE" == true ]]; then - echo "1. Create NVIDIA API key file:" - echo " echo 'nvapi-xxx' > ${CONFIG_DIR}/nvidia_api_key" - echo " chmod 600 ${CONFIG_DIR}/nvidia_api_key" + echo "1. Create LLM API key file:" + echo " echo 'your-llm-api-key-here' > ${CONFIG_DIR}/llm_api_key" + echo " chmod 600 ${CONFIG_DIR}/llm_api_key" echo "" echo "2. Edit configuration (set NVRX_ATTRSVC_ALLOWED_ROOT):" echo " vim ${CONFIG_DIR}/nvrx.env" @@ -409,10 +409,10 @@ EOF echo "5. View logs:" echo " $0 --user logs" else - echo "1. Create NVIDIA API key file:" - echo " echo 'nvapi-xxx' | sudo tee ${CONFIG_DIR}/nvidia_api_key" - echo " sudo chmod 640 ${CONFIG_DIR}/nvidia_api_key" - echo " sudo chown root:${NVRX_GROUP} ${CONFIG_DIR}/nvidia_api_key" + echo "1. Create LLM API key file:" + echo " echo 'your-llm-api-key-here' | sudo tee ${CONFIG_DIR}/llm_api_key" + echo " sudo chmod 640 ${CONFIG_DIR}/llm_api_key" + echo " sudo chown root:${NVRX_GROUP} ${CONFIG_DIR}/llm_api_key" echo "" echo "2. Edit configuration (set NVRX_ATTRSVC_ALLOWED_ROOT):" echo " sudo vim ${CONFIG_DIR}/nvrx.env" @@ -431,7 +431,7 @@ EOF echo "" echo "Configuration:" echo " ${CONFIG_DIR}/nvrx.env - All settings for both services" - echo " ${CONFIG_DIR}/nvidia_api_key - API key (create this)" + echo " ${CONFIG_DIR}/llm_api_key - LLM API key (create this)" echo "" echo "Installed:" echo " ${INSTALL_DIR}/venv/bin/nvrx-attrsvc" diff --git a/src/nvidia_resiliency_ext/attribution/api_keys.py b/src/nvidia_resiliency_ext/attribution/api_keys.py index 3849b900..febac0c6 100644 --- a/src/nvidia_resiliency_ext/attribution/api_keys.py +++ b/src/nvidia_resiliency_ext/attribution/api_keys.py @@ -3,8 +3,8 @@ """Load API tokens from environment and well-known file paths. -**NVIDIA API key** — Required for LogSage and LLM merge paths; callers that embed attribution -should fail startup or analysis if :func:`load_nvidia_api_key` returns empty. +**LLM API key** — Required for LogSage and LLM merge paths; callers that embed attribution +should fail startup or analysis if :func:`load_llm_api_key` returns empty. **Slack bot token** — Optional; empty means notifications are disabled (postprocessing no-ops). """ @@ -23,31 +23,33 @@ def _read_key_file(path: str) -> str: return "" -def load_nvidia_api_key() -> str: - """Load NVIDIA API key from environment or file. +def load_llm_api_key() -> str: + """Load LLM API key from environment or file. Required for LLM-based attribution. Checks in order: - 1. ``NVIDIA_API_KEY`` environment variable - 2. ``NVIDIA_API_KEY_FILE`` environment variable (path to key file) - 3. ``~/.nvidia_api_key`` - 4. ``~/.config/nvrx/nvidia_api_key`` + 1. ``LLM_API_KEY`` environment variable + 2. ``LLM_API_KEY_FILE`` environment variable (path to key file) + 3. ``~/.llm_api_key`` + 4. ``~/.config/nvrx/llm_api_key`` Returns: API key string, or empty string if not found or unreadable. """ - api_key = os.getenv("NVIDIA_API_KEY") + api_key = os.getenv("LLM_API_KEY") if api_key: return api_key.strip() - key_file = os.getenv("NVIDIA_API_KEY_FILE") + key_file = os.getenv("LLM_API_KEY_FILE") if key_file and os.path.isfile(key_file): - return _read_key_file(key_file) + v = _read_key_file(key_file) + if v: + return v home = os.path.expanduser("~") for path in ( - os.path.join(home, ".nvidia_api_key"), - os.path.join(home, ".config", "nvrx", "nvidia_api_key"), + os.path.join(home, ".llm_api_key"), + os.path.join(home, ".config", "nvrx", "llm_api_key"), ): if os.path.isfile(path): v = _read_key_file(path) diff --git a/src/nvidia_resiliency_ext/attribution/combined_log_fr/combined_log_fr.py b/src/nvidia_resiliency_ext/attribution/combined_log_fr/combined_log_fr.py index 92cce705..c8230fef 100644 --- a/src/nvidia_resiliency_ext/attribution/combined_log_fr/combined_log_fr.py +++ b/src/nvidia_resiliency_ext/attribution/combined_log_fr/combined_log_fr.py @@ -5,7 +5,7 @@ import logging from typing import Any, Mapping, Union -from nvidia_resiliency_ext.attribution.api_keys import load_nvidia_api_key +from nvidia_resiliency_ext.attribution.api_keys import load_llm_api_key from nvidia_resiliency_ext.attribution.base import ( AttributionState, NVRxAttribution, @@ -37,7 +37,7 @@ def __init__( ) self._init_config = ad # Resolve once per CombinedLogFR instance so merge_log_fr_llm does not re-read env/files each call. - self._nvidia_api_key = load_nvidia_api_key() + self._llm_api_key = load_llm_api_key() async def preprocess_input(self) -> dict: cfg = merged_attribution_config(self._init_config) @@ -59,7 +59,7 @@ async def collective_analysis(self, output: dict) -> str: return await merge_log_fr_llm( log_result, fr_result, - nvidia_api_key=self._nvidia_api_key, + llm_api_key=self._llm_api_key, model=cfg.get("model", DEFAULT_LLM_MODEL), base_url=cfg.get("base_url", DEFAULT_LLM_BASE_URL), temperature=float(cfg.get("temperature", 0.2)), diff --git a/src/nvidia_resiliency_ext/attribution/combined_log_fr/llm_merge.py b/src/nvidia_resiliency_ext/attribution/combined_log_fr/llm_merge.py index a738c74b..bb28a501 100644 --- a/src/nvidia_resiliency_ext/attribution/combined_log_fr/llm_merge.py +++ b/src/nvidia_resiliency_ext/attribution/combined_log_fr/llm_merge.py @@ -91,7 +91,7 @@ async def merge_log_fr_llm( log_result: Any, fr_result: Any, *, - nvidia_api_key: str, + llm_api_key: str, model: str, base_url: str, temperature: float = 0.2, @@ -100,19 +100,19 @@ async def merge_log_fr_llm( ) -> str: """Run the Nemotron-style fusion prompt; ``fr_result`` may be :class:`FRAnalysisResult` or raw text. - Callers should pass a key obtained once (e.g. from :func:`~nvidia_resiliency_ext.attribution.api_keys.load_nvidia_api_key` + Callers should pass a key obtained once (e.g. from :func:`~nvidia_resiliency_ext.attribution.api_keys.load_llm_api_key` at startup or pipeline entry) so the merge step does not re-read env/files on every call. """ from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import PromptTemplate from langchain_openai import ChatOpenAI - if not (nvidia_api_key and nvidia_api_key.strip()): + if not (llm_api_key and llm_api_key.strip()): raise ValueError( - "NVIDIA API key is empty. Load it once via load_nvidia_api_key() and pass nvidia_api_key=... " + "LLM API key is empty. Load it once via load_llm_api_key() and pass llm_api_key=... " "Required for log+FR LLM merge." ) - api_key = nvidia_api_key.strip() + api_key = llm_api_key.strip() log_payload, _ = unpack_run_result(log_result) fr_payload, _ = unpack_run_result(fr_result) diff --git a/src/nvidia_resiliency_ext/attribution/log_analyzer/__init__.py b/src/nvidia_resiliency_ext/attribution/log_analyzer/__init__.py index 1ed25819..34811da3 100644 --- a/src/nvidia_resiliency_ext/attribution/log_analyzer/__init__.py +++ b/src/nvidia_resiliency_ext/attribution/log_analyzer/__init__.py @@ -85,7 +85,7 @@ extract_job_metadata, ) from .parser_base import BaseParser, ParseResult -from .runner import ensure_analyzer_ready, run_log_analysis_sync +from .runner import ensure_analyzer_ready, notify_log_path_sync, run_log_analysis_sync from .slurm_parser import ( SlurmOutputInfo, SlurmParser, @@ -189,6 +189,7 @@ "build_dataflow_record", # Sync lib/MCP runner (e.g. FT path) "ensure_analyzer_ready", + "notify_log_path_sync", "run_log_analysis_sync", # Attribution decision helper "attribution_no_restart", diff --git a/src/nvidia_resiliency_ext/attribution/log_analyzer/analysis_pipeline.py b/src/nvidia_resiliency_ext/attribution/log_analyzer/analysis_pipeline.py index bcca32f6..458bd0b7 100644 --- a/src/nvidia_resiliency_ext/attribution/log_analyzer/analysis_pipeline.py +++ b/src/nvidia_resiliency_ext/attribution/log_analyzer/analysis_pipeline.py @@ -34,7 +34,7 @@ from dataclasses import dataclass from typing import Any, Awaitable, Callable, Dict, Optional, Tuple -from nvidia_resiliency_ext.attribution.api_keys import load_nvidia_api_key +from nvidia_resiliency_ext.attribution.api_keys import load_llm_api_key from nvidia_resiliency_ext.attribution.trace_analyzer.fr_support import ( FRAnalysisResult, analyze_fr_dump, @@ -94,7 +94,7 @@ async def run_attribution_pipeline( llm_temperature: float = 0.2, llm_top_p: float = 0.7, llm_max_tokens: int = 16384, - nvidia_api_key: Optional[str] = None, + llm_api_key: Optional[str] = None, ) -> CombinedAnalysisResult: """Run attribution according to ``mode``. @@ -113,8 +113,8 @@ async def run_attribution_pipeline( llm_model: Model id for **LOG_AND_TRACE_WITH_LLM** (required when that mode runs the merge). llm_base_url: Base url for **LOG_AND_TRACE_WITH_LLM** (required when that mode runs the merge). llm_temperature / llm_top_p / llm_max_tokens: Passed to the merge LLM when applicable. - nvidia_api_key: NVIDIA API key for **LOG_AND_TRACE_WITH_LLM** host merge when MCP did not - merge. If ``None``, resolved once per pipeline run via :func:`load_nvidia_api_key`. + llm_api_key: API key for **LOG_AND_TRACE_WITH_LLM** host merge when MCP did not + merge. If ``None``, resolved once per pipeline run via :func:`load_llm_api_key`. """ discover = discover_fr_dump_path or extract_fr_dump_path run_fr = run_fr_analysis or analyze_fr_dump @@ -176,11 +176,11 @@ async def run_attribution_pipeline( raise ValueError("llm_base_url is required for LOG_AND_TRACE_WITH_LLM when merging") from nvidia_resiliency_ext.attribution.combined_log_fr.llm_merge import merge_log_fr_llm - merge_key = nvidia_api_key if nvidia_api_key is not None else load_nvidia_api_key() + merge_key = llm_api_key if llm_api_key is not None else load_llm_api_key() llm_merged_summary = await merge_log_fr_llm( log_result, fr_analysis, - nvidia_api_key=merge_key, + llm_api_key=merge_key, model=llm_model, base_url=llm_base_url, temperature=llm_temperature, diff --git a/src/nvidia_resiliency_ext/attribution/log_analyzer/nvrx_logsage.py b/src/nvidia_resiliency_ext/attribution/log_analyzer/nvrx_logsage.py index d4235654..3a2a318f 100644 --- a/src/nvidia_resiliency_ext/attribution/log_analyzer/nvrx_logsage.py +++ b/src/nvidia_resiliency_ext/attribution/log_analyzer/nvrx_logsage.py @@ -113,14 +113,14 @@ def chunk_logs_strict(lines): class NVRxLogAnalyzer(NVRxAttribution): def __init__(self, args: Union[argparse.Namespace, Mapping[str, Any]]): - from nvidia_resiliency_ext.attribution.api_keys import load_nvidia_api_key + from nvidia_resiliency_ext.attribution.api_keys import load_llm_api_key self._init_config = normalize_attribution_args(args) - self.api_key = load_nvidia_api_key() + self.api_key = load_llm_api_key() if not self.api_key: raise ValueError( - "NVIDIA_API_KEY not found. Set NVIDIA_API_KEY env var, " - "NVIDIA_API_KEY_FILE env var, or create ~/.nvidia_api_key" + "LLM API key not found. Set LLM_API_KEY or LLM_API_KEY_FILE, " + "or create ~/.llm_api_key or ~/.config/nvrx/llm_api_key" ) logger.debug("API key loaded (length=%d)", len(self.api_key)) logger.debug( diff --git a/src/nvidia_resiliency_ext/attribution/log_analyzer/runner.py b/src/nvidia_resiliency_ext/attribution/log_analyzer/runner.py index 7b13c98a..dcd4c913 100644 --- a/src/nvidia_resiliency_ext/attribution/log_analyzer/runner.py +++ b/src/nvidia_resiliency_ext/attribution/log_analyzer/runner.py @@ -9,10 +9,20 @@ results are cached per file path—same mapping as the HTTP service. We run a dedicated thread with an event loop and submit work to it from sync code. +**MCP backend** (``use_lib_log_analysis=False``): after :meth:`~nvidia_resiliency_ext.attribution.analyzer.engine.Analyzer.connect_mcp`, +:class:`~nvidia_resiliency_ext.attribution.log_analyzer.log_analyzer.LogAnalyzer` runs +:class:`~nvidia_resiliency_ext.attribution.log_analyzer.analysis_pipeline.AnalysisPipelineMode.LOG_AND_TRACE` +with a :class:`~nvidia_resiliency_ext.attribution.trace_analyzer.trace_analyzer.TraceAnalyzer`. +When an FR dump path is discovered from the log, analysis uses a **single** MCP round-trip to the +``log_fr_analyzer`` module (LogSage + FR in the MCP server), not separate ``log_analyzer`` + ``fr_analyzer`` calls. + The HTTP service (nvrx_attrsvc) does not use this module; it builds :class:`~nvidia_resiliency_ext.attribution.analyzer.engine.Analyzer` with ``ALLOWED_ROOT`` from config, and ``analyze()`` / ``submit()`` validate paths under that root. + +:func:`notify_log_path_sync` runs ``submit()`` only (job registration), for early +notification parity with HTTP ``POST /logs`` before full analysis. """ import asyncio @@ -100,18 +110,26 @@ def _get_or_create_analyzer( use_lib = use_lib_log_analysis if use_lib_log_analysis is not None else True try: from ..analyzer import Analyzer + from ..trace_analyzer.trace_analyzer import TraceAnalyzer # Permissive root: only enforces absolute paths under /. Callers must # restrict paths themselves if needed (e.g. FT host validates before calling). + # Pass TraceAnalyzer explicitly so LOG_AND_TRACE always registers FR discovery; in MCP + # mode the pipeline uses the ``log_fr_analyzer`` tool when a dump path exists. _lib_analyzer = Analyzer( allowed_root="/", use_lib_log_analysis=use_lib, compute_timeout=timeout_seconds, + trace_analyzer=TraceAnalyzer(allowed_root="/"), ) _lib_analyzer.set_event_loop(_lib_loop) if not use_lib: future = asyncio.run_coroutine_threadsafe(_lib_analyzer.connect_mcp(), _lib_loop) - future.result(timeout=30) + future.result(timeout=120) + logger.info( + "log analysis MCP: connected; LOG_AND_TRACE will use MCP module " + "'log_fr_analyzer' when an FR dump path is found in the log" + ) except Exception as e: _lib_analyzer = None logger.warning("log analysis lib: failed to create analyzer: %s", e) @@ -141,6 +159,9 @@ def run_log_analysis_sync( wl_restart: Optional[int] = None, user: str = "", job_id: str = "", + *, + timeout_seconds: Optional[float] = None, + use_lib_log_analysis: Optional[bool] = None, ) -> Optional[Dict[str, Any]]: """Run log analysis synchronously with a timeout. @@ -151,18 +172,25 @@ def run_log_analysis_sync( as the HTTP service). Repeat calls for the same path return the cached result; wl_restart selects the cycle when one file has multiple cycles. + Call :func:`ensure_analyzer_ready` first with the same ``use_lib_log_analysis`` as used here + (``False`` for MCP, ``True`` for in-process LogSage). If omitted, ``use_lib_log_analysis`` + defaults to in-process when creating the analyzer on first use. + Args: log_path: Path to the cycle log file to analyze. With this module's ``allowed_root="/"``, enforce a stricter allowed prefix at the call site if required. wl_restart: Workload restart index within file (None = first or all). When a file contains multiple cycles, use this to select which cycle's result. + timeout_seconds: Coalescer compute timeout when creating the analyzer (if not already created). + use_lib_log_analysis: ``False`` for MCP backend; ``True`` for in-process LogSage. Returns: Result dict from the analyzer on success, or None on timeout/error/skip. """ from .types import LogAnalyzerError - if not _get_or_create_analyzer(): + ts = timeout_seconds if timeout_seconds is not None else 60.0 + if not _get_or_create_analyzer(ts, use_lib_log_analysis): return None validated = _lib_analyzer.validate_path(log_path, require_regular_file=True, reject_empty=False) @@ -190,3 +218,48 @@ async def _run() -> Any: logger.debug("log analysis lib: analysis error for %s: %s", log_path, raw.message) return None return _raw_to_result_dict(raw) + + +_NOTIFY_SUBMIT_TIMEOUT_S = 30.0 + + +def notify_log_path_sync( + log_path: str, + user: str = "", + job_id: str = "", + *, + timeout_seconds: Optional[float] = None, + use_lib_log_analysis: Optional[bool] = None, +) -> None: + """Register ``log_path`` for job tracking via :meth:`~nvidia_resiliency_ext.attribution.analyzer.engine.Analyzer.submit` only. + + Parallels HTTP ``POST /logs`` used by :class:`~nvidia_resiliency_ext.fault_tolerance.ft_attribution.AttributionServiceClient` + before full analysis: creates/updates the job record without running LLM analysis. Intended for a + short fire-and-forget call (e.g. daemon thread) before workers start; failures are logged only. + """ + from .types import LogAnalyzerError + + ts = timeout_seconds if timeout_seconds is not None else 60.0 + if not _get_or_create_analyzer(ts, use_lib_log_analysis): + logger.debug("notify_log_path_sync: analyzer not ready; skip %s", log_path) + return + + validated = _lib_analyzer.validate_path(log_path, require_regular_file=True, reject_empty=False) + if isinstance(validated, LogAnalyzerError): + logger.debug("notify_log_path_sync: skip (path validation): %s", validated.message) + return + + async def _submit_only() -> None: + await _lib_analyzer.submit(validated, user=user, job_id=job_id or None) + + try: + future = asyncio.run_coroutine_threadsafe(_submit_only(), _lib_loop) + future.result(timeout=_NOTIFY_SUBMIT_TIMEOUT_S) + except FuturesTimeoutError: + logger.warning( + "notify_log_path_sync: submit timed out after %.0fs: %s", + _NOTIFY_SUBMIT_TIMEOUT_S, + log_path, + ) + except Exception as e: + logger.warning("notify_log_path_sync: failed for %s: %s: %s", log_path, type(e).__name__, e) diff --git a/src/nvidia_resiliency_ext/attribution/trace_analyzer/fr_attribution.py b/src/nvidia_resiliency_ext/attribution/trace_analyzer/fr_attribution.py index 664db223..9fcef874 100644 --- a/src/nvidia_resiliency_ext/attribution/trace_analyzer/fr_attribution.py +++ b/src/nvidia_resiliency_ext/attribution/trace_analyzer/fr_attribution.py @@ -298,7 +298,7 @@ async def collective_analysis(self, analysis_output: str) -> Optional[str]: None: Results are printed to standard output Note: - Requires the NVIDIA_API_KEY environment variable to be set + Requires an LLM API key (LLM_API_KEY / LLM_API_KEY_FILE, or default ~/.llm_api_key paths). """ result = analysis_output cfg = effective_run_or_init_config(self._init_config) @@ -337,11 +337,14 @@ async def collective_analysis(self, analysis_output: str) -> Optional[str]: prompt = PromptTemplate(template=template, input_variables=["analysis_output"]) # Check for API key - from nvidia_resiliency_ext.attribution.api_keys import load_nvidia_api_key + from nvidia_resiliency_ext.attribution.api_keys import load_llm_api_key - api_key = load_nvidia_api_key() + api_key = load_llm_api_key() if not api_key: - eprint("NVIDIA_API_KEY not found. Set env var or create ~/.nvidia_api_key") + eprint( + "LLM API key not found. Set LLM_API_KEY / LLM_API_KEY_FILE " + "or default key files (~/.llm_api_key, ~/.config/nvrx/llm_api_key)" + ) return default_values = { @@ -364,7 +367,10 @@ async def collective_analysis(self, analysis_output: str) -> Optional[str]: eprint("pip install langchain langchain-nvidia-ai-endpoints") except Exception as e: eprint(f"\nError using LangChain: {e}") - eprint("Set NVIDIA_API_KEY env var or create ~/.nvidia_api_key") + eprint( + "Set LLM_API_KEY / LLM_API_KEY_FILE or default key files " + "(~/.llm_api_key, ~/.config/nvrx/llm_api_key)" + ) return result """ diff --git a/src/nvidia_resiliency_ext/fault_tolerance/__init__.py b/src/nvidia_resiliency_ext/fault_tolerance/__init__.py index c3a9cc28..3e43e6a6 100644 --- a/src/nvidia_resiliency_ext/fault_tolerance/__init__.py +++ b/src/nvidia_resiliency_ext/fault_tolerance/__init__.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from nvidia_resiliency_ext.fault_tolerance.ft_attribution import AttributionRunConfig # noqa: F401 + from .config import FaultToleranceConfig # noqa: F401 from .data import WorkloadAction # noqa: F401 from .data import WorkloadControlRequest # noqa: F401 diff --git a/src/nvidia_resiliency_ext/fault_tolerance/config.py b/src/nvidia_resiliency_ext/fault_tolerance/config.py index f6ecec01..1e5b3b15 100644 --- a/src/nvidia_resiliency_ext/fault_tolerance/config.py +++ b/src/nvidia_resiliency_ext/fault_tolerance/config.py @@ -19,11 +19,79 @@ import logging import signal from dataclasses import dataclass, fields -from typing import Mapping, Optional +from typing import List, Mapping, Optional import yaml +def _slack_bot_token_repr(token: Optional[str]) -> str: + """Safe token form for :meth:`SlackConfig.__repr__` (no raw secret).""" + if not token or not str(token).strip(): + return "None" + t = str(token).strip() + if len(t) <= 4: + return "'***'" + return repr(f"…{t[-4:]}") + + +def _read_token_from_file(path: str) -> Optional[str]: + """Read token from file path. Returns stripped content or None on error.""" + if not path or not path.strip(): + return None + try: + with open(path.strip(), "r") as f: + return f.read().strip() or None + except OSError: + return None + + +@dataclass(frozen=True) +class SlackConfig: + """Slack notification config. Reusable by attribution and other FT modules.""" + + bot_token: Optional[str] = None + channel: Optional[str] = None + + def __repr__(self) -> str: + return ( + f"SlackConfig(bot_token={_slack_bot_token_repr(self.bot_token)}, " + f"channel={self.channel!r})" + ) + + def to_dict(self, *, include_secrets: bool = False) -> dict: + """Serialize Slack settings. + + Args: + include_secrets: If True, include ``bot_token`` in plaintext (required for rendezvous + wire format and round-trip with :meth:`from_dict`). If False, omit the raw token + (``bot_token`` is None; ``bot_token_present`` indicates a token was configured) — + safe for logging or diagnostics. + + Warning: + Never log the return value when ``include_secrets`` is True. + """ + ch = self.channel + if include_secrets: + return {"bot_token": self.bot_token, "channel": ch} + out: dict = {"bot_token": None, "channel": ch} + if self.bot_token: + out["bot_token_present"] = True + return out + + @classmethod + def from_dict(cls, d: Optional[dict]) -> Optional["SlackConfig"]: + if not d: + return None + tok = d.get("bot_token") + token_file = d.get("bot_token_file") + if token_file: + tok = _read_token_from_file(token_file) or tok + ch = d.get("channel") + if tok is None and ch is None: + return None + return cls(bot_token=tok, channel=ch) + + @dataclass class FaultToleranceConfig: """ @@ -95,9 +163,18 @@ class FaultToleranceConfig: out-of-section timeouts. The first N iterations (relative to cycle start) are excluded from timeout monitoring as they can be significantly slower than steady-state iterations. Default: 5. Can be overridden by workload (e.g., Megatron-LM via init_workload_monitoring). - * Attribution service (optional): - - `attrsvc_host` [str] hostname/IP of the attribution service - - `attrsvc_port` [int] port of the attribution service + * Attribution (optional): `attribution_backends` [list[str]|None] — backends to query in order + (each ``"mcp"`` or an HTTP URL). Use this for multiple backends (e.g. MCP + third-party URL). + None = disabled. ``"mcp"`` = LogSage + FR via MCP; URL = HTTP attribution service. + `attribution_timeout_seconds` [int] = wait/timeout in seconds (default 60). + `attribution_dry_run` [bool] = if True, run attribution chain but do not apply the action + (log what would happen; useful for validation). Default: False. + * Slack (shared by attribution and other FT modules): `slack` [SlackConfig|None]. + Token via `bot_token_file` (CLI/yaml) or SLACK_BOT_TOKEN/SLACK_BOT_TOKEN_FILE env. + * `dataflow_index` [str|None] = Elasticsearch/dataflow index for attribution posting (mcp/URL). None = disabled. + * `llm_api_key_file` [str|None] = Path to a file containing the LLM API key for MCP attribution. + When set, the FT attribution client sets ``LLM_API_KEY_FILE`` before initializing the analyzer. + When None, only ``LLM_API_KEY`` / ``LLM_API_KEY_FILE`` and default key files apply (see ``api_keys.load_llm_api_key``). * `cycle_info_dir` [str|None] Full path to the NVRx cycle info directory (e.g. /nvrx/). If set, the TCPStore host writes cycle info JSON files and the @@ -144,9 +221,14 @@ class FaultToleranceConfig: num_warmup_iterations: int = ( 5 # Number of warmup iterations before monitoring step section and out-of-section timeouts ) - # Attribution service configuration (optional) - attrsvc_host: Optional[str] = None - attrsvc_port: Optional[int] = None + # Attribution: None = off; use attribution_backends (list of mcp and/or URLs) + attribution_backends: Optional[List[str]] = None + attribution_timeout_seconds: int = 60 + attribution_dry_run: bool = False # Run attribution chain but don't apply action; log only + # Slack (shared by attribution and other FT modules) + slack: Optional["SlackConfig"] = None + dataflow_index: Optional[str] = None + llm_api_key_file: Optional[str] = None # NVRx cycle info: base directory for cycle_info JSON files cycle_info_dir: Optional[str] = None @@ -171,6 +253,25 @@ def from_kwargs(ignore_not_recognized: bool = True, **kwargs) -> 'FaultTolerance Raises: ValueError: If there are unrecognized arguments and ignore_not_recognized is False. """ + # Preprocess slack: build from nested slack: {...} or flat slack_bot_token/slack_channel/slack_bot_token_file + kwargs = dict(kwargs) + if "slack" not in kwargs and ( + "slack_bot_token" in kwargs + or "slack_bot_token_file" in kwargs + or "slack_channel" in kwargs + ): + tok = kwargs.pop("slack_bot_token", None) + token_file = kwargs.pop("slack_bot_token_file", None) + if token_file: + tok = _read_token_from_file(token_file) or tok + kwargs["slack"] = SlackConfig( + bot_token=tok, + channel=kwargs.pop("slack_channel", None), + ) + slack_val = kwargs.get("slack") + if isinstance(slack_val, dict): + kwargs["slack"] = SlackConfig.from_dict(slack_val) + fields_set = {f.name for f in fields(FaultToleranceConfig) if f.init} matching_args = {k: v for k, v in kwargs.items() if k in fields_set} extra_args = {k: v for k, v in kwargs.items() if k not in fields_set} @@ -289,6 +390,10 @@ def from_args(args: argparse.Namespace): 'gpu_memory_poll_interval', ] for field in fields(FaultToleranceConfig): + if field.name == "slack": + continue # Handled below from ft_slack_bot_token / ft_slack_channel + if field.name == "attribution_backends": + continue # Merged below so YAML + CLI combine correctly cli_field_name = f"ft_{field.name}" val = getattr(args, cli_field_name, None) if val is not None: @@ -298,10 +403,26 @@ def from_args(args: argparse.Namespace): val = FaultToleranceConfig._parse_timeout_arg(val) cli_ft_args[field.name] = val + # Slack from --ft-slack-token-file / --ft-slack-channel (token from file only; env fallback in ft_attribution) + slack_token_file = getattr(args, "ft_slack_bot_token_file", None) + slack_tok = _read_token_from_file(slack_token_file) if slack_token_file else None + slack_ch = getattr(args, "ft_slack_channel", None) + if slack_tok is not None or slack_ch is not None: + cli_ft_args["slack"] = SlackConfig(bot_token=slack_tok, channel=slack_ch) + # Update config with CLI args for arg_name, arg_val in cli_ft_args.items(): setattr(ft_cfg, arg_name, arg_val) + # Merge attribution: YAML/file base + --ft-attribution-backend (repeatable append) + from nvidia_resiliency_ext.fault_tolerance.ft_attribution import dedupe_attribution_backends + + ab_cli = getattr(args, "ft_attribution_backends", None) + if ab_cli is not None: + base = list(ft_cfg.attribution_backends or []) + base.extend(ab_cli) + ft_cfg.attribution_backends = dedupe_attribution_backends(base) if base else None + # Fix any type issues ft_cfg._fix_log_level_type() ft_cfg._fix_rank_termination_signal_type() diff --git a/src/nvidia_resiliency_ext/fault_tolerance/ft_attribution.py b/src/nvidia_resiliency_ext/fault_tolerance/ft_attribution.py new file mode 100644 index 00000000..49b3414c --- /dev/null +++ b/src/nvidia_resiliency_ext/fault_tolerance/ft_attribution.py @@ -0,0 +1,417 @@ +# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fault-tolerance integration with attribution. + +Integrates :mod:`nvidia_resiliency_ext.attribution.log_analyzer` with ``ft_launcher``. +Supports **multiple backends** per run: e.g. ``mcp`` (LogSage + FR via MCP) and one or more +HTTP URLs (third-party or attrsvc). Restart/stop uses **any** backend that reports +``STOP`` / no-restart (:func:`attribution_no_restart`). +""" + +import logging +import os +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple +from urllib.parse import quote_plus + +import httpx + +from nvidia_resiliency_ext.attribution.log_analyzer.llm_output import attribution_no_restart +from nvidia_resiliency_ext.fault_tolerance.config import SlackConfig +from nvidia_resiliency_ext.fault_tolerance.utils import job_id_from_env, job_user_from_env +from nvidia_resiliency_ext.shared_utils.log_manager import LogConfig + +logger = logging.getLogger(LogConfig.name) + +__all__ = [ + "AttributionRunConfig", + "LogAnalysisClient", + "AttributionServiceClient", + "SlackConfig", + "attribution_no_restart", + "dedupe_attribution_backends", + "validate_backend_entry", +] + + +def _validate_attribution_url(url: str) -> str: + """Validate attribution URL and return normalized form (with scheme if missing).""" + if not url or not url.strip(): + raise ValueError("attribution backend URL must be non-empty") + s = url.strip() + if "://" in s: + if not s.startswith(("http://", "https://")): + raise ValueError(f"attribution backend: expected http(s) URL, got: {url!r}") + return s + if ":" in s: + return f"http://{s}" + raise ValueError( + f"attribution backend: expected host:port or http(s)://host:port, got: {url!r}" + ) + + +def validate_backend_entry(entry: str) -> str: + """Normalize one backend: ``mcp`` or a validated HTTP URL string.""" + v = entry.strip() + if not v: + raise ValueError("empty attribution backend entry") + low = v.lower() + if low == "lib": + raise ValueError("attribution from ft_launcher requires 'mcp' or an HTTP URL") + if low == "mcp": + return "mcp" + return _validate_attribution_url(v) + + +def dedupe_attribution_backends(backends: List[str]) -> List[str]: + """Deduplicate backend list: one ``mcp``, unique URLs (order preserved).""" + out: List[str] = [] + seen_mcp = False + seen_urls: set = set() + for raw in backends: + b = validate_backend_entry(raw) + if b == "mcp": + if not seen_mcp: + seen_mcp = True + out.append("mcp") + else: + logger.debug("duplicate mcp backend entry ignored") + else: + if b not in seen_urls: + seen_urls.add(b) + out.append(b) + return out + + +@dataclass(frozen=True) +class AttributionRunConfig: + """FT attribution: one or more backends (``mcp``, HTTP URL(s)). + + Each backend is either the literal ``mcp`` (LogSage + optional FR via MCP) or an HTTP + base URL for a log/attribution service. :class:`LogAnalysisClient` queries backends in + order; :meth:`~LogAnalysisClient.should_stop` is True if **any** backend says do not restart. + + ``llm_api_key_file`` (optional): if set, :class:`LogAnalysisClient` sets ``LLM_API_KEY_FILE`` before MCP init. + """ + + backends: Tuple[str, ...] + timeout_seconds: int = 60 + slack: Optional[SlackConfig] = None + dataflow_index: Optional[str] = None + llm_api_key_file: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + d: Dict[str, Any] = { + "backends": list(self.backends), + "timeout_seconds": self.timeout_seconds, + } + if self.slack is not None: + d["slack"] = self.slack.to_dict(include_secrets=True) + if self.dataflow_index is not None: + d["dataflow_index"] = self.dataflow_index + if self.llm_api_key_file is not None: + d["llm_api_key_file"] = self.llm_api_key_file + return d + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "AttributionRunConfig": + """Load from serialized dict. Requires a non-empty ``backends`` list (same shape as :meth:`to_dict`).""" + raw = d.get("backends") if d else None + if not raw: + raise ValueError( + "attribution_config requires non-empty 'backends' (e.g. ['mcp'] or HTTP URL strings)" + ) + backends = tuple(validate_backend_entry(x) for x in raw) + return cls( + backends=backends, + timeout_seconds=int(d.get("timeout_seconds", 60)), + slack=SlackConfig.from_dict(d.get("slack")), + dataflow_index=d.get("dataflow_index"), + llm_api_key_file=d.get("llm_api_key_file"), + ) + + @classmethod + def from_backend_strings( + cls, + entries: List[str], + timeout_seconds: int = 60, + slack: Optional[SlackConfig] = None, + dataflow_index: Optional[str] = None, + llm_api_key_file: Optional[str] = None, + ) -> "AttributionRunConfig": + """Build from CLI/YAML string list (``mcp`` and/or URLs).""" + norm = dedupe_attribution_backends(entries) + if not norm: + raise ValueError("at least one attribution backend is required") + return cls( + backends=tuple(norm), + timeout_seconds=timeout_seconds, + slack=slack, + dataflow_index=dataflow_index, + llm_api_key_file=llm_api_key_file, + ) + + +# --- HTTP client (URL backend) --- +class AttributionServiceClient: + """ + HTTP client for an attribution service (URL backend). + Talks to HTTP APIs that expose log analysis results (e.g. nvrx_attrsvc). + """ + + def __init__(self, base_url: str, timeout_seconds: float = 60.0): + self._base_url = base_url.rstrip("/") + self._timeout = max(1.0, float(timeout_seconds)) + + def path_notify(self, log_path: str) -> None: + """Notify path before workers start (fire-and-forget POST).""" + threading.Thread( + target=self._do_submit_log, + args=(log_path,), + daemon=True, + ).start() + + def _do_submit_log(self, log_path: str) -> None: + try: + with httpx.Client(timeout=10.0) as client: + url = f"{self._base_url}/logs" + logger.debug("AttributionServiceClient POST: %s (log_path=%s)", url, log_path) + client.post( + url, + json={"log_path": log_path}, + headers={"accept": "application/json"}, + ) + except Exception as e: + logger.warning( + "AttributionServiceClient POST %s failed: %s: %s", log_path, type(e).__name__, e + ) + + def get_result_sync(self, log_path: str) -> Optional[Dict[str, Any]]: + """Get analysis results via GET (blocking). Uses client timeout.""" + if not log_path: + return None + try: + with httpx.Client(timeout=self._timeout) as client: + q_path = quote_plus(log_path) + url = f"{self._base_url}/logs?log_path={q_path}" + logger.debug("AttributionServiceClient GET: %s (log_path=%s)", url, log_path) + resp = client.get(url, headers={"accept": "application/json"}) + if resp.status_code == 200: + payload = resp.json() if resp.text else {} + result = payload.get("result", payload) + if isinstance(result, dict): + return result + return {"result": result} if result is not None else None + logger.warning( + "AttributionServiceClient GET for %s returned %d", log_path, resp.status_code + ) + return None + except Exception as e: + logger.warning( + "AttributionServiceClient GET %s failed: %s: %s", log_path, type(e).__name__, e + ) + return None + + +class LogAnalysisClient: + """Run cycle attribution across one or more backends (MCP + HTTP URL(s)).""" + + def __init__(self, config: AttributionRunConfig) -> None: + self._config = config + self._timeout = max(1, config.timeout_seconds) + self._user = job_user_from_env() + self._job_id = job_id_from_env() + self._fetchers: List[Callable[[str], Optional[Dict[str, Any]]]] = [] + self._path_notify_fns: List[Callable[[str], None]] = [] + self._init_backends() + + def _init_backends(self) -> None: + from nvidia_resiliency_ext.attribution.log_analyzer.runner import ( + ensure_analyzer_ready, + notify_log_path_sync, + run_log_analysis_sync, + ) + from nvidia_resiliency_ext.attribution.postprocessing import config as pp_config + from nvidia_resiliency_ext.attribution.postprocessing import configure_from_env + + slack_cfg = self._config.slack + if slack_cfg is None: + slack_token_arg: Optional[str] = None + slack_channel_arg: Optional[str] = None + else: + slack_token_arg = slack_cfg.bot_token + if slack_token_arg is not None: + slack_token_arg = slack_token_arg.strip() + slack_channel_arg = slack_cfg.channel + if slack_channel_arg is not None: + slack_channel_arg = slack_channel_arg.strip() + df_idx = (self._config.dataflow_index or "").strip() + cluster = (os.getenv("SLURM_CLUSTER_NAME") or "").strip() + configure_from_env( + slack_token=slack_token_arg, + slack_channel=slack_channel_arg, + dataflow_index=df_idx, + cluster_name=cluster, + ) + llm_key_file = (self._config.llm_api_key_file or "").strip() + if llm_key_file: + os.environ["LLM_API_KEY_FILE"] = llm_key_file + logger.debug("FT attribution: set LLM_API_KEY_FILE from fault_tolerance config") + if pp_config.slack_bot_token: + logger.info( + "Slack notifications enabled for FT attribution (channel=%s)", + pp_config.slack_channel or "(none)", + ) + if df_idx: + logger.info( + "FT attribution: dataflow posting enabled (index=%s, cluster=%s)", + df_idx, + cluster or "(unset)", + ) + + mcp_initialized = False + for b in self._config.backends: + if b == "mcp": + if mcp_initialized: + continue + if not ensure_analyzer_ready( + timeout_seconds=float(self._timeout), use_lib_log_analysis=False + ): + logger.warning("FT attribution: MCP analyzer not ready; skipping mcp backend") + continue + mcp_initialized = True + logger.info( + "FT attribution: MCP backend — nvrx-mcp-analysis (log + FR when discoverable)" + ) + user = self._user + job_id = self._job_id + + def _fetch_mcp(log_path: str, u=user, j=job_id) -> Optional[Dict[str, Any]]: + return run_log_analysis_sync( + log_path, + user=u, + job_id=j, + timeout_seconds=float(self._timeout), + use_lib_log_analysis=False, + ) + + self._fetchers.append(_fetch_mcp) + + def _path_notify_mcp( + log_path: str, u=user, j=job_id, to=float(self._timeout) + ) -> None: + def _run() -> None: + try: + notify_log_path_sync( + log_path, + user=u, + job_id=j, + timeout_seconds=to, + use_lib_log_analysis=False, + ) + except Exception as e: + logger.warning( + "FT attribution: MCP path_notify failed: %s: %s", + type(e).__name__, + e, + ) + + threading.Thread( + target=_run, + daemon=True, + name="ft-attr-mcp-path-notify", + ).start() + + self._path_notify_fns.append(_path_notify_mcp) + else: + svc = AttributionServiceClient(base_url=b, timeout_seconds=float(self._timeout)) + + def _fetch_url( + log_path: str, client: AttributionServiceClient = svc + ) -> Optional[Dict[str, Any]]: + return client.get_result_sync(log_path) + + self._fetchers.append(_fetch_url) + self._path_notify_fns.append(svc.path_notify) + + if not self._fetchers: + logger.warning("FT attribution: no usable backends; attribution disabled") + + def fetch_result(self, log_path: str) -> Optional[Dict[str, Any]]: + """Return the first non-None result across backends (for debugging); prefer :meth:`should_stop`.""" + for fetch in self._fetchers: + r = fetch(log_path) + if r is not None: + return r + return None + + def should_stop(self, log_path: str) -> bool: + """True if **any** backend recommends do not restart. + + Backends are queried **in parallel** so wall-clock time is roughly the slowest + fetch (each fetch already applies its own timeout), not the sum of all backends. + """ + fetchers = self._fetchers + if not fetchers: + return False + + def _safe_fetch( + fetch: Callable[[str], Optional[Dict[str, Any]]] + ) -> Optional[Dict[str, Any]]: + try: + return fetch(log_path) + except Exception as e: + logger.debug( + "FT attribution: should_stop backend fetch failed: %s: %s", + type(e).__name__, + e, + ) + return None + + if len(fetchers) == 1: + return attribution_no_restart(_safe_fetch(fetchers[0])) + + ex = ThreadPoolExecutor( + max_workers=len(fetchers), + thread_name_prefix="ft-attr-should-stop", + ) + try: + futures = [ex.submit(_safe_fetch, f) for f in fetchers] + for fut in as_completed(futures): + try: + r = fut.result() + except Exception: + r = None + if attribution_no_restart(r): + return True + return False + finally: + ex.shutdown(wait=False, cancel_futures=True) + + @property + def path_notify(self) -> Optional[Callable[[str], None]]: + """Chain backends' early path notify (fire-and-forget): MCP submit-only, then HTTP POST /logs.""" + if not self._path_notify_fns: + return None + + def _notify(log_path: str) -> None: + for fn in self._path_notify_fns: + fn(log_path) + + return _notify diff --git a/src/nvidia_resiliency_ext/fault_tolerance/ft_rendezvous_barrier.py b/src/nvidia_resiliency_ext/fault_tolerance/ft_rendezvous_barrier.py index d4d1c39c..63a7aefb 100644 --- a/src/nvidia_resiliency_ext/fault_tolerance/ft_rendezvous_barrier.py +++ b/src/nvidia_resiliency_ext/fault_tolerance/ft_rendezvous_barrier.py @@ -51,11 +51,14 @@ RendezvousInfo = None RendezvousStoreInfo = None +from nvidia_resiliency_ext.fault_tolerance.ft_attribution import ( + AttributionRunConfig, + LogAnalysisClient, +) from nvidia_resiliency_ext.shared_utils.log_manager import LogConfig from ..inprocess.utils import format_rank_set_verbose from ..shared_utils.health_check import ( - AttributionService, DistributedStorageHealthCheck, GPUHealthCheck, NicLinkStateHealthCheck, @@ -709,6 +712,20 @@ def _increment_peer_aborted_count(self) -> int: new_count = self.store.add(self.peer_aborted_count_key, 1) return new_count + def _undo_peer_abort_notify(self) -> None: + """Best-effort undo of one :meth:`_increment_peer_aborted_count` (e.g. attribution veto). + + Used when the launcher incremented to wake peers, then decided not to restart so + healthy nodes should not stay in a raised ``peer_aborted_count`` state. + """ + try: + cur = self._get_peer_aborted_count() + if cur <= 0: + return + self.store.add(self.peer_aborted_count_key, -1) + except Exception as e: + log.warning("peer_aborted_count rollback failed: %s", e) + def _get_peer_aborted_count(self) -> int: """Get the current peer aborted count. @@ -1499,8 +1516,7 @@ def from_backend( enable_dist_storage_healthcheck: bool = False, link_state_path_template: Optional[str] = None, storage_healthcheck_paths: Optional[list] = None, - attrsvc_host: Optional[str] = None, - attrsvc_port: Optional[int] = None, + attribution_config: Optional[AttributionRunConfig] = None, ): """Create a new :py:class:`FtRendezvousBarrierHandler`. @@ -1531,10 +1547,8 @@ def from_backend( Template path for NIC link state files. storage_healthcheck_paths: List of storage paths to check for health. - attrsvc_host: - Hostname or IP address of the attribution service. - attrsvc_port: - Port number of the attribution service. + attribution_config: + Multi-backend attribution config (:class:`~nvidia_resiliency_ext.fault_tolerance.ft_attribution.AttributionRunConfig`). """ # We associate each handler instance with a unique node descriptor. node = cls._node_desc_generator.generate(local_addr) @@ -1560,8 +1574,7 @@ def from_backend( enable_dist_storage_healthcheck=enable_dist_storage_healthcheck, link_state_path_template=link_state_path_template, storage_healthcheck_paths=storage_healthcheck_paths, - attrsvc_host=attrsvc_host, - attrsvc_port=attrsvc_port, + attribution_config=attribution_config, ) def __init__( @@ -1575,8 +1588,7 @@ def __init__( enable_dist_storage_healthcheck: bool = False, link_state_path_template: Optional[str] = None, storage_healthcheck_paths: Optional[list] = None, - attrsvc_host: Optional[str] = None, - attrsvc_port: Optional[int] = None, + attribution_config: Optional[AttributionRunConfig] = None, ) -> None: if not settings.run_id: raise ValueError("The run id must be a non-empty string.") @@ -1637,14 +1649,15 @@ def __init__( StoragePathHealthCheck(storage_healthcheck_paths) if storage_healthcheck_paths else None ) - # Attribution service client (optional, only on master node) - if is_store_host and attrsvc_host and attrsvc_port is not None: - self._attr_service = AttributionService( - host=attrsvc_host, - port=int(attrsvc_port), - ) - else: - self._attr_service = None + # Attribution: log analysis client (optional, only when config enabled) + self._log_analysis_client = None + if is_store_host and attribution_config is not None: + self._log_analysis_client = LogAnalysisClient(attribution_config) + + @property + def log_analysis_client(self) -> Optional[LogAnalysisClient]: + """Log analysis client for attribution, or None if not configured.""" + return self._log_analysis_client @property def _rendezvous_round(self) -> int: @@ -1795,11 +1808,6 @@ def ensure_node_is_healthy(self) -> None: f"Node {self._this_node} has invalid or unreadable paths.", ) - # Perform optional log analysis (non-fatal) - # Note: _submit_log() was already called from launcher before workers started - if self._attr_service is not None: - self._attr_service() - # Perform Node health check (external service if available) _nodehealth_checker = get_node_health_check() if _nodehealth_checker is not None: @@ -2153,8 +2161,10 @@ def create_handler( ) storage_healthcheck_paths = params.config.get('storage_healthcheck_paths', None) link_state_path_template = params.config.get('link_state_path_template', None) - attrsvc_host = params.config.get('attrsvc_host', None) - attrsvc_port = params.config.get('attrsvc_port', None) + attribution_cfg_dict = params.config.get('attribution_config', None) + attribution_config = None + if attribution_cfg_dict: + attribution_config = AttributionRunConfig.from_dict(attribution_cfg_dict) return FtRendezvousBarrierHandler.from_backend( params.run_id, @@ -2171,8 +2181,7 @@ def create_handler( enable_dist_storage_healthcheck=enable_dist_storage_healthcheck, link_state_path_template=link_state_path_template, storage_healthcheck_paths=storage_healthcheck_paths, - attrsvc_host=attrsvc_host, - attrsvc_port=attrsvc_port, + attribution_config=attribution_config, ) except Exception as e: construct_and_record_rdzv_event( diff --git a/src/nvidia_resiliency_ext/fault_tolerance/launcher.py b/src/nvidia_resiliency_ext/fault_tolerance/launcher.py index 24926cfe..677ca17b 100644 --- a/src/nvidia_resiliency_ext/fault_tolerance/launcher.py +++ b/src/nvidia_resiliency_ext/fault_tolerance/launcher.py @@ -77,6 +77,7 @@ FT_RANK_MONITOR_IPC_SOCKET_ENV_VAR, UpdateConfigMsg, ) +from nvidia_resiliency_ext.fault_tolerance.ft_attribution import AttributionRunConfig from nvidia_resiliency_ext.fault_tolerance.per_cycle_logs import PipeBasedLogsSpecs from nvidia_resiliency_ext.fault_tolerance.progress_tracker import TrainingProgressTracker from nvidia_resiliency_ext.fault_tolerance.rank_monitor_server import RankMonitorServer @@ -85,6 +86,7 @@ get_processes_by_pgids, hostnames_to_slurm_nodelist, is_slurm_job_array, + job_id_from_env, patched_method, terminate_mp_processes, write_obj_to_ipc_stream, @@ -510,7 +512,7 @@ def _on_cycle_end(self) -> None: if self._cycle_info_writer is None: return current_cycle = self._get_global_restart_count() - job_id = os.environ.get("SLURM_ARRAY_JOB_ID") or os.environ.get("SLURM_JOB_ID", "") + job_id = job_id_from_env() attempt_index = int(os.environ.get("SLURM_RESTART_CNT", "0")) self._cycle_info_writer.update_cycle_end( job_id=job_id, @@ -523,7 +525,7 @@ def _write_cycle_start_info(self, current_cycle: int) -> Optional[str]: """Write NVRx cycle info at cycle start. Returns path to current cycle info file, or None.""" if self._cycle_info_writer is None: return None - job_id = os.environ.get("SLURM_ARRAY_JOB_ID") or os.environ.get("SLURM_JOB_ID", "") + job_id = job_id_from_env() attempt_index = int(os.environ.get("SLURM_RESTART_CNT", "0")) cycle_log_file = self._logs_specs.get_cycle_log_file(current_cycle) # Legacy FtRendezvousHandler does not define these; barrier handler does. @@ -594,6 +596,11 @@ def _open_rendezvous_for_restart(self): logger.error(f"Failed to open rendezvous: {e}") # For legacy rendezvous, no action needed - it uses different mechanism + def _restart_workers(self, worker_group: WorkerGroup, *args, **kwargs) -> None: + """Override to pass will_restart and time_consumed_before_reclaim to _stop_workers.""" + self._stop_workers(worker_group, *args, will_restart=True, **kwargs) + self._start_workers(worker_group) + def _handle_restart_decision( self, role: str, @@ -602,19 +609,52 @@ def _handle_restart_decision( open_rendezvous: bool = False, notify_peer: bool = False, ) -> bool: - """Handle restart decision logic based on progress tracking and remaining restarts. + """Decide whether to restart based on attribution, progress tracking, and remaining restarts. + + If restart: calls _restart_workers and returns True. + If stop: returns False; caller must call _stop_workers. Args: role: The role name for logging spec: Worker specification log_msg: Custom log message for restart - open_rendezvous: Whether to open rendezvous before restart (for barrier-based rendezvous) - notify_peer: Whether to notify peers to abort the workers in current cycle. + open_rendezvous: If True, open rendezvous for restart only after attribution and progress + checks pass (barrier handler only). Opening before then can let peers join a round this + node might never complete if attribution vetoes restart. + notify_peer: If True, increment ``peer_aborted_count`` immediately so healthy peers can + observe failure without waiting for local attribution. If attribution, progress, or + remaining-restart checks then veto restart, the increment is rolled back. Returns: - True if restart was initiated (caller should continue monitoring loop) - False if no restart (caller should stop workers and return failure) + True if restart was initiated, False if no restart (caller should call _stop_workers). """ + peer_abort_incremented = False + if notify_peer and hasattr(self._rdzv_handler, '_barrier_state'): + self._rdzv_handler._barrier_state._increment_peer_aborted_count() + peer_abort_incremented = True + + def _rollback_peer_abort_notify() -> None: + nonlocal peer_abort_incremented + if not peer_abort_incremented: + return + if hasattr(self._rdzv_handler, '_barrier_state'): + self._rdzv_handler._barrier_state._undo_peer_abort_notify() + peer_abort_incremented = False + + start = time.time() + should_terminate_early = self._run_attribution() + if should_terminate_early: + if self._ft_cfg.attribution_dry_run: + logger.info( + "[%s] Attribution dry run: would NOT restart (attribution says stop), " + "but proceeding as configured (action not applied).", + role, + ) + else: + logger.error("[%s] Attribution says do not restart; will not restart.", role) + _rollback_peer_abort_notify() + return False + self._progress_tracker.analyze_previous_cycle() should_terminate_early = self._progress_tracker.should_terminate_early() @@ -624,19 +664,22 @@ def _handle_restart_decision( "No more restarts will be attempted.", role ) + _rollback_peer_abort_notify() return False elif self._remaining_restarts > 0: logger.info(log_msg, role) self._remaining_restarts -= 1 - # Increment peer_aborted_count to notify other nodes (for barrier-based rendezvous) - if notify_peer and hasattr(self._rdzv_handler, '_barrier_state'): - self._rdzv_handler._barrier_state._increment_peer_aborted_count() if open_rendezvous: self._open_rendezvous_for_restart() - self._restart_workers(self._worker_group) + time_consumed = time.time() - start + self._restart_workers( + self._worker_group, + time_consumed_before_reclaim=time_consumed, + ) return True else: # No more restarts available + _rollback_peer_abort_notify() return False def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: @@ -689,13 +732,10 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes ) should_restart = self._handle_restart_decision( role, spec, log_msg, open_rendezvous=True, - notify_peer=True + notify_peer=True, ) - if should_restart: - continue # Continue monitoring after restart - - # No more restarts (either exhausted or early termination) + continue self._stop_workers(self._worker_group) self._worker_group.state = WorkerState.FAILED return RunResult(state=WorkerState.FAILED) @@ -722,17 +762,16 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes f"(nodes_waiting={num_nodes_waiting}, peer_aborted={peer_aborted_count}); " f"will restart worker group" ) - # Note: The node that triggered the change (unhealthy or new) already opened - # the rendezvous, so we don't need to open it again here. + # Note: The node that triggered the change already opened the rendezvous. should_restart = self._handle_restart_decision( role, spec, log_msg, open_rendezvous=False, - notify_peer=False + notify_peer=False, ) - - if not should_restart: - self._stop_workers(self._worker_group) - self._worker_group.state = WorkerState.FAILED - return RunResult(state=WorkerState.FAILED) + if should_restart: + continue + self._stop_workers(self._worker_group) + self._worker_group.state = WorkerState.FAILED + return RunResult(state=WorkerState.FAILED) else: raise Exception(f"[{role}] Worker group in {state.name} state") @@ -950,15 +989,41 @@ def _log_watchdog_event( event = events.Event(name=name, source=events.EventSource.AGENT, metadata=metadata) events.record(event) + @property + def _log_analysis_client(self): + """Log analysis client from rdzv handler, or None if not configured.""" + return getattr(self._rdzv_handler, "log_analysis_client", None) + + def _run_attribution(self) -> bool: + """Run attribution if configured. Returns True if attribution says do not restart, else False.""" + if not self._is_store_host or self._log_analysis_client is None: + return False + cycle_log_file = None + if hasattr(self._logs_specs, "get_cycle_log_file"): + current_cycle = self._get_global_restart_count() + cycle_log_file = self._logs_specs.get_cycle_log_file(current_cycle) + if cycle_log_file is None: + return False + return self._log_analysis_client.should_stop(cycle_log_file) + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator # `torch.distributed.elastic.metrics.prof`. @prof - def _stop_workers(self, worker_group: WorkerGroup, *args, **kwargs) -> None: + def _stop_workers( + self, worker_group: WorkerGroup, *args, **kwargs + ) -> Optional[Any]: # Support both old and new SimpleElasticAgent._stop_workers signatures: # - Before 2.5.1: _stop_workers(self, worker_group: WorkerGroup) -> None # - 2.5.1: _stop_workers(self, worker_group: WorkerGroup, is_restarter: bool = False) -> None # - 2.7.1+: _stop_workers(self, worker_group: WorkerGroup) -> None (reverted back) # We use *args and **kwargs to handle both cases transparently + # + # Optional: will_restart [bool] - if True, wait for GPU reclaim before next cycle. + # Optional: time_consumed_before_reclaim [float] - deducted from reclaim budget when will_restart. + will_restart: bool = kwargs.pop("will_restart", False) + time_consumed_before_reclaim: float = kwargs.pop( + "time_consumed_before_reclaim", 0.0 + ) logger.info(f"Stopping workers... Timeout = {self._workers_stop_timeout} sec.") # Rank monitors will detect worker shutdown when worker processes disconnect @@ -983,16 +1048,22 @@ def _stop_workers(self, worker_group: WorkerGroup, *args, **kwargs) -> None: else: logger.debug("All worker processes and descendants terminated successfully") - # Wait for GPU memory to be reclaimed BEFORE returning control - # This ensures the node doesn't proceed to the next rendezvous cycle while memory is still tied up - if self._ft_cfg.gpu_memory_reclaim_timeout > 0: - logger.debug( - "Waiting for GPU memory to be reclaimed (timeout: %ds, tolerance: %d MB, poll interval: %ds)...", - int(self._ft_cfg.gpu_memory_reclaim_timeout), - int(self._ft_cfg.gpu_memory_tolerance_mb), - int(self._ft_cfg.gpu_memory_poll_interval), - ) - self._wait_for_gpu_memory_reclaim(worker_group.spec.local_world_size) + # Wait for GPU memory to be reclaimed only when restarting (shutdown case skips). + reclaim_timeout = self._ft_cfg.gpu_memory_reclaim_timeout + if will_restart and reclaim_timeout > 0: + remaining_reclaim = max(0.0, reclaim_timeout - time_consumed_before_reclaim) + if remaining_reclaim > 0: + logger.debug( + "Waiting for GPU memory to be reclaimed (timeout: %.1fs, " + "tolerance: %d MB, poll interval: %ds)...", + remaining_reclaim, + int(self._ft_cfg.gpu_memory_tolerance_mb), + int(self._ft_cfg.gpu_memory_poll_interval), + ) + self._wait_for_gpu_memory_reclaim( + worker_group.spec.local_world_size, + timeout_override=remaining_reclaim, + ) # Wait for reader thread to drain pipes (polls every 100ms, wait 3 cycles) # then close pipe file objects to prevent FD reuse bugs @@ -1023,6 +1094,7 @@ def _stop_workers(self, worker_group: WorkerGroup, *args, **kwargs) -> None: node_id=self._rdzv_handler._this_node, rank=worker_group.group_rank, ) + return None # pyre-fixme[56]: Pyre was not able to infer the type of the decorator # `torch.distributed.elastic.metrics.prof`. @@ -1070,14 +1142,12 @@ def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: f"MASTER_ADDR={master_addr}, MASTER_PORT={master_port}" ) - # Submit current cycle's log to attribution service (master node only, before workers start) - if ( - self._is_store_host - and self._rdzv_handler._attr_service is not None - and hasattr(self._logs_specs, 'get_cycle_log_file') - ): - cycle_log_file = self._logs_specs.get_cycle_log_file(current_cycle) - self._rdzv_handler._attr_service._submit_log(cycle_log_file) + # Early notify: HTTP backends (POST /logs) and MCP (Analyzer.submit only), before workers start + if self._is_store_host and hasattr(self._logs_specs, "get_cycle_log_file"): + client = self._log_analysis_client + if client is not None and client.path_notify is not None: + cycle_log_file = self._logs_specs.get_cycle_log_file(current_cycle) + client.path_notify(cycle_log_file) # Write NVRx cycle info and set env for workload current_cycle_info_path = self._write_cycle_start_info(current_cycle) @@ -1187,7 +1257,9 @@ def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: return self._pcontext.pids() - def _wait_for_gpu_memory_reclaim(self, num_gpus: int) -> None: + def _wait_for_gpu_memory_reclaim( + self, num_gpus: int, timeout_override: Optional[float] = None + ) -> None: """ Wait for GPU memory to be reclaimed below the tolerance threshold before starting new workers. This is called on restarts (not on initial start) to ensure memory has been cleaned up. @@ -1196,6 +1268,7 @@ def _wait_for_gpu_memory_reclaim(self, num_gpus: int) -> None: Args: num_gpus: Number of GPUs on this node + timeout_override: If set, use this instead of gpu_memory_reclaim_timeout (for time accounting). """ def log_memory_stats(memory_stats, num_gpus, log_func, message_template, *args): """Helper to log GPU memory statistics.""" @@ -1219,7 +1292,13 @@ def log_memory_stats(memory_stats, num_gpus, log_func, message_template, *args): ) memory_logger = GPUMemoryLogger() - timeout = self._ft_cfg.gpu_memory_reclaim_timeout + timeout = ( + timeout_override + if timeout_override is not None + else self._ft_cfg.gpu_memory_reclaim_timeout + ) + if timeout <= 0: + return tolerance_mb = self._ft_cfg.gpu_memory_tolerance_mb poll_interval = self._ft_cfg.gpu_memory_poll_interval @@ -1762,9 +1841,7 @@ def launch_agent( # unhealthy_count) before the store goes away. shutdown_rdzv only controls explicit # permanent-close signaling; it cannot keep the store alive after process exit. if is_store_host: - # Trigger attribution service analysis for final cycle - if agent._rdzv_handler._attr_service is not None: - agent._rdzv_handler._attr_service() + # Attribution is invoked on the Restart & progress state path (inside _handle_restart_decision), not at exit. # No ordering required between cycle_info_writer and rendezvous: the writer # is independent I/O. Run grace-period wait and writer shutdown in parallel @@ -2827,22 +2904,67 @@ def get_args_parser() -> ArgumentParser: "format and log the traceback, and use os._exit() to exit the process reliably. Default: False.", ) - # Attribution service configuration (optional) + # Attribution: repeat --ft-attribution-backend for multiple backends parser.add_argument( - "--ft-attrsvc-host", - "--ft_attrsvc_host", - type=str, + "--ft-attribution-backend", + "--ft_attribution_backend", + action="append", default=None, - dest="ft_attrsvc_host", - help="Hostname or IP for the attribution service (e.g., 127.0.0.1).", + dest="ft_attribution_backends", + metavar="BACKEND", + help="Attribution backend (repeatable): mcp or HTTP URL (e.g. http://127.0.0.1:8000). " + "Combined with YAML attribution_backends. Stop/restart if any backend says do not restart.", ) parser.add_argument( - "--ft-attrsvc-port", - "--ft_attrsvc_port", + "--ft-attribution-timeout", + "--ft_attribution_timeout", type=int, + default=60, + dest="ft_attribution_timeout_seconds", + help="Attribution wait/timeout in seconds; skip result if exceeded (default: 60).", + ) + parser.add_argument( + "--ft-attribution-dry-run", + "--ft_attribution_dry_run", + action="store_true", + default=None, + dest="ft_attribution_dry_run", + help="Attribution dry run: run full attribution chain (log analysis, Slack, dataflow) " + "but do not apply the restart/stop decision. Log what would happen instead. " + "Useful for validating the chain without affecting behavior.", + ) + parser.add_argument( + "--ft-llm-api-key-file", + "--ft_llm_api_key_file", + type=str, default=None, - dest="ft_attrsvc_port", - help="Port for the attribution service (e.g., 8000).", + dest="ft_llm_api_key_file", + help="Path to file containing LLM API key for MCP attribution. Sets LLM_API_KEY_FILE " + "before the analyzer starts. Combined with YAML llm_api_key_file; CLI wins when both set.", + ) + parser.add_argument( + "--ft-slack-channel", + "--ft_slack_channel", + type=str, + default=None, + dest="ft_slack_channel", + help="Slack channel for FT alerts (attribution, etc.).", + ) + parser.add_argument( + "--ft-slack-token-file", + "--ft_slack_token_file", + type=str, + default=None, + dest="ft_slack_bot_token_file", + help="Path to file containing Slack bot token. Else uses SLACK_BOT_TOKEN/SLACK_BOT_TOKEN_FILE env.", + ) + parser.add_argument( + "--ft-dataflow-index", + "--ft_dataflow_index", + type=str, + default=None, + dest="ft_dataflow_index", + help="Dataflow/Elasticsearch index for attribution posting (mcp/url). Requires nvdataflow.", ) parser.add_argument( @@ -3075,6 +3197,7 @@ def _validate_slurm_single_launcher_per_node() -> None: " NVRX_ENABLE_MULTI_LAUNCHERS_PER_NODE=1" ) + def _validate_args(args: Any) -> None: """Centralized validation of CLI args (cross-flag consistency). Raises ValueError if invalid.""" n_log_agg = int(getattr(args, "ft_log_aggregator_count", 2)) @@ -3097,6 +3220,7 @@ def _validate_args(args: Any) -> None: "--ft-nvrx-logfile cannot be used when NVRX_NODE_LOCAL_TMPDIR is set." ) + def config_from_args(args, launcher_pipe_read_fd=None, launcher_log_file=None) -> Tuple[LaunchConfig, Union[Callable, str], List[str]]: # If ``args`` not passed, defaults to ``sys.argv[:1]`` _validate_args(args) @@ -3169,18 +3293,27 @@ def config_from_args(args, launcher_pipe_read_fd=None, launcher_log_file=None) - # Pass segment-related configs to rendezvous config rdzv_configs['segment'] = fault_tol_cfg.segment - # Pass NIC health check configs to rendezvous config - rdzv_configs['enable_nic_healthcheck'] = fault_tol_cfg.enable_nic_healthcheck - rdzv_configs['link_state_path_template'] = fault_tol_cfg.link_state_path_template - # Pass enable_nic_healthcheck and link_state_path_template from fault tolerance config to rendezvous config rdzv_configs['enable_nic_healthcheck'] = fault_tol_cfg.enable_nic_healthcheck rdzv_configs['link_state_path_template'] = fault_tol_cfg.link_state_path_template - # Pass attribution service configuration if provided - if getattr(fault_tol_cfg, 'attrsvc_host', None): - rdzv_configs['attrsvc_host'] = fault_tol_cfg.attrsvc_host - if getattr(fault_tol_cfg, 'attrsvc_port', None) is not None: - rdzv_configs['attrsvc_port'] = int(fault_tol_cfg.attrsvc_port) + + # Attribution: merged backends (YAML + --ft-attribution-backend) + attribution_backends = getattr(fault_tol_cfg, "attribution_backends", None) or [] + attribution_timeout = int(getattr(fault_tol_cfg, "attribution_timeout_seconds", 60)) + ft_slack = getattr(fault_tol_cfg, "slack", None) + ft_dataflow_index = getattr(fault_tol_cfg, "dataflow_index", None) + if attribution_backends: + timeout_sec = max(1, attribution_timeout) + llm_key_file = getattr(fault_tol_cfg, "llm_api_key_file", None) + attribution_cfg = AttributionRunConfig.from_backend_strings( + attribution_backends, + timeout_seconds=timeout_sec, + slack=ft_slack, + dataflow_index=ft_dataflow_index, + llm_api_key_file=llm_key_file, + ) + rdzv_configs["attribution_config"] = attribution_cfg.to_dict() + # Pass distributed storage health check configuration cli_dist_storage = getattr(args, 'ft_enable_dist_storage_healthcheck', None) if cli_dist_storage is not None: @@ -3631,6 +3764,7 @@ def _wait_grpc_subprocess_after_terminate(p: subprocess.Popen, wait_timeout: flo p.wait() + def run(args): # Configure logger based on whether launcher logs should be redirected base_log_file = getattr(args, 'ft_per_cycle_applog_prefix', None) diff --git a/src/nvidia_resiliency_ext/fault_tolerance/utils.py b/src/nvidia_resiliency_ext/fault_tolerance/utils.py index 585770c6..cc217f56 100644 --- a/src/nvidia_resiliency_ext/fault_tolerance/utils.py +++ b/src/nvidia_resiliency_ext/fault_tolerance/utils.py @@ -234,6 +234,16 @@ def is_slurm_job_array() -> bool: return os.getenv('SLURM_ARRAY_TASK_ID') is not None +def job_user_from_env() -> str: + """Read job user from SLURM_JOB_USER or USER env.""" + return os.environ.get("SLURM_JOB_USER") or os.environ.get("USER", "") or "" + + +def job_id_from_env() -> str: + """Read job id from SLURM_ARRAY_JOB_ID or SLURM_JOB_ID env.""" + return os.environ.get("SLURM_ARRAY_JOB_ID") or os.environ.get("SLURM_JOB_ID", "") or "" + + def get_log_aggregator_shard_index(num_aggregators: int) -> int: """Shard index in ``[0, num_aggregators)`` for first-level log aggregator selection. @@ -538,7 +548,9 @@ def exception_handler(exc_type, exc_value, exc_traceback): f"{rank_str}: Process will exit with code 1\n" ) - app_logger = logging.getLogger(__name__) + # Get logger from application context (not module-level logger) + # Use root logger to ensure we capture in application's logging context + app_logger = logging.getLogger() app_logger.error(error_msg) # Also print to stderr to ensure visibility diff --git a/tests/attribution/unit/test_api_keys.py b/tests/attribution/unit/test_api_keys.py index 32fdb59f..1604dd99 100644 --- a/tests/attribution/unit/test_api_keys.py +++ b/tests/attribution/unit/test_api_keys.py @@ -5,6 +5,7 @@ import asyncio import sys +import tempfile import unittest from unittest.mock import patch @@ -13,18 +14,18 @@ "Attribution package requires Python 3.10+ (e.g. dataclass(slots=True) in log_analyzer.job)." ) -from nvidia_resiliency_ext.attribution.api_keys import load_nvidia_api_key +from nvidia_resiliency_ext.attribution.api_keys import load_llm_api_key from nvidia_resiliency_ext.attribution.combined_log_fr.llm_merge import merge_log_fr_llm -class TestLoadNvidiaApiKey(unittest.TestCase): - def test_reads_and_strips_from_env(self): - with patch.dict("os.environ", {"NVIDIA_API_KEY": " sk-test "}): - self.assertEqual(load_nvidia_api_key(), "sk-test") +class TestLoadLlmApiKey(unittest.TestCase): + def test_reads_and_strips_llm_env(self): + with patch.dict("os.environ", {"LLM_API_KEY": " sk-test "}): + self.assertEqual(load_llm_api_key(), "sk-test") def test_returns_empty_when_unset_and_no_key_files(self): def getenv_side_effect(key: str, default=None): - if key in ("NVIDIA_API_KEY", "NVIDIA_API_KEY_FILE"): + if key in ("LLM_API_KEY", "LLM_API_KEY_FILE"): return None return default @@ -35,15 +36,27 @@ def getenv_side_effect(key: str, default=None): ), patch("nvidia_resiliency_ext.attribution.api_keys.os.path.isfile", return_value=False), ): - self.assertEqual(load_nvidia_api_key(), "") + self.assertEqual(load_llm_api_key(), "") + + def test_llm_api_key_file_used_when_set(self): + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".key") as f: + f.write("key-from-file\n") + path = f.name + try: + with patch.dict("os.environ", {"LLM_API_KEY_FILE": path}, clear=True): + self.assertEqual(load_llm_api_key(), "key-from-file") + finally: + import os as _os + + _os.unlink(path) class TestMergeLogFrLlm(unittest.TestCase): def test_raises_when_api_key_empty(self): async def run(): with self.assertRaises(ValueError) as ctx: - await merge_log_fr_llm("log", "fr", nvidia_api_key="", model="dummy-model") - self.assertIn("NVIDIA API key", str(ctx.exception)) + await merge_log_fr_llm("log", "fr", llm_api_key="", model="dummy-model") + self.assertIn("LLM API key", str(ctx.exception)) asyncio.run(run()) From 62d2d934e12bc2ce505766a2fdab9c6ee862fdb0 Mon Sep 17 00:00:00 2001 From: namitdhameja Date: Fri, 17 Apr 2026 13:30:58 -0700 Subject: [PATCH 2/6] timing perf --- .../attribution/ARCHITECTURE.md | 45 ++++++++ .../fault_tolerance/ft_rendezvous_barrier.py | 78 +++++++++++++ .../fault_tolerance/launcher.py | 105 +++++++++++++++++- 3 files changed, 226 insertions(+), 2 deletions(-) diff --git a/src/nvidia_resiliency_ext/attribution/ARCHITECTURE.md b/src/nvidia_resiliency_ext/attribution/ARCHITECTURE.md index f69e2960..0a8c967d 100644 --- a/src/nvidia_resiliency_ext/attribution/ARCHITECTURE.md +++ b/src/nvidia_resiliency_ext/attribution/ARCHITECTURE.md @@ -116,6 +116,51 @@ flowchart TB AP --> CLF ``` +### 2.4 Fault tolerance: rendezvous health checks vs monitor / `_handle_restart_decision` + +When analyzing **FT launcher** logs, it helps to keep two **separate control-flow trees** in mind. They share configuration (for example `rdzv_configs` / barrier flags) and can influence each other only indirectly (for example cluster counters visible to the launcher, or workers exiting after a rendezvous failure). They are **not** a single pipeline like “run all health checks, then call `_handle_restart_decision`.” + +**What they share:** both are anchored to the same **fault-tolerance cycle / round** model. A **workload cycle** ends when local worker processes exit or the job moves to the next coordinated round; **Tree A** runs (or re-runs) as nodes **enter or advance rendezvous** for that next step—so health checks often appear in logs **around cycle boundaries** when the stack is re-syncing. **Tree B** is a **continuous** monitor loop, but **`_handle_restart_decision`** is reached when that loop observes a **post-cycle** local outcome (**FAILED** / **UNHEALTHY** after workers stop) or **HEALTHY** workers reacting to **cluster signals** that also stem from other nodes finishing or aborting a cycle. Same “turn of the crank” in the job; different code paths implementing different responsibilities. + +- **Rendezvous / barrier tree** (`fault_tolerance/ft_rendezvous_barrier.py`): optional **NIC**, **distributed storage**, **path storage**, and **node health** checks run in the rendezvous / round-join path. Outcomes affect whether nodes participate in rounds, `unhealthy_count`, early termination, and related store keys. +- **Launcher monitor tree** (`fault_tolerance/launcher.py`): `time.sleep(monitor_interval)` → **`_monitor_workers`** (local worker / subprocess state via `PContext.wait`, not those rendezvous checks) → on **FAILED / UNHEALTHY** or on **HEALTHY** with **`num_nodes_waiting` / `peer_aborted_count`**, **`_handle_restart_decision`**. That path runs **`_run_attribution()`** (this package’s LogSage flow), progress checks, restart budget, then optionally **`_open_rendezvous_for_restart`** and **`_restart_workers`** (which may wait for **GPU memory reclaim** after a restart—not the same as NIC/storage checks above). + +```mermaid +flowchart TB + SHARED["Shared lifecycle anchor: FT cycle / round boundary\nlocal workers end or job advances to next coordinated round"] + + SHARED --> tree_rdzv + SHARED --> tree_launcher + + subgraph tree_rdzv [Tree A — rendezvous / infra checks] + CFG1["FT / rdzv config\nenable_nic_healthcheck, storage flags, …"] + BAR["Rendezvous + barrier\nft_rendezvous_barrier"] + NIC["NIC link check"] + DS["Distributed storage check"] + SP["Path storage checks"] + NH["Node health check\nget_node_health_check"] + CFG1 --> BAR + BAR --> NIC + BAR --> DS + BAR --> SP + BAR --> NH + BAR --> ROUNDS["Round join / store keys\nunhealthy_count, peer signals, …"] + end + + subgraph tree_launcher [Tree B — monitor loop and restart decision] + CFG2["WorkerSpec.monitor_interval,\nrdzv_handler for cluster reads"] + LOOP["sleep → _monitor_workers\nlocal worker state"] + HRD["_handle_restart_decision\n_run_attribution, progress, restarts"] + RESTART["_restart_workers / GPU reclaim wait\nwhen restarting"] + CFG2 --> LOOP + LOOP -->|FAILED or UNHEALTHY| HRD + LOOP -->|HEALTHY + nodes_waiting\nor peer_aborted| HRD + HRD --> RESTART + end + + ROUNDS -.->|cluster counters / rendezvous state| LOOP +``` + --- ## 3. Major subsystems diff --git a/src/nvidia_resiliency_ext/fault_tolerance/ft_rendezvous_barrier.py b/src/nvidia_resiliency_ext/fault_tolerance/ft_rendezvous_barrier.py index 63a7aefb..e20a989a 100644 --- a/src/nvidia_resiliency_ext/fault_tolerance/ft_rendezvous_barrier.py +++ b/src/nvidia_resiliency_ext/fault_tolerance/ft_rendezvous_barrier.py @@ -31,6 +31,7 @@ from torch.distributed import PrefixStore, Store from torch.distributed.elastic.events import NodeState, construct_and_record_rdzv_event +from torch.distributed.elastic.metrics.api import put_metric from torch.distributed.elastic.multiprocessing import SignalException from torch.distributed.elastic.rendezvous.api import ( RendezvousClosedError, @@ -77,6 +78,17 @@ log = logging.getLogger(LogConfig.name) +# When NVRX_FT_RDZV_NEXT_RENDEZVOUS_TIMING=1, log and emit torchelastic metrics for +# next_rendezvous phases (health, control IPC, barrier wait-open vs after-open). Independent +# of NVRX_FT_RESTART_DECISION_TIMING (launcher._handle_restart_decision only). +_FT_RDZV_NEXT_RENDEZVOUS_TIMING_ENV = "NVRX_FT_RDZV_NEXT_RENDEZVOUS_TIMING" + + +def _ft_rdzv_next_rendezvous_timing_enabled() -> bool: + v = os.environ.get(_FT_RDZV_NEXT_RENDEZVOUS_TIMING_ENV, "").strip().lower() + return v in ("1", "true", "yes", "on") + + # Sentinel domain_id written to a participant's slot when they withdraw. Any use of # participant data (_can_meet_segment_constraint, _assign_group_ranks) must exclude # participants with this domain_id. @@ -466,6 +478,9 @@ def __init__( # Only populated if segment is configured self._cached_domain_id: Optional[str] = None + # Populated when NVRX_FT_RDZV_NEXT_RENDEZVOUS_TIMING=1 after a successful perform_rendezvous + self._last_perform_rdzv_phases: Optional[Dict[str, float]] = None + # Key prefixes for the barrier self.prefix = f"ft_rendezvous_barrier:{run_id}" self.arrived_count_key = f"{self.prefix}:arrived_count" @@ -973,10 +988,17 @@ def perform_rendezvous( Returns: Tuple of (group_rank, total_participants) """ + rdzv_timing = _ft_rdzv_next_rendezvous_timing_enabled() + self._last_perform_rdzv_phases = None + # Step 0: Wait if rendezvous is closed (training in progress) # Hot spares arriving late will wait here until a failure opens a new round # Note: This also checks for permanent close (is_permanently_closed()), no need to check again + t_wait_open = time.perf_counter() self._wait_for_rendezvous_open(node_desc) + wait_open_s = time.perf_counter() - t_wait_open + + t_after_open = time.perf_counter() # Record start time for timeout monitoring # Start timing AFTER Step 0 completes, since hot spares may wait indefinitely at Step 0 @@ -1231,6 +1253,11 @@ def perform_rendezvous( log.debug( f"[{node_desc}] [Step 4] Received group rank {rank}, total participants {total_participants}" ) + if rdzv_timing: + self._last_perform_rdzv_phases = { + "rdzv_wait_open_s": wait_open_s, + "rdzv_after_open_s": time.perf_counter() - t_after_open, + } return rank, total_participants # Delay before next check @@ -1888,6 +1915,33 @@ def _perform_rendezvous(self) -> None: else: log.info(f"Node {self._this_node} assigned group rank {group_rank}") + def _emit_ft_rdzv_next_rendezvous_timing( + self, + phases: Dict[str, float], + total_s: float, + ) -> None: + """Log and export next_rendezvous phase timings (NVRX_FT_RDZV_NEXT_RENDEZVOUS_TIMING=1).""" + if not _ft_rdzv_next_rendezvous_timing_enabled(): + return + phase_str = " ".join(f"{k}={v:.4f}" for k, v in sorted(phases.items())) + log.info( + "[ft_rdzv_timing] node=%s run_id=%s rendezvous_round=%s assigned_rank=%s " + "is_store_host=%s total_s=%.4f %s", + self._this_node, + self._settings.run_id, + getattr(self, "_rendezvous_round", None), + self._assigned_rank, + self._barrier_state.is_store_host, + total_s, + phase_str, + ) + try: + put_metric("ft.rdzv.next_rendezvous.total_s", total_s) + for name, sec in phases.items(): + put_metric(f"ft.rdzv.next_rendezvous.{name}", sec) + except Exception: + log.debug("ft_rdzv_timing put_metric failed", exc_info=True) + def next_rendezvous(self) -> Union[RendezvousInfo, Tuple[Store, int, int]]: """See base class. @@ -1904,20 +1958,36 @@ def next_rendezvous(self) -> Union[RendezvousInfo, Tuple[Store, int, int]]: log.info(msg) prev_signal_handlers = _install_rdzv_signal_handlers() + timing_on = _ft_rdzv_next_rendezvous_timing_enabled() + t0_perf = time.perf_counter() + rdzv_phases: Dict[str, float] = {} + next_rdzv_success = False try: # Check node health and control requests before starting rendezvous health_check_start = time.monotonic() + t_h = time.perf_counter() self.ensure_node_is_healthy() health_check_elapsed = time.monotonic() - health_check_start + if timing_on: + rdzv_phases["health_s"] = time.perf_counter() - t_h log.debug( f"[{self._this_node}] Node health check completed in {health_check_elapsed:.3f}s" ) + t_c = time.perf_counter() self.handle_control_requests_from_rank() + if timing_on: + rdzv_phases["control_ipc_s"] = time.perf_counter() - t_c # Perform the complete rendezvous process # Stale round detection and sync happens automatically in _wait_for_rendezvous_open() + t_p = time.perf_counter() self._perform_rendezvous() + if timing_on: + rdzv_phases["perform_rdzv_wall_s"] = time.perf_counter() - t_p + inner = getattr(self._barrier_state, "_last_perform_rdzv_phases", None) + if inner: + rdzv_phases.update(inner) # Use stored rank and world size rank = self._assigned_rank @@ -1937,6 +2007,8 @@ def next_rendezvous(self) -> Union[RendezvousInfo, Tuple[Store, int, int]]: # Restore from the configured value in settings self._worker_group.spec.local_world_size = self._settings.nproc_per_node + next_rdzv_success = True + except Exception as e: self._record( message=f"{type(e).__name__}: {str(e)}", @@ -1950,6 +2022,12 @@ def next_rendezvous(self) -> Union[RendezvousInfo, Tuple[Store, int, int]]: self._barrier_state._maybe_withdraw_on_unwind() _restore_rdzv_signal_handlers(prev_signal_handlers) + if next_rdzv_success and timing_on: + self._emit_ft_rdzv_next_rendezvous_timing( + rdzv_phases, + time.perf_counter() - t0_perf, + ) + msg = ( f"The node '{self._this_node}' has joined the rendezvous " f"'{self._settings.run_id}' as rank {rank} in a world of size " diff --git a/src/nvidia_resiliency_ext/fault_tolerance/launcher.py b/src/nvidia_resiliency_ext/fault_tolerance/launcher.py index 677ca17b..266d61b5 100644 --- a/src/nvidia_resiliency_ext/fault_tolerance/launcher.py +++ b/src/nvidia_resiliency_ext/fault_tolerance/launcher.py @@ -132,6 +132,19 @@ # Note: Must call run() before using logger to ensure proper configuration logger = logging.getLogger(LogConfig.name) +# Set NVRX_FT_RESTART_DECISION_TIMING=1 to log and emit torchelastic metrics for phase +# durations inside _handle_restart_decision (attribution, progress, open_rendezvous, +# restart_workers). Used to quantify overlap opportunity before changing ordering. +# Does not touch next_rendezvous / barrier (see NVRX_FT_RDZV_NEXT_RENDEZVOUS_TIMING in +# ft_rendezvous_barrier.py). +_FT_RESTART_DECISION_TIMING_ENV = "NVRX_FT_RESTART_DECISION_TIMING" + + +def _ft_restart_decision_timing_enabled() -> bool: + v = os.environ.get(_FT_RESTART_DECISION_TIMING_ENV, "").strip().lower() + return v in ("1", "true", "yes", "on") + + _NODE_HEALTH_CHECK_INSTANCE: Optional[NodeHealthCheck] = None # Populated on TCP store host when gRPC log aggregation is enabled: root-only [Popen] or # [root, leaf0, ..., leaf_{N-1}] when ft_log_aggregator_count > 1. @@ -576,6 +589,42 @@ def run(self, role: str = DEFAULT_ROLE) -> RunResult: # record the execution time in case there were any exceptions during run. self._total_execution_time = int(time.monotonic() - start_time) + def _emit_ft_restart_decision_timing( + self, + role: str, + outcome: str, + phases: Dict[str, float], + t0_perf: float, + open_rendezvous_config: bool, + notify_peer_config: bool, + ) -> None: + """Log and export phase timings when NVRX_FT_RESTART_DECISION_TIMING is enabled.""" + if not _ft_restart_decision_timing_enabled(): + return + total_s = time.perf_counter() - t0_perf + gr: Any = None + if self._worker_group is not None: + gr = self._worker_group.group_rank + phase_str = " ".join(f"{k}={v:.4f}" for k, v in sorted(phases.items())) + logger.info( + "[ft_restart_timing] role=%s group_rank=%s is_store_host=%s outcome=%s " + "open_rendezvous=%s notify_peer=%s total_s=%.4f %s", + role, + gr, + getattr(self, "_is_store_host", False), + outcome, + open_rendezvous_config, + notify_peer_config, + total_s, + phase_str, + ) + try: + put_metric(f"ft.restart_decision.{role}.total_s", total_s) + for name, sec in phases.items(): + put_metric(f"ft.restart_decision.{role}.{name}", sec) + except Exception: + logger.debug("ft_restart_timing put_metric failed", exc_info=True) + def _open_rendezvous_for_restart(self): """Open rendezvous for restart when using barrier-based rendezvous. @@ -628,10 +677,17 @@ def _handle_restart_decision( Returns: True if restart was initiated, False if no restart (caller should call _stop_workers). """ + timing_on = _ft_restart_decision_timing_enabled() + t0_perf = time.perf_counter() + phases: Dict[str, float] = {} + peer_abort_incremented = False + t_pn = time.perf_counter() if notify_peer and hasattr(self._rdzv_handler, '_barrier_state'): self._rdzv_handler._barrier_state._increment_peer_aborted_count() peer_abort_incremented = True + if timing_on: + phases["peer_notify_s"] = time.perf_counter() - t_pn def _rollback_peer_abort_notify() -> None: nonlocal peer_abort_incremented @@ -641,8 +697,12 @@ def _rollback_peer_abort_notify() -> None: self._rdzv_handler._barrier_state._undo_peer_abort_notify() peer_abort_incremented = False - start = time.time() + start_wall = time.time() + t_attr = time.perf_counter() should_terminate_early = self._run_attribution() + if timing_on: + phases["attribution_s"] = time.perf_counter() - t_attr + if should_terminate_early: if self._ft_cfg.attribution_dry_run: logger.info( @@ -653,10 +713,21 @@ def _rollback_peer_abort_notify() -> None: else: logger.error("[%s] Attribution says do not restart; will not restart.", role) _rollback_peer_abort_notify() + self._emit_ft_restart_decision_timing( + role, + "veto_attribution", + phases, + t0_perf, + open_rendezvous, + notify_peer, + ) return False + t_prog = time.perf_counter() self._progress_tracker.analyze_previous_cycle() should_terminate_early = self._progress_tracker.should_terminate_early() + if timing_on: + phases["progress_s"] = time.perf_counter() - t_prog if should_terminate_early: logger.error( @@ -665,21 +736,51 @@ def _rollback_peer_abort_notify() -> None: role ) _rollback_peer_abort_notify() + self._emit_ft_restart_decision_timing( + role, + "veto_progress", + phases, + t0_perf, + open_rendezvous, + notify_peer, + ) return False elif self._remaining_restarts > 0: logger.info(log_msg, role) self._remaining_restarts -= 1 if open_rendezvous: + t_open = time.perf_counter() self._open_rendezvous_for_restart() - time_consumed = time.time() - start + if timing_on: + phases["open_rendezvous_s"] = time.perf_counter() - t_open + time_consumed = time.time() - start_wall + t_rw = time.perf_counter() self._restart_workers( self._worker_group, time_consumed_before_reclaim=time_consumed, ) + if timing_on: + phases["restart_workers_s"] = time.perf_counter() - t_rw + self._emit_ft_restart_decision_timing( + role, + "restart", + phases, + t0_perf, + open_rendezvous, + notify_peer, + ) return True else: # No more restarts available _rollback_peer_abort_notify() + self._emit_ft_restart_decision_timing( + role, + "veto_no_restarts", + phases, + t0_perf, + open_rendezvous, + notify_peer, + ) return False def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: From 93314e1cd52a27faca5c1f9720e1f74e9ab85d8e Mon Sep 17 00:00:00 2001 From: namitdhameja Date: Sat, 18 Apr 2026 09:23:16 -0700 Subject: [PATCH 3/6] fixes n bugs --- .../train_ddp_heartbeats_api.py | 58 +++++++++++++++---- .../fault_tolerance/launcher.py | 45 ++++++++++++-- 2 files changed, 88 insertions(+), 15 deletions(-) diff --git a/examples/fault_tolerance/train_ddp_heartbeats_api.py b/examples/fault_tolerance/train_ddp_heartbeats_api.py index cbb85c74..3af2ee44 100644 --- a/examples/fault_tolerance/train_ddp_heartbeats_api.py +++ b/examples/fault_tolerance/train_ddp_heartbeats_api.py @@ -21,6 +21,9 @@ `ft_launcher --nproc-per-node=2 --ft-cfg-path=./examples/fault_tolerance/fault_tol_cfg_heartbeats.yaml examples/fault_tolerance/train_ddp_heartbeats_api.py --device=cpu` +For a **fast** fault-injection smoke test (seconds instead of a full pass over the train set), shrink the dataset and use the quick path, for example: +`... train_ddp_heartbeats_api.py --train_dataset_size=512 --epochs=4 --simulated_fault=rank_killed,2 --quick-simulated-fault --simulated-fault-jitter-sec=0` + Fault tolerance features demonstrated: 1. Heartbeat sending during training 2. Timeout calculation and setting @@ -130,6 +133,19 @@ def fault_desc(strings): parser.add_argument('--simulated_fault', type=fault_desc, help='Description of a fault to be simulated') + parser.add_argument( + '--quick-simulated-fault', + action='store_true', + help='With --simulated_fault: compute HB timeouts after a few iterations of epoch 0 ' + 'instead of waiting until epoch 1 / iter 1 (which implies one full training epoch first).', + ) + parser.add_argument( + '--simulated-fault-jitter-sec', + type=float, + default=4.0, + help='Extra random delay in [0, value) seconds added on top of the delay in --simulated_fault. ' + 'Set to 0 for a deterministic fault time.', + ) # fmt: on args = parser.parse_args() @@ -200,11 +216,22 @@ def training_loop( last_log_time = time.monotonic() for iter_idx, x in enumerate(dataloader, start=progress['iter_idx']): - if ft_client.hb_timeouts.are_valid is False and epoch_idx == 1 and iter_idx == 1: - # after 0th epoch is completed and we've done 0th iteration of the 1st epoch, - # we can calculate and set timeouts. this is a good moment to do so, - # because now we've seen the possibly long interval where checkpoint was saved. - ft_client.calculate_and_set_hb_timeouts() + if not ft_client.hb_timeouts.are_valid: + if ( + args.quick_simulated_fault + and args.simulated_fault + and epoch_idx == 0 + and iter_idx == 2 + ): + # Enough heartbeats for timeout calc without finishing a full epoch first. + # Requires >= 3 batches in this epoch (iter 0, 1, 2); use --train_dataset_size accordingly. + ft_client.calculate_and_set_hb_timeouts(skip_if_not_ready=True) + elif epoch_idx == 1 and iter_idx == 1: + # after 0th epoch is completed and we've done 0th iteration of the 1st epoch, + # we can calculate and set timeouts. this is a good moment to do so, + # because now we've seen the possibly long interval where checkpoint was saved. + ft_client.calculate_and_set_hb_timeouts() + _maybe_setup_simulated_fault(ft_client, args, device) optimizer.zero_grad() x = x.to(device) @@ -273,7 +300,20 @@ def _cancel_simulated_fault(): _sim_fault_canceled = True -def _setup_simulated_fault(ft_client, fault_desc, device): +def _maybe_setup_simulated_fault(ft_client, args, device) -> None: + """Arm simulated fault once heartbeat timeouts are valid (see training_loop).""" + if not args.simulated_fault or _sim_fault_is_set: + return + if ft_client.hb_timeouts.are_valid: + _setup_simulated_fault( + ft_client, + args.simulated_fault, + device, + jitter_sec=args.simulated_fault_jitter_sec, + ) + + +def _setup_simulated_fault(ft_client, fault_desc, device, jitter_sec: float = 4.0): # FIXME: hanging rank with SIGTSTP results in rank monitor # blocked when trying to receive the data in _on_ipc_data_from_rank @@ -306,7 +346,7 @@ def _setup_simulated_fault(ft_client, fault_desc, device): else: raise Exception(f"Unknown fault type {fault_type}") - delay = fault_desc['delay'] + 4.0 * rng.random() + delay = fault_desc['delay'] + jitter_sec * rng.random() logging.info( f"Selected fault={fault_type}; target rank={rank_to_fail}; delay={delay}", @@ -514,10 +554,6 @@ def main(): logging.info('Leaving the main loop, due to SIGTERM') break - # Setup simulated fault as soon as we have valid timeouts - if args.simulated_fault and not _sim_fault_is_set and ft_client.hb_timeouts.are_valid: - _setup_simulated_fault(ft_client, args.simulated_fault, device) - _cancel_simulated_fault() ft_client.shutdown_workload_monitoring() torch.distributed.destroy_process_group() diff --git a/src/nvidia_resiliency_ext/fault_tolerance/launcher.py b/src/nvidia_resiliency_ext/fault_tolerance/launcher.py index 266d61b5..e0dcd851 100644 --- a/src/nvidia_resiliency_ext/fault_tolerance/launcher.py +++ b/src/nvidia_resiliency_ext/fault_tolerance/launcher.py @@ -436,6 +436,11 @@ def __init__( self._children_pgids: Set[int] = set() self._restart_policy = restart_policy self._node_id = self._get_fq_hostname() + # Last peer_aborted_count value we already reacted to (restart path). Without this, + # notify_peer's store counter can stay >0 until the next full barrier clear, and a + # single-node agent would treat its own stale signal as a new cluster change every + # monitor tick after a local restart. + self._last_peer_aborted_observed: int = 0 DEFAULT_ROLE = "default" # FIXME @@ -836,6 +841,7 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes notify_peer=True, ) if should_restart: + self._last_peer_aborted_observed = self._check_cluster_peer_aborted_count() continue self._stop_workers(self._worker_group) self._worker_group.state = WorkerState.FAILED @@ -850,7 +856,11 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes peer_aborted_count = self._check_cluster_peer_aborted_count() group_rank = self._worker_group.group_rank - if num_nodes_waiting > 0 or peer_aborted_count > 0: + if peer_aborted_count < self._last_peer_aborted_observed: + self._last_peer_aborted_observed = peer_aborted_count + peer_aborted_increased = peer_aborted_count > self._last_peer_aborted_observed + + if num_nodes_waiting > 0 or peer_aborted_increased: # Record failure detection event record_profiling_event( ProfilingEvent.FAILURE_DETECTED, @@ -869,6 +879,7 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes notify_peer=False, ) if should_restart: + self._last_peer_aborted_observed = self._check_cluster_peer_aborted_count() continue self._stop_workers(self._worker_group) self._worker_group.state = WorkerState.FAILED @@ -1127,6 +1138,13 @@ def _stop_workers( ) logger.info(f"Stopping workers... Timeout = {self._workers_stop_timeout} sec.") + # Snapshot handlers before _shutdown clears _pcontext (avoids double-close when run() + # finally calls _shutdown again; handlers are still valid for pipe cleanup). + pipe_handlers: List[Any] = [] + if isinstance(self._logs_specs, PipeBasedLogsSpecs) and self._pcontext is not None: + if hasattr(self._pcontext, "subprocess_handlers"): + pipe_handlers = list(self._pcontext.subprocess_handlers.values()) + # Rank monitors will detect worker shutdown when worker processes disconnect self._shutdown(timeout=self._workers_stop_timeout) @@ -1170,7 +1188,7 @@ def _stop_workers( # then close pipe file objects to prevent FD reuse bugs if isinstance(self._logs_specs, PipeBasedLogsSpecs): time.sleep(0.3) - for handler in self._pcontext.subprocess_handlers.values(): + for handler in pipe_handlers: for stream in (handler.proc.stdout, handler.proc.stderr): if stream: try: @@ -1333,6 +1351,14 @@ def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: start_method=self._start_method, ) + # _restart_workers calls _start_workers without _initialize_workers; Worker.id must + # match SubprocessContext/MultiprocessContext pids or the next _monitor_workers fails. + pids_by_local_rank = self._pcontext.pids() + for worker in worker_group.workers: + new_pid = pids_by_local_rank.get(worker.local_rank) + if new_pid is not None: + worker.id = new_pid + self._children_pgids = {os.getpgid(p) for p in self._pcontext.pids().values()} # Start reader thread for pipe-based logging if using PipeBasedLogsSpecs @@ -1461,7 +1487,12 @@ def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM, timeout: int = 3 self._worker_watchdog.stop() self._worker_watchdog = None if self._pcontext: - self._pcontext.close(death_sig, timeout=timeout) + try: + self._pcontext.close(death_sig, timeout=timeout) + finally: + # Idempotent second _shutdown from run() finally; avoids AssertionError from + # closing the same SubprocessContext twice. + self._pcontext = None # Best-effort cleanup for orphan descendants in worker process groups. # PID=1 can become parent when original worker parent exits. terminate_mp_processes(allowed_pgids=self._children_pgids) @@ -1472,7 +1503,13 @@ def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM, timeout: int = 3 def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult: role = worker_group.spec.role worker_pids = {w.id for w in worker_group.workers} - assert self._pcontext is not None + if self._pcontext is None: + logger.warning( + "[%s] _monitor_workers: process context is gone (workers already stopped); " + "returning FAILED.", + role, + ) + return RunResult(state=WorkerState.FAILED) pc_pids = set(self._pcontext.pids().values()) if worker_pids != pc_pids: logger.error( From 170aff86e5b706ec0544807307991e476ef4dbfb Mon Sep 17 00:00:00 2001 From: namitdhameja Date: Sat, 18 Apr 2026 10:03:02 -0700 Subject: [PATCH 4/6] fix cycle-id for local launch --- .../train_ddp_heartbeats_api.py | 28 ++++++++++++++- .../fault_tolerance/launcher.py | 36 ++++++++++++++----- 2 files changed, 55 insertions(+), 9 deletions(-) diff --git a/examples/fault_tolerance/train_ddp_heartbeats_api.py b/examples/fault_tolerance/train_ddp_heartbeats_api.py index 3af2ee44..15c32606 100644 --- a/examples/fault_tolerance/train_ddp_heartbeats_api.py +++ b/examples/fault_tolerance/train_ddp_heartbeats_api.py @@ -28,7 +28,8 @@ 1. Heartbeat sending during training 2. Timeout calculation and setting 3. State persistence through checkpoints -4. Simulated fault injection +4. Simulated fault injection (by default only on the first worker start; see + ``--simulated-fault-on-every-restart``). """ import argparse @@ -146,6 +147,13 @@ def fault_desc(strings): help='Extra random delay in [0, value) seconds added on top of the delay in --simulated_fault. ' 'Set to 0 for a deterministic fault time.', ) + parser.add_argument( + '--simulated-fault-on-every-restart', + action='store_true', + help='With --simulated_fault: re-arm the fault after each ft_launcher worker restart ' + '(new processes). Default is off so fault injection runs only on the initial attempt ' + '(TORCHELASTIC_RESTART_COUNT==0) and later cycles continue training.', + ) # fmt: on args = parser.parse_args() @@ -293,6 +301,7 @@ def validation_loop(ft_client, model, val_dataloader, epoch_idx, device): _sim_fault_canceled = False _sim_fault_is_set = False +_logged_sim_fault_skip_on_restart = False def _cancel_simulated_fault(): @@ -302,8 +311,25 @@ def _cancel_simulated_fault(): def _maybe_setup_simulated_fault(ft_client, args, device) -> None: """Arm simulated fault once heartbeat timeouts are valid (see training_loop).""" + global _logged_sim_fault_skip_on_restart if not args.simulated_fault or _sim_fault_is_set: return + # After ft_launcher restarts workers, each rank is a new Python process: module globals + # like _sim_fault_is_set reset to False. Without this guard, --simulated_fault would + # re-arm every cycle and kill again. TORCHELASTIC_RESTART_COUNT is set by the launcher + # (0 = first attempt, >0 after a fault-tolerance restart). + if not args.simulated_fault_on_every_restart: + restart_cnt = int(os.environ.get('TORCHELASTIC_RESTART_COUNT', '0')) + if restart_cnt > 0: + if not _logged_sim_fault_skip_on_restart: + logging.info( + 'TORCHELASTIC_RESTART_COUNT=%s: not re-arming --simulated_fault so training ' + 'can continue after this restart (use --simulated-fault-on-every-restart ' + 'to inject on every new worker process).', + restart_cnt, + ) + _logged_sim_fault_skip_on_restart = True + return if ft_client.hb_timeouts.are_valid: _setup_simulated_fault( ft_client, diff --git a/src/nvidia_resiliency_ext/fault_tolerance/launcher.py b/src/nvidia_resiliency_ext/fault_tolerance/launcher.py index e0dcd851..b0747af4 100644 --- a/src/nvidia_resiliency_ext/fault_tolerance/launcher.py +++ b/src/nvidia_resiliency_ext/fault_tolerance/launcher.py @@ -525,6 +525,16 @@ def _get_global_restart_count(self) -> int: """ return max(0, self._rdzv_handler.round() - 1) + def _elastic_worker_attempt_index(self, spec: WorkerSpec) -> int: + """Torchelastic-style worker attempt index (matches stock ``TORCHELASTIC_RESTART_COUNT``). + + Increments each time the agent consumes a restart slot and starts a new worker group. + Used for per-attempt log paths (``*_cycle{N}.log``) and worker env. Unlike + :meth:`_get_global_restart_count` (derived from barrier rendezvous ``round()``), this + advances on local-only FT restarts so cycle log files do not all map to ``_cycle0``. + """ + return spec.max_restarts - self._remaining_restarts + def _on_cycle_end(self) -> None: """Record cycle end time in cycle info file.""" if self._cycle_info_writer is None: @@ -1112,8 +1122,10 @@ def _run_attribution(self) -> bool: return False cycle_log_file = None if hasattr(self._logs_specs, "get_cycle_log_file"): - current_cycle = self._get_global_restart_count() - cycle_log_file = self._logs_specs.get_cycle_log_file(current_cycle) + spec = self._worker_group.spec + cycle_log_file = self._logs_specs.get_cycle_log_file( + self._elastic_worker_attempt_index(spec) + ) if cycle_log_file is None: return False return self._log_analysis_client.should_stop(cycle_log_file) @@ -1223,12 +1235,11 @@ def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: store = worker_group.store assert store is not None - # Get the current cycle number from the rendezvous handler - # At this point, rendezvous has completed and we're about to start workers. - # The cycle number is used for profiling and environment variable setting. - # Note: We use _get_global_restart_count() because in the context of worker startup, - # the "restart_count" variable name is used to mean "cycle number" (for historical reasons). - current_cycle = restart_count = self._get_global_restart_count() # Actually cycle number + # Per-attempt index for workers, rank monitors, and PipeBasedLogsSpecs (``*_cycle{N}.log``). + # Must match stock torchelastic ``TORCHELASTIC_RESTART_COUNT = max_restarts - remaining`` so + # ``reify()`` sees a new suffix after each local restart; do not use _get_global_restart_count() + # here (that follows barrier ``round()`` and stays flat across worker-only restarts). + current_cycle = restart_count = self._elastic_worker_attempt_index(spec) # Send current cycle number to rank monitors for logging self._send_cycle_to_rank_monitors(restart_count) @@ -1919,6 +1930,15 @@ def launch_agent( return None if result.is_failed(): + # ChildFailedError.__init__ in PyTorch asserts failures is non-empty; the agent can + # return WorkerState.FAILED with an empty failures dict (e.g. max restarts exhausted + # after _stop_workers, no ProcessFailure payload). Do not raise ChildFailedError then. + if not result.failures: + raise RuntimeError( + f"{entrypoint_name}: worker group ended in {result.state.name} with no " + "per-rank failure records (typical when restarts are exhausted or the agent " + "stopped workers without a torchelastic error file)." + ) # ChildFailedError is treated specially by @record # if the error files for the failed children exist # @record will copy the first error (root cause) From 2bd52a5f48b15a97af008d22b2ee3aa3c3ea6986 Mon Sep 17 00:00:00 2001 From: namitdhameja Date: Sun, 19 Apr 2026 19:17:20 -0700 Subject: [PATCH 5/6] graceful_stop_requested --- .../fault_tolerance/launcher.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/nvidia_resiliency_ext/fault_tolerance/launcher.py b/src/nvidia_resiliency_ext/fault_tolerance/launcher.py index b0747af4..fcdcc572 100644 --- a/src/nvidia_resiliency_ext/fault_tolerance/launcher.py +++ b/src/nvidia_resiliency_ext/fault_tolerance/launcher.py @@ -441,6 +441,11 @@ def __init__( # single-node agent would treat its own stale signal as a new cluster change every # monitor tick after a local restart. self._last_peer_aborted_observed: int = 0 + # Set to True when attribution or progress tracker decides the job should stop + # permanently (not a node failure). Checked by callers of _handle_restart_decision + # to raise RendezvousGracefulExitError (exit 0) instead of RunResult(FAILED) (exit 1), + # preventing SLURM from requeueing a job that has a job-level stop signal. + self._graceful_stop_requested: bool = False DEFAULT_ROLE = "default" # FIXME @@ -728,6 +733,11 @@ def _rollback_peer_abort_notify() -> None: else: logger.error("[%s] Attribution says do not restart; will not restart.", role) _rollback_peer_abort_notify() + # Attribution is a job-level stop: broadcast tombstone so all nodes exit + # cleanly rather than re-entering the restart loop. + if hasattr(self._rdzv_handler, '_barrier_state'): + self._rdzv_handler._barrier_state.set_permanently_closed() + self._graceful_stop_requested = True self._emit_ft_restart_decision_timing( role, "veto_attribution", @@ -751,6 +761,11 @@ def _rollback_peer_abort_notify() -> None: role ) _rollback_peer_abort_notify() + # Progress tracker is a job-level stop: broadcast tombstone so all nodes exit + # cleanly rather than re-entering the restart loop. + if hasattr(self._rdzv_handler, '_barrier_state'): + self._rdzv_handler._barrier_state.set_permanently_closed() + self._graceful_stop_requested = True self._emit_ft_restart_decision_timing( role, "veto_progress", @@ -854,6 +869,11 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes self._last_peer_aborted_observed = self._check_cluster_peer_aborted_count() continue self._stop_workers(self._worker_group) + if self._graceful_stop_requested: + self._graceful_stop_requested = False + raise RendezvousGracefulExitError( + "Job-level stop (attribution or progress tracker): job will not restart." + ) self._worker_group.state = WorkerState.FAILED return RunResult(state=WorkerState.FAILED) elif state == WorkerState.HEALTHY: @@ -892,6 +912,11 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes self._last_peer_aborted_observed = self._check_cluster_peer_aborted_count() continue self._stop_workers(self._worker_group) + if self._graceful_stop_requested: + self._graceful_stop_requested = False + raise RendezvousGracefulExitError( + "Job-level stop (attribution or progress tracker): job will not restart." + ) self._worker_group.state = WorkerState.FAILED return RunResult(state=WorkerState.FAILED) else: From f4e3a45f329eb14e1a38049c65ef5b2081707ae8 Mon Sep 17 00:00:00 2001 From: namitdhameja Date: Tue, 21 Apr 2026 14:21:12 -0700 Subject: [PATCH 6/6] base llm model cfg support --- .../attribution/log_analyzer/config.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/nvidia_resiliency_ext/attribution/log_analyzer/config.py b/src/nvidia_resiliency_ext/attribution/log_analyzer/config.py index 8902a212..bacfce46 100644 --- a/src/nvidia_resiliency_ext/attribution/log_analyzer/config.py +++ b/src/nvidia_resiliency_ext/attribution/log_analyzer/config.py @@ -17,12 +17,13 @@ (see ``DEFAULT_COMPUTE_TIMEOUT_SECONDS`` in ``attribution.coalescing``); :class:`~nvidia_resiliency_ext.attribution.analyzer.engine.Analyzer` accepts ``compute_timeout`` / ``grace_period_seconds`` for the coalescer. """ +import os from dataclasses import dataclass from enum import Enum -# LLM defaults -DEFAULT_LLM_MODEL = "nvidia/qwen/qwen-235b" -DEFAULT_LLM_BASE_URL = "https://inference-api.nvidia.com/v1" +# LLM defaults — override with NVRX_LLM_MODEL / NVRX_LLM_BASE_URL env vars +DEFAULT_LLM_MODEL = os.environ.get("NVRX_LLM_MODEL", "nvidia/qwen/qwen-235b") +DEFAULT_LLM_BASE_URL = os.environ.get("NVRX_LLM_BASE_URL", "https://inference-api.nvidia.com/v1") # TTL constants (see spec Section 3.2) TTL_PENDING_SECONDS = 7 * 24 * 60 * 60 # 1 week - pending job expiry