From abc9fb560593724cbd10371447857519ca3997a8 Mon Sep 17 00:00:00 2001 From: namitdhameja Date: Thu, 26 Feb 2026 09:14:49 -0800 Subject: [PATCH] Fault tolerance integration with Log_analysis attribution module --- docs/source/fault_tolerance/usage_guide.rst | 76 ++++- pyproject.toml | 7 + services/nvrx_attrsvc/NVRX_ATTRSVC_SPEC.md | 10 +- services/nvrx_attrsvc/README.md | 3 +- services/nvrx_attrsvc/config.py | 29 +- services/nvrx_attrsvc/dataflow.py | 66 ---- .../attribution/log_analyzer/config.py | 1 + .../attribution/log_analyzer/runner.py | 169 ++++++++++ .../attribution/log_analyzer/utils.py | 69 +++- .../attribution/postprocessing/__init__.py | 14 +- .../attribution/postprocessing/base.py | 6 +- .../attribution/postprocessing/config.py | 102 +++++- .../attribution/postprocessing/dataflow.py | 72 +++++ .../attribution/postprocessing/slack.py | 23 +- .../attribution/utils.py | 10 +- .../fault_tolerance/__init__.py | 2 + .../fault_tolerance/config.py | 85 ++++- .../fault_tolerance/ft_attribution.py | 298 ++++++++++++++++++ .../fault_tolerance/ft_rendezvous_barrier.py | 51 ++- .../fault_tolerance/launcher.py | 237 ++++++++++---- .../fault_tolerance/utils.py | 10 + .../shared_utils/health_check.py | 102 +----- 22 files changed, 1112 insertions(+), 330 deletions(-) delete mode 100644 services/nvrx_attrsvc/dataflow.py create mode 100644 src/nvidia_resiliency_ext/attribution/log_analyzer/runner.py create mode 100644 src/nvidia_resiliency_ext/attribution/postprocessing/dataflow.py create mode 100644 src/nvidia_resiliency_ext/fault_tolerance/ft_attribution.py diff --git a/docs/source/fault_tolerance/usage_guide.rst b/docs/source/fault_tolerance/usage_guide.rst index 68889da7..bc1751d9 100644 --- a/docs/source/fault_tolerance/usage_guide.rst +++ b/docs/source/fault_tolerance/usage_guide.rst @@ -117,33 +117,79 @@ Validation behavior: - Other existing types (e.g., devices/symlinks): performs ``stat`` access -Attribution service integration -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Attribution integration +^^^^^^^^^^^^^^^^^^^^^^ -Enable artifact analysis (e.g., logs) during rendezvous health checks by pointing to a running attribution service. -The feature is enabled by specifying both host and port. +Enable artifact analysis (e.g., logs) during rendezvous to make RESTART/STOP decisions. +Use ``--ft-attribution-loganalysis [lib|mcp|url]`` (default: lib) for integration flexibility: -* CLI: +* ``lib`` (default): Direct calling via API in-process. +* ``mcp``: Log analysis in a separate MCP subprocess. +* ``url``: HTTP attribution service (host:port or http(s)://host:port). - - ``--ft-attrsvc-host `` (alias: ``--ft_attrsvc_host``) - - ``--ft-attrsvc-port `` (alias: ``--ft_attrsvc_port``) +* CLI: - Example: + - ``--ft-attribution-loganalysis`` (alias: ``--ft_attribution_loganalysis``): Enable log analysis attribution. + Accepts ``lib``, ``mcp``, or a URL string. No value = lib (default). + - ``--ft-attribution-timeout`` (alias: ``--ft_attribution_timeout``): Wait/timeout in seconds; + skip result if exceeded (default: 60). + - ``--ft-attribution-dry-run`` (alias: ``--ft_attribution_dry_run``): Dry run. Run the full + attribution chain (log analysis, Slack, dataflow) but do not apply the restart/stop decision. + Log what would happen instead. Useful for validating the pipeline without affecting behavior. + - ``--ft-slack-token-file`` (alias: ``--ft_slack_token_file``): Path to file containing Slack bot token. + When not set, uses ``SLACK_BOT_TOKEN`` or ``SLACK_BOT_TOKEN_FILE`` env vars. + - ``--ft-slack-channel`` (alias: ``--ft_slack_channel``): Slack channel for alerts. + When not set, uses ``SLACK_CHANNEL`` env var. + - ``--ft-dataflow-index`` (alias: ``--ft_dataflow_index``): Elasticsearch/dataflow index for posting + attribution results (lib/mcp only). Requires ``nvdataflow`` (install via ``pip install nvidia-resiliency-ext[dataflow]``). + When not set, dataflow posting is disabled. + + Examples: .. code-block:: bash - ft_launcher \ - --ft-attrsvc-host 127.0.0.1 \ - --ft-attrsvc-port 8000 \ - train.py + # Lib mode (in-process); default + ft_launcher --ft-attribution-loganalysis train.py + ft_launcher --ft-attribution-loganalysis lib train.py + + # MCP mode (log analysis in separate subprocess) + ft_launcher --ft-attribution-loganalysis mcp train.py + + # URL mode (HTTP attribution service) + ft_launcher --ft-attribution-loganalysis http://127.0.0.1:8000 train.py + + # Service with custom timeout + ft_launcher --ft-attribution-loganalysis http://127.0.0.1:8000 --ft-attribution-timeout 90 train.py + + # Lib mode with Slack and dataflow (token from file; channel from env) + ft_launcher --ft-attribution-loganalysis lib --ft-slack-token-file /etc/secrets/slack-token train.py + + # Lib mode with explicit Slack channel and dataflow index + ft_launcher --ft-attribution-loganalysis lib \ + --ft-slack-token-file /etc/secrets/slack-token --ft-slack-channel "#alerts" \ + --ft-dataflow-index my-attribution-index train.py + + # Dry run: exercise full attribution chain without applying restart/stop decision + ft_launcher --ft-attribution-loganalysis lib --ft-attribution-dry-run train.py -* YAML: under the ``fault_tolerance`` section +* YAML: under the ``fault_tolerance`` section use ``attribution_loganalysis``, ``attribution_timeout_seconds``, + ``slack``, and ``dataflow_index``: .. code-block:: yaml fault_tolerance: - attrsvc_host: "127.0.0.1" - attrsvc_port: 8000 + attribution_loganalysis: "lib" # or "mcp", or "http://127.0.0.1:8000" for service + attribution_timeout_seconds: 60 + attribution_dry_run: false # true = run chain but don't apply action; log only + slack: + bot_token_file: "/etc/secrets/slack-token" # or bot_token for inline (less secure) + channel: "#alerts" + dataflow_index: "my-attribution-index" # optional; requires nvdataflow + +* Environment (fallback when CLI/YAML not set): + + - ``SLACK_BOT_TOKEN`` or ``SLACK_BOT_TOKEN_FILE``: Slack bot token for lib/mcp alerts. + - ``SLACK_CHANNEL``: Slack channel for alerts. GPU Memory Reclaim ^^^^^^^^^^^^^^^^^^ diff --git a/pyproject.toml b/pyproject.toml index 738dc804..7bb30c37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,8 +49,15 @@ defusedxml = "*" langchain-nvidia-ai-endpoints = ">=0.3.15" mcp = ">=1.15.0" logsage = ">=0.1.5" +slack-bolt = ">=1.23.0" +slack-sdk = ">=3.35.0" +nvdataflow = {version = "*", optional = true} grpcio = "^1.76.0" grpcio-tools = "^1.76.0" +httpx = ">=0.24.0" + +[tool.poetry.extras] +dataflow = ["nvdataflow"] [tool.poetry.scripts] ft_launcher = "nvidia_resiliency_ext.fault_tolerance.launcher:main" diff --git a/services/nvrx_attrsvc/NVRX_ATTRSVC_SPEC.md b/services/nvrx_attrsvc/NVRX_ATTRSVC_SPEC.md index 43d75e6a..f5e2ddd2 100644 --- a/services/nvrx_attrsvc/NVRX_ATTRSVC_SPEC.md +++ b/services/nvrx_attrsvc/NVRX_ATTRSVC_SPEC.md @@ -278,6 +278,7 @@ src/nvidia_resiliency_ext/attribution/ │ ├── __init__.py # Exports config, configure(), ResultPoster, post_results, Slack │ ├── config.py # PostprocessingConfig singleton and configure() │ ├── base.py # ResultPoster, post_results (generic framework) +│ ├── dataflow.py # nvdataflow posting (post, get_nvdataflow_post_fn) │ └── slack.py # Slack notifications for terminal failures │ └── mcp_integration/ # MCP client/server for LLM communication @@ -306,7 +307,7 @@ File organization by layer: app.py POSTPROCESSING: - config.setup() wires lib postprocessing (ResultPoster(dataflow.post), Slack); dataflow.py + config.setup() wires lib postprocessing (ResultPoster(post_fn=postprocessing.dataflow.post), Slack) PYTHON API REFERENCE: @@ -2477,14 +2478,15 @@ This is optional and proprietary - implemented in separate module for easy exclusion or replacement. Files (see section 2 PROJECT STRUCTURE): - - nvrx_attrsvc/dataflow.py # Elasticsearch posting via nvdataflow + - nvidia_resiliency_ext.attribution.postprocessing.dataflow # nvdataflow posting (post, get_nvdataflow_post_fn) - nvrx_attrsvc/config.py setup() # Wires lib postprocessing via configure(poster, cluster_name, dataflow_index, slack_*) - nvidia_resiliency_ext/attribution/postprocessing/ # config, configure(), ResultPoster, post_results, Slack Configuration: - CLUSTER_NAME, DATAFLOW_INDEX: env prefix NVRX_ATTRSVC_ (e.g. NVRX_ATTRSVC_CLUSTER_NAME) - - SLACK_BOT_TOKEN, SLACK_CHANNEL: no prefix (env vars SLACK_BOT_TOKEN, SLACK_CHANNEL) - - If DATAFLOW_INDEX empty, dataflow posting disabled; if SLACK_BOT_TOKEN empty, Slack disabled + - SLACK_BOT_TOKEN, SLACK_BOT_TOKEN_FILE, SLACK_CHANNEL: no prefix + - SLACK_BOT_TOKEN_FILE takes precedence (path to file containing token) + - If DATAFLOW_INDEX empty, dataflow posting disabled; if Slack token empty, Slack disabled - Slack notifications sent for auto_resume = "STOP - DONT RESTART IMMEDIATE" When triggered: diff --git a/services/nvrx_attrsvc/README.md b/services/nvrx_attrsvc/README.md index 15ceb15e..1370c401 100644 --- a/services/nvrx_attrsvc/README.md +++ b/services/nvrx_attrsvc/README.md @@ -53,6 +53,7 @@ Environment variables (prefix: `NVRX_ATTRSVC_`): | Variable | Default | Description | |----------|---------|-------------| | `SLACK_BOT_TOKEN` | `""` | Slack bot OAuth token (empty = disabled) | +| `SLACK_BOT_TOKEN_FILE` | `""` | Path to file containing token (preferred over SLACK_BOT_TOKEN) | | `SLACK_CHANNEL` | `""` | Slack channel for terminal failure alerts | When configured, sends alerts to Slack for jobs with `auto_resume = "STOP - DONT RESTART IMMEDIATE"`. @@ -271,7 +272,7 @@ asyncio.run(main()) | `app.py` | FastAPI routes and middleware | | `service.py` | `AttributionService` - wraps LogAnalyzer | | `config.py` | `Settings` (pydantic), `setup()` wires postprocessing (poster + Slack) from cfg | -| `dataflow.py` | NVIDIA-proprietary Elasticsearch posting | +| lib `postprocessing.dataflow` | nvdataflow Elasticsearch posting (used via ResultPoster) | | `deploy/run_attrsvc.sh` | Run service with logging (background) | | `deploy/snapshot_attrsvc.sh` | Periodic endpoint snapshot for debugging | | `deploy/Dockerfile` | Docker build instructions | diff --git a/services/nvrx_attrsvc/config.py b/services/nvrx_attrsvc/config.py index 75ace385..418aa197 100644 --- a/services/nvrx_attrsvc/config.py +++ b/services/nvrx_attrsvc/config.py @@ -71,12 +71,17 @@ class Settings(BaseSettings): default="", description="Dataflow/elasticsearch index for posting results" ) - # Slack integration (optional - set SLACK_BOT_TOKEN to enable; env vars have no NVRX_ATTRSVC_ prefix) + # Slack integration (optional - set SLACK_BOT_TOKEN or SLACK_BOT_TOKEN_FILE to enable; no NVRX_ATTRSVC_ prefix) SLACK_BOT_TOKEN: str = Field( default="", description="Slack bot token (empty = disabled)", validation_alias="SLACK_BOT_TOKEN", ) + SLACK_BOT_TOKEN_FILE: str = Field( + default="", + description="Path to file containing Slack bot token (preferred over SLACK_BOT_TOKEN)", + validation_alias="SLACK_BOT_TOKEN_FILE", + ) SLACK_CHANNEL: str = Field( default="", description="Slack channel for alerts", @@ -217,19 +222,19 @@ def setup() -> Settings: logging.getLogger("nvidia_resiliency_ext.attribution.mcp_integration").setLevel(logging.WARNING) logging.getLogger("uvicorn.access").setLevel(logging.WARNING) - # Wire postprocessing config (lib singleton) - from nvidia_resiliency_ext.attribution.postprocessing import ResultPoster, configure - - from . import dataflow + # Wire postprocessing config (lib singleton); slack resolved from env (SLACK_BOT_TOKEN/SLACK_BOT_TOKEN_FILE) + from nvidia_resiliency_ext.attribution.postprocessing import ( + ResultPoster, + configure_postprocessing_resolved, + ) + from nvidia_resiliency_ext.attribution.postprocessing.dataflow import post as dataflow_post - configure( - default_poster=ResultPoster(post_fn=dataflow.post), + configure_postprocessing_resolved( + default_poster=ResultPoster(post_fn=dataflow_post), cluster_name=cfg.CLUSTER_NAME or "", dataflow_index=cfg.DATAFLOW_INDEX or "", - slack_bot_token=cfg.SLACK_BOT_TOKEN or "", - slack_channel=cfg.SLACK_CHANNEL or "", + slack_token=None, + slack_channel=cfg.SLACK_CHANNEL or None, + cluster_name_env=None, ) - if cfg.SLACK_BOT_TOKEN: - logger.info(f"Slack notifications enabled for channel: {cfg.SLACK_CHANNEL}") - return cfg diff --git a/services/nvrx_attrsvc/dataflow.py b/services/nvrx_attrsvc/dataflow.py deleted file mode 100644 index 42daade6..00000000 --- a/services/nvrx_attrsvc/dataflow.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. -# -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -"""nvdataflow integration for nvrx_attrsvc.""" - -import logging -import time -from typing import Any - -try: - from nvdataflow import post as nv_post - - # Silence verbose nvdataflow INFO logs (they're debug-level messages) - logging.getLogger("nvdataflow").setLevel(logging.WARNING) - logging.getLogger("nvdataflow.post").setLevel(logging.WARNING) - logging.getLogger("nvdataflow.nvdataflowlog").setLevel(logging.WARNING) -except ImportError: - nv_post = None - -logger = logging.getLogger(__name__) - -# Retry configuration -MAX_RETRIES = 3 -INITIAL_BACKOFF_SECONDS = 0.5 # 0.5s, 1s, 2s - - -def post(data: dict[str, Any], index: str) -> bool: - """ - Post data to nvdataflow/elasticsearch with retry logic. - - Uses exponential backoff: 0.5s, 1s, 2s between retries. - - Args: - data: Data dictionary to post - index: Dataflow/elasticsearch index name - - Returns: - True if posted successfully, False otherwise - """ - if nv_post is None: - logger.error("can't import nvdataflow") - return False - - last_error: Exception | None = None - for attempt in range(MAX_RETRIES): - try: - nv_post(data=data, project=index) - if attempt > 0: - logger.info(f"dataflow post succeeded on attempt {attempt + 1}") - return True - except Exception as e: - last_error = e - if attempt < MAX_RETRIES - 1: - backoff = INITIAL_BACKOFF_SECONDS * (2**attempt) - logger.warning( - f"dataflow post failed (attempt {attempt + 1}/{MAX_RETRIES}), retrying in {backoff}s: {e}" - ) - time.sleep(backoff) - - logger.error(f"failed to post to dataflow after {MAX_RETRIES} attempts: {last_error}") - return False diff --git a/src/nvidia_resiliency_ext/attribution/log_analyzer/config.py b/src/nvidia_resiliency_ext/attribution/log_analyzer/config.py index 0ca3e2c1..426f66c1 100644 --- a/src/nvidia_resiliency_ext/attribution/log_analyzer/config.py +++ b/src/nvidia_resiliency_ext/attribution/log_analyzer/config.py @@ -22,6 +22,7 @@ from enum import Enum +# --- Library constants --- # TTL constants (see spec Section 3.2) TTL_PENDING_SECONDS = 7 * 24 * 60 * 60 # 1 week - pending job expiry TTL_TERMINATED_SECONDS = 60 * 60 # 1 hour - terminated job expiry (after GET) diff --git a/src/nvidia_resiliency_ext/attribution/log_analyzer/runner.py b/src/nvidia_resiliency_ext/attribution/log_analyzer/runner.py new file mode 100644 index 00000000..fbbcb944 --- /dev/null +++ b/src/nvidia_resiliency_ext/attribution/log_analyzer/runner.py @@ -0,0 +1,169 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Run log analysis in-process (lib) or via MCP; sync entry points for FT launcher. + +Runs log analysis with a timeout (from config); blocks then returns result or None (skip). + +Uses one long-lived LogAnalyzer (and its RequestCoalescer) per process so that +results are cached per file path—same mapping as the HTTP service. We run a +dedicated thread with an event loop and submit work to it from sync code. +""" + +import asyncio +import logging +import threading +from concurrent.futures import TimeoutError as FuturesTimeoutError +from typing import Any, Dict, Optional + +logger = logging.getLogger(__name__) + +# Long-lived loop, config, and analyzer for the library path so the coalescer is reused. +_lib_loop: Optional[asyncio.AbstractEventLoop] = None +_lib_config: Any = None +_lib_analyzer: Any = None +_lib_loop_ready = threading.Event() +_lib_loop_starting = False +_lib_lock = threading.Lock() + + +def _ensure_analyzer_event_loop() -> None: + """Start the dedicated thread and event loop if not already running""" + global _lib_loop, _lib_loop_starting + with _lib_lock: + if _lib_loop is not None: + return + if _lib_loop_starting: + pass # another caller already started the thread; wait below + else: + _lib_loop_starting = True + + def _run_loop() -> None: + global _lib_loop + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + _lib_loop = loop + _lib_loop_ready.set() + loop.run_forever() + + threading.Thread(target=_run_loop, daemon=True).start() + _lib_loop_ready.wait(timeout=5.0) + if _lib_loop is None: + with _lib_lock: + _lib_loop_starting = False + logger.warning("log analysis lib: event loop thread did not start in time") + + +def _get_or_create_analyzer( + timeout_seconds: float = 60.0, + use_lib_log_analysis: Optional[bool] = None, +) -> bool: + """Ensure the long-lived analyzer exists; sets _lib_loop and _lib_analyzer globals. + + use_lib_log_analysis is only used when creating; pass None to reuse existing (set at init). + Returns True if ready, False otherwise. + """ + global _lib_config, _lib_analyzer + _ensure_analyzer_event_loop() + if not _lib_loop_ready.is_set() or _lib_loop is None: + return False + with _lib_lock: + if _lib_analyzer is not None and _lib_config is not None: + return True + use_lib = use_lib_log_analysis if use_lib_log_analysis is not None else True + try: + from .analyzer import AnalyzerConfig, LogAnalyzer + + _lib_config = AnalyzerConfig( + allowed_root="/", + compute_timeout=timeout_seconds, + use_lib_log_analysis=use_lib, + ) + _lib_analyzer = LogAnalyzer(config=_lib_config) + _lib_analyzer.set_event_loop(_lib_loop) + if not use_lib: + future = asyncio.run_coroutine_threadsafe(_lib_analyzer.connect_mcp(), _lib_loop) + future.result(timeout=30) + except Exception as e: + _lib_analyzer = None + _lib_config = None + logger.warning("log analysis lib: failed to create analyzer: %s", e) + return False + return True + + +def ensure_analyzer_ready( + timeout_seconds: float = 60.0, + use_lib_log_analysis: bool = True, +) -> bool: + """Eagerly create the analyzer (event loop, AnalyzerConfig, set_event_loop). + Call at client init for fail-fast. Returns True if ready.""" + return _get_or_create_analyzer(timeout_seconds, use_lib_log_analysis) + + +def _raw_to_result_dict(raw: Any) -> Optional[Dict[str, Any]]: + """Convert analyzer result to the dict shape used by attribution_no_restart.""" + if hasattr(raw, "result"): + r = getattr(raw, "result", None) + return r if isinstance(r, dict) else None + return None + + +def run_log_analysis_sync( + log_path: str, + wl_restart: Optional[int] = None, + user: str = "", + job_id: str = "", +) -> Optional[Dict[str, Any]]: + """Run log analysis synchronously with a timeout. + + If the analysis does not complete within the configured timeout, the result is skipped + (returns None). Timeout comes from _lib_config (set at init). + + Uses the analyzer's RequestCoalescer: results are cached per file path (same + as the HTTP service). Repeat calls for the same path return the cached result; + wl_restart selects the cycle when one file has multiple cycles. + + Args: + log_path: Path to the cycle log file to analyze. + wl_restart: Workload restart index within file (None = first or all). + When a file contains multiple cycles, use this to select which cycle's result. + + Returns: + Result dict from the analyzer on success, or None on timeout/error/skip. + """ + from .analyzer import AnalyzerError + + if not _get_or_create_analyzer(): + return None + + validated = _lib_analyzer.validate_path(log_path, require_regular_file=True, reject_empty=False) + if isinstance(validated, AnalyzerError): + logger.debug("log analysis lib: skip (path validation): %s", validated.message) + return None + + timeout = _lib_config.compute_timeout + + async def _run() -> Any: + await _lib_analyzer.submit(validated, user=user, job_id=job_id) + return await _lib_analyzer.analyze(validated, wl_restart=wl_restart) + + try: + future = asyncio.run_coroutine_threadsafe(_run(), _lib_loop) + raw = future.result(timeout=timeout) + except FuturesTimeoutError: + logger.info("log analysis lib: skipped (timeout after %.0fs): %s", timeout, log_path) + return None + except Exception as e: + logger.warning("log analysis lib: skip (exception): %s: %s", type(e).__name__, e) + return None + + if isinstance(raw, AnalyzerError): + logger.debug("log analysis lib: analysis error for %s: %s", log_path, raw.message) + return None + return _raw_to_result_dict(raw) diff --git a/src/nvidia_resiliency_ext/attribution/log_analyzer/utils.py b/src/nvidia_resiliency_ext/attribution/log_analyzer/utils.py index 54bce83f..78297d83 100644 --- a/src/nvidia_resiliency_ext/attribution/log_analyzer/utils.py +++ b/src/nvidia_resiliency_ext/attribution/log_analyzer/utils.py @@ -12,10 +12,77 @@ import re from dataclasses import dataclass from datetime import datetime -from typing import Any, Dict +from typing import Any, Dict, Optional logger = logging.getLogger(__name__) + +def attribution_no_restart(attr_result: Optional[Dict[str, Any]]) -> bool: + """Whether attribution recommends do not restart (stop). + + Call this on the raw result from log analysis or an attribution service (or None if unavailable). + True = attribution recommends stop; False = recommends restart or no result (skip). + + Handles result shapes: state STOP/CONTINUE/RESTART, or strings containing + 'STOP - DONT RESTART' / 'RESTART IMMEDIATE'. + """ + if attr_result is None or not isinstance(attr_result, dict): + return False + state = attr_result.get("state") + if state == "STOP": + return True + if state in ("CONTINUE", "RESTART"): + return False + nested = attr_result.get("result") + if isinstance(nested, dict): + nested_state = nested.get("state") + if nested_state == "STOP": + return True + if nested_state in ("CONTINUE", "RESTART"): + return False + if isinstance(nested, (list, tuple)) and nested: + first = nested[0] + s = first if isinstance(first, str) else str(first) + if "STOP" in s and "RESTART" not in s.split("STOP")[0]: + return True + if "RESTART" in s: + return False + # Fallback: string matching on stringified result (fragile) + s = str(attr_result) + logger.warning( + "attribution_no_restart: falling through to string matching on result: %s", + s[:200] + ("..." if len(s) > 200 else ""), + ) + if "STOP" in s and "DONT RESTART" in s: + return True + if "RESTART" in s and "IMMEDIATE" in s: + return False + return False + + +def log_attribution_result(attr_result: Any) -> None: + """Log attribution result in a readable format""" + if attr_result is None: + logger.info("Attribution result: None") + elif isinstance(attr_result, dict): + att = attr_result.get("s_attribution") or attr_result.get("attribution") + expl = attr_result.get("s_auto_resume_explanation") or attr_result.get( + "auto_resume_explanation" + ) + job = attr_result.get("s_job_id") or attr_result.get("job_id", "?") + if att or expl: + logger.info( + "Attribution result (job=%s): attribution=%r explanation=%r", + job, + (att or "")[:200], + (expl or "")[:200], + ) + else: + logger.info("Attribution result: %s", str(attr_result)[:300]) + else: + logger.info("Attribution result: %s", str(attr_result)[:300]) + + # Regex patterns for log file path parsing and splitlog file discovery # Per-cycle log files (e.g., foo_cycle3.log); raw string for re.search(), compiled for .search() diff --git a/src/nvidia_resiliency_ext/attribution/postprocessing/__init__.py b/src/nvidia_resiliency_ext/attribution/postprocessing/__init__.py index a3d70c24..74ff4a5f 100644 --- a/src/nvidia_resiliency_ext/attribution/postprocessing/__init__.py +++ b/src/nvidia_resiliency_ext/attribution/postprocessing/__init__.py @@ -3,7 +3,7 @@ """Postprocessing for attribution results. -- config: Singleton PostprocessingConfig; set config.default_poster, config.cluster_name, config.dataflow_index, config.slack_bot_token, config.slack_channel at startup. +- config: Singleton PostprocessingConfig; set via configure() at startup. - base: ResultPoster, post_results, get_default_poster. - slack: API and maybe_send_slack_notification (used by post_results). @@ -26,9 +26,14 @@ get_default_poster, post_results, ) -from .config import PostprocessingConfig, config, configure +from .config import ( + PostprocessingConfig, + config, + configure, + configure_postprocessing_resolved, + load_slack_from_env, +) from .slack import ( - HAS_SLACK, SlackStats, get_slack_stats, get_slack_user_id, @@ -42,6 +47,8 @@ "PostprocessingConfig", "config", "configure", + "configure_postprocessing_resolved", + "load_slack_from_env", # Base "DataflowStats", "PostFunction", @@ -50,7 +57,6 @@ "get_default_poster", "post_results", # Slack - "HAS_SLACK", "SlackStats", "get_slack_stats", "get_slack_user_id", diff --git a/src/nvidia_resiliency_ext/attribution/postprocessing/base.py b/src/nvidia_resiliency_ext/attribution/postprocessing/base.py index cc186980..e9dc54a1 100644 --- a/src/nvidia_resiliency_ext/attribution/postprocessing/base.py +++ b/src/nvidia_resiliency_ext/attribution/postprocessing/base.py @@ -22,6 +22,7 @@ JobMetadata, ParsedLLMResponse, build_dataflow_record, + log_attribution_result, ) from .config import config @@ -107,11 +108,8 @@ def post_results( user=user, ) - logger.info("jobid: %s", metadata.job_id) + log_attribution_result(data) logger.info("log_path: %s", log_path) - logger.info("auto_resume: %s", parsed.auto_resume) - logger.info("auto_resume_explanation: %s", parsed.auto_resume_explanation) - logger.info("attribution_text: %s", parsed.attribution_text) poster = get_default_poster() success = True diff --git a/src/nvidia_resiliency_ext/attribution/postprocessing/config.py b/src/nvidia_resiliency_ext/attribution/postprocessing/config.py index aeabaa6d..e8635e95 100644 --- a/src/nvidia_resiliency_ext/attribution/postprocessing/config.py +++ b/src/nvidia_resiliency_ext/attribution/postprocessing/config.py @@ -4,12 +4,41 @@ """Shared config for postprocessing (poster, dataflow, Slack). Assign attributes at startup.""" import logging +import os from dataclasses import dataclass -from typing import Any +from typing import Any, Optional, Tuple logger = logging.getLogger(__name__) +def load_slack_from_env() -> Tuple[str, str]: + """Load Slack token and channel from environment (mirrors load_nvidia_api_key). + + Token checks in order: + 1. SLACK_BOT_TOKEN environment variable + 2. SLACK_BOT_TOKEN_FILE environment variable (path to token file) + + Channel: SLACK_CHANNEL environment variable. + + Returns: + (token, channel) tuple; empty strings if not found. + """ + token = "" + token_val = os.getenv("SLACK_BOT_TOKEN") + if token_val: + token = token_val.strip() + if not token: + key_file = os.getenv("SLACK_BOT_TOKEN_FILE") + if key_file and os.path.isfile(key_file): + try: + with open(key_file) as f: + token = f.read().strip() + except OSError: + pass + channel = (os.getenv("SLACK_CHANNEL") or "").strip() + return (token, channel) + + @dataclass class PostprocessingConfig: """Single place for postprocessing state. Callers set attributes directly (e.g. config.slack_bot_token = ...).""" @@ -51,3 +80,74 @@ def configure( "postprocessing: slack_channel is set but slack_bot_token is empty; " "Slack notifications will not be sent" ) + + +def configure_postprocessing_resolved( + *, + default_poster: Any = None, + cluster_name: str = "", + dataflow_index: str = "", + slack_token: Optional[str] = None, + slack_channel: Optional[str] = None, + cluster_name_env: Optional[str] = "SLURM_CLUSTER_NAME", + create_dataflow_poster_if_needed: bool = False, +) -> None: + """Configure postprocessing singleton. Resolves from env when params are None/empty. + + Centralizes logic used by nvrx_attrsvc and FT lib/mcp. Call this instead of + manually resolving slack + calling configure(). + + Args: + default_poster: ResultPoster to use (or None). + cluster_name: Override; if empty and cluster_name_env set, uses that env var. + dataflow_index: Elasticsearch index for dataflow posting. + slack_token: Override; None = resolve from SLACK_BOT_TOKEN/SLACK_BOT_TOKEN_FILE. + slack_channel: Override; None = resolve from SLACK_CHANNEL. + cluster_name_env: Env var for cluster_name when cluster_name empty (e.g. SLURM_CLUSTER_NAME). + create_dataflow_poster_if_needed: If True, dataflow_index set, and default_poster None, + creates ResultPoster from nvdataflow when available. + """ + if not (slack_token or "").strip() or not (slack_channel or "").strip(): + env_tok, env_ch = load_slack_from_env() + slack_token = (slack_token or env_tok).strip() if slack_token is not None else env_tok + slack_channel = (slack_channel or env_ch).strip() if slack_channel is not None else env_ch + else: + slack_token = (slack_token or "").strip() + slack_channel = (slack_channel or "").strip() + + if not cluster_name and cluster_name_env: + cluster_name = os.environ.get(cluster_name_env, "") + + poster = default_poster + if poster is None and create_dataflow_poster_if_needed and dataflow_index: + from .base import ResultPoster + from .dataflow import get_nvdataflow_post_fn + + post_fn = get_nvdataflow_post_fn() + if post_fn: + poster = ResultPoster(post_fn=post_fn) + + configure( + default_poster=poster, + cluster_name=cluster_name, + dataflow_index=dataflow_index, + slack_bot_token=slack_token, + slack_channel=slack_channel, + ) + + # Log status of optional integrations + if config.slack_bot_token: + logger.info( + "Slack notifications enabled for channel: %s", + config.slack_channel or "(none)", + ) + if dataflow_index: + if config.default_poster is not None: + logger.info( + "Dataflow posting enabled for attribution (index=%s)", + dataflow_index, + ) + else: + logger.warning( + "dataflow_index set but nvdataflow not installed; dataflow posting disabled" + ) diff --git a/src/nvidia_resiliency_ext/attribution/postprocessing/dataflow.py b/src/nvidia_resiliency_ext/attribution/postprocessing/dataflow.py new file mode 100644 index 00000000..3b42d4cd --- /dev/null +++ b/src/nvidia_resiliency_ext/attribution/postprocessing/dataflow.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Optional nvdataflow integration for lib/mcp attribution posting. + +When nvdataflow is installed, provides a post function for ResultPoster. +Otherwise get_nvdataflow_post_fn returns None and dataflow posting is skipped. +""" + +import logging +import time +from typing import Any, Callable, Dict, Optional + +try: + from nvdataflow import post as nv_post + + logging.getLogger("nvdataflow").setLevel(logging.WARNING) + logging.getLogger("nvdataflow.post").setLevel(logging.WARNING) + logging.getLogger("nvdataflow.nvdataflowlog").setLevel(logging.WARNING) + HAS_NVDATAFLOW = True +except ImportError: + nv_post = None + HAS_NVDATAFLOW = False + +logger = logging.getLogger(__name__) + +MAX_RETRIES = 3 +INITIAL_BACKOFF_SECONDS = 0.5 + + +def _post_with_retry(data: Dict[str, Any], index: str) -> bool: + """Post to nvdataflow with retry. Requires nvdataflow to be installed.""" + if nv_post is None: + logger.error("nvdataflow not installed, cannot post") + return False + last_error: Optional[Exception] = None + for attempt in range(MAX_RETRIES): + try: + nv_post(data=data, project=index) + if attempt > 0: + logger.info("dataflow post succeeded on attempt %d", attempt + 1) + return True + except Exception as e: + last_error = e + if attempt < MAX_RETRIES - 1: + backoff = INITIAL_BACKOFF_SECONDS * (2**attempt) + logger.warning( + "dataflow post failed (attempt %d/%d), retrying in %.1fs: %s", + attempt + 1, + MAX_RETRIES, + backoff, + e, + ) + time.sleep(backoff) + logger.error("failed to post to dataflow after %d attempts: %s", MAX_RETRIES, last_error) + return False + + +def post(data: Dict[str, Any], index: str) -> bool: + """ + Post data to nvdataflow/elasticsearch with retry logic. + + Callable directly for ResultPoster(post_fn=post). Returns False if nvdataflow not installed. + """ + return _post_with_retry(data, index) + + +def get_nvdataflow_post_fn() -> Optional[Callable[[dict, str], bool]]: + """Return post function for nvdataflow, or None if nvdataflow is not installed.""" + if not HAS_NVDATAFLOW: + return None + return _post_with_retry diff --git a/src/nvidia_resiliency_ext/attribution/postprocessing/slack.py b/src/nvidia_resiliency_ext/attribution/postprocessing/slack.py index 9a2b60a8..17c55c76 100644 --- a/src/nvidia_resiliency_ext/attribution/postprocessing/slack.py +++ b/src/nvidia_resiliency_ext/attribution/postprocessing/slack.py @@ -9,13 +9,14 @@ Usage: config.slack_bot_token = token; config.slack_channel = channel # at startup # One-off: send_slack_notification(data, token, channel) when should_notify_slack(auto_resume) - -Requires slack-sdk (optional). When not installed, HAS_SLACK is False and send no-ops. """ import logging from dataclasses import dataclass +from slack_sdk import WebClient +from slack_sdk.errors import SlackApiError + from .config import config logger = logging.getLogger(__name__) @@ -38,16 +39,6 @@ class SlackStats: # Global stats instance _slack_stats = SlackStats() -try: - from slack_sdk import WebClient - from slack_sdk.errors import SlackApiError - - HAS_SLACK = True -except ImportError: - HAS_SLACK = False - WebClient = None # type: ignore - SlackApiError = Exception # type: ignore - def get_slack_stats() -> SlackStats: """Get current Slack statistics.""" @@ -64,10 +55,6 @@ def get_slack_user_id(user_id: str, token: str) -> str | None: Returns: Slack user ID if found, None otherwise """ - if not HAS_SLACK: - logger.warning("slack-sdk not installed, cannot look up user") - return None - _slack_stats.user_lookups += 1 client = WebClient(token=token) @@ -99,10 +86,6 @@ def send_slack_notification( Returns: True if notification sent successfully, False otherwise """ - if not HAS_SLACK: - logger.warning("slack-sdk not installed, cannot send notification") - return False - if not slack_bot_token: logger.debug("Slack notification skipped: no bot token configured") return False diff --git a/src/nvidia_resiliency_ext/attribution/utils.py b/src/nvidia_resiliency_ext/attribution/utils.py index 8f41d37a..2018d677 100644 --- a/src/nvidia_resiliency_ext/attribution/utils.py +++ b/src/nvidia_resiliency_ext/attribution/utils.py @@ -4,6 +4,8 @@ from contextlib import contextmanager from io import StringIO +logger = logging.getLogger(__name__) + def load_nvidia_api_key() -> str: """Load NVIDIA API key from environment or file. @@ -44,19 +46,19 @@ def load_nvidia_api_key() -> str: @contextmanager def capture_logs(logger_name=None): - logger = logging.getLogger(logger_name) + target_logger = logging.getLogger(logger_name) # Save original handlers - original_handlers = logger.handlers.copy() + original_handlers = target_logger.handlers.copy() # Create capture handler log_capture = StringIO() capture_handler = logging.StreamHandler(log_capture) - logger.handlers = [capture_handler] + target_logger.handlers = [capture_handler] try: yield log_capture finally: # Restore original handlers - logger.handlers = original_handlers + target_logger.handlers = original_handlers @contextmanager diff --git a/src/nvidia_resiliency_ext/fault_tolerance/__init__.py b/src/nvidia_resiliency_ext/fault_tolerance/__init__.py index c3a9cc28..e688dfe1 100644 --- a/src/nvidia_resiliency_ext/fault_tolerance/__init__.py +++ b/src/nvidia_resiliency_ext/fault_tolerance/__init__.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from nvidia_resiliency_ext.fault_tolerance.ft_attribution import LogAnalysisConfig # noqa: F401 + from .config import FaultToleranceConfig # noqa: F401 from .data import WorkloadAction # noqa: F401 from .data import WorkloadControlRequest # noqa: F401 diff --git a/src/nvidia_resiliency_ext/fault_tolerance/config.py b/src/nvidia_resiliency_ext/fault_tolerance/config.py index 77657694..62a040d1 100644 --- a/src/nvidia_resiliency_ext/fault_tolerance/config.py +++ b/src/nvidia_resiliency_ext/fault_tolerance/config.py @@ -24,6 +24,41 @@ import yaml +def _read_token_from_file(path: str) -> Optional[str]: + """Read token from file path. Returns stripped content or None on error.""" + if not path or not path.strip(): + return None + try: + with open(path.strip(), "r") as f: + return f.read().strip() or None + except OSError: + return None + + +@dataclass(frozen=True) +class SlackConfig: + """Slack notification config. Reusable by attribution and other FT modules.""" + + bot_token: Optional[str] = None + channel: Optional[str] = None + + def to_dict(self) -> dict: + return {"bot_token": self.bot_token, "channel": self.channel} + + @classmethod + def from_dict(cls, d: Optional[dict]) -> Optional["SlackConfig"]: + if not d: + return None + tok = d.get("bot_token") + token_file = d.get("bot_token_file") + if token_file: + tok = _read_token_from_file(token_file) or tok + ch = d.get("channel") + if tok is None and ch is None: + return None + return cls(bot_token=tok, channel=ch) + + @dataclass class FaultToleranceConfig: """ @@ -95,9 +130,15 @@ class FaultToleranceConfig: out-of-section timeouts. The first N iterations (relative to cycle start) are excluded from timeout monitoring as they can be significantly slower than steady-state iterations. Default: 5. Can be overridden by workload (e.g., Megatron-LM via init_workload_monitoring). - * Attribution service (optional): - - `attrsvc_host` [str] hostname/IP of the attribution service - - `attrsvc_port` [int] port of the attribution service + * Attribution (optional): `attribution_loganalysis` [str|None] enables log analysis. None = disabled; + ``"lib"`` = in-process log analysis (default); ``"mcp"`` = log analysis in MCP subprocess; + URL string = HTTP attribution service. + `attribution_timeout_seconds` [int] = wait/timeout in seconds (default 60). + `attribution_dry_run` [bool] = if True, run attribution chain but do not apply the action + (log what would happen; useful for validation). Default: False. + * Slack (shared by attribution and other FT modules): `slack` [SlackConfig|None]. + Token via `bot_token_file` (CLI/yaml) or SLACK_BOT_TOKEN/SLACK_BOT_TOKEN_FILE env. + * `dataflow_index` [str|None] = Elasticsearch/dataflow index for attribution posting (lib/mcp). None = disabled. * `cycle_info_dir` [str|None] Full path to the NVRx cycle info directory (e.g. /nvrx/). If set, the TCPStore host writes cycle info JSON files and the @@ -144,9 +185,13 @@ class FaultToleranceConfig: num_warmup_iterations: int = ( 5 # Number of warmup iterations before monitoring step section and out-of-section timeouts ) - # Attribution service configuration (optional) - attrsvc_host: Optional[str] = None - attrsvc_port: Optional[int] = None + # Attribution: None = off; "lib" = in-process; "mcp" = MCP subprocess; URL = HTTP service + attribution_loganalysis: Optional[str] = None + attribution_timeout_seconds: int = 60 + attribution_dry_run: bool = False # Run attribution chain but don't apply action; log only + # Slack (shared by attribution and other FT modules) + slack: Optional["SlackConfig"] = None + dataflow_index: Optional[str] = None # NVRx cycle info: base directory for cycle_info JSON files cycle_info_dir: Optional[str] = None @@ -171,6 +216,25 @@ def from_kwargs(ignore_not_recognized: bool = True, **kwargs) -> 'FaultTolerance Raises: ValueError: If there are unrecognized arguments and ignore_not_recognized is False. """ + # Preprocess slack: build from nested slack: {...} or flat slack_bot_token/slack_channel/slack_bot_token_file + kwargs = dict(kwargs) + if "slack" not in kwargs and ( + "slack_bot_token" in kwargs + or "slack_bot_token_file" in kwargs + or "slack_channel" in kwargs + ): + tok = kwargs.pop("slack_bot_token", None) + token_file = kwargs.pop("slack_bot_token_file", None) + if token_file: + tok = _read_token_from_file(token_file) or tok + kwargs["slack"] = SlackConfig( + bot_token=tok, + channel=kwargs.pop("slack_channel", None), + ) + slack_val = kwargs.get("slack") + if isinstance(slack_val, dict): + kwargs["slack"] = SlackConfig.from_dict(slack_val) + fields_set = {f.name for f in fields(FaultToleranceConfig) if f.init} matching_args = {k: v for k, v in kwargs.items() if k in fields_set} extra_args = {k: v for k, v in kwargs.items() if k not in fields_set} @@ -289,6 +353,8 @@ def from_args(args: argparse.Namespace): 'gpu_memory_poll_interval', ] for field in fields(FaultToleranceConfig): + if field.name == "slack": + continue # Handled below from ft_slack_bot_token / ft_slack_channel cli_field_name = f"ft_{field.name}" val = getattr(args, cli_field_name, None) if val is not None: @@ -298,6 +364,13 @@ def from_args(args: argparse.Namespace): val = FaultToleranceConfig._parse_timeout_arg(val) cli_ft_args[field.name] = val + # Slack from --ft-slack-token-file / --ft-slack-channel (token from file only; env fallback in ft_attribution) + slack_token_file = getattr(args, "ft_slack_bot_token_file", None) + slack_tok = _read_token_from_file(slack_token_file) if slack_token_file else None + slack_ch = getattr(args, "ft_slack_channel", None) + if slack_tok is not None or slack_ch is not None: + cli_ft_args["slack"] = SlackConfig(bot_token=slack_tok, channel=slack_ch) + # Update config with CLI args for arg_name, arg_val in cli_ft_args.items(): setattr(ft_cfg, arg_name, arg_val) diff --git a/src/nvidia_resiliency_ext/fault_tolerance/ft_attribution.py b/src/nvidia_resiliency_ext/fault_tolerance/ft_attribution.py new file mode 100644 index 00000000..526cbb06 --- /dev/null +++ b/src/nvidia_resiliency_ext/fault_tolerance/ft_attribution.py @@ -0,0 +1,298 @@ +# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fault-tolerance integration with attribution. + +Attribution has multiple analyzer backends; this module integrates the LogAnalysis +analyzer (nvidia_resiliency_ext.attribution.log_analyzer) with the FT launcher. + +Provides LogAnalysisConfig, LogAnalysisClient, and AttributionServiceClient for +invoking log analysis on the Restart & progress path (lib, mcp, or url mode). +""" + +import logging +import threading +from dataclasses import dataclass +from typing import Any, Callable, Dict, Literal, Optional +from urllib.parse import quote_plus + +import httpx + +from nvidia_resiliency_ext.attribution.log_analyzer.utils import attribution_no_restart +from nvidia_resiliency_ext.fault_tolerance.config import SlackConfig +from nvidia_resiliency_ext.fault_tolerance.utils import job_id_from_env, job_user_from_env +from nvidia_resiliency_ext.shared_utils.log_manager import LogConfig + +logger = logging.getLogger(LogConfig.name) + +# Re-export for launcher (parse attribution result → restart decision) +__all__ = [ + "LogAnalysisConfig", + "LogAnalysisClient", + "LogAnalysisMode", + "AttributionServiceClient", + "SlackConfig", + "attribution_no_restart", +] + +# --- Config --- +LogAnalysisMode = Literal["lib", "mcp", "url"] + + +def _validate_attribution_url(url: str) -> str: + """Validate attribution URL and return normalized form (with scheme if missing).""" + if not url or not url.strip(): + raise ValueError("--ft-attribution-loganalysis URL must be non-empty") + s = url.strip() + if "://" in s: + if not s.startswith(("http://", "https://")): + raise ValueError(f"--ft-attribution-loganalysis: expected http(s) URL, got: {url!r}") + return s + if ":" in s: + return f"http://{s}" + raise ValueError( + f"--ft-attribution-loganalysis: expected host:port or http(s)://host:port, got: {url!r}" + ) + + +@dataclass(frozen=True) +class LogAnalysisConfig: + """Configuration for log analysis invocation on the Restart & progress path. + + Use mode ``lib`` (in-process), ``mcp`` (MCP subprocess), or ``url`` (HTTP service). + When mode is ``url``, attribution_service_url must be set (e.g. http://host:8000). + user and job_id are read from env by LogAnalysisClient (SLURM_JOB_USER, SLURM_*_JOB_ID). + slack: SlackConfig for lib/mcp alerts; reuses FaultToleranceConfig.slack when provided. + dataflow_index: Elasticsearch index for lib/mcp posting; reuses FaultToleranceConfig.dataflow_index. + """ + + mode: LogAnalysisMode + attribution_service_url: Optional[str] = None + timeout_seconds: int = 60 + slack: Optional[SlackConfig] = None + dataflow_index: Optional[str] = None + + @property + def use_lib(self) -> bool: + return self.mode == "lib" + + @property + def use_mcp(self) -> bool: + return self.mode == "mcp" + + @property + def use_url(self) -> bool: + return self.mode == "url" + + def to_dict(self) -> Dict[str, Any]: + d = { + "mode": self.mode, + "attribution_service_url": self.attribution_service_url, + "timeout_seconds": self.timeout_seconds, + } + if self.slack is not None: + d["slack"] = self.slack.to_dict() + if self.dataflow_index is not None: + d["dataflow_index"] = self.dataflow_index + return d + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "LogAnalysisConfig": + if not d: + return cls(mode="lib", timeout_seconds=60) + return cls( + mode=d.get("mode", "lib"), + attribution_service_url=d.get("attribution_service_url"), + timeout_seconds=int(d.get("timeout_seconds", 60)), + slack=SlackConfig.from_dict(d.get("slack")), + dataflow_index=d.get("dataflow_index"), + ) + + @classmethod + def from_ft_cli_value( + cls, + val: str, + timeout_seconds: int = 60, + slack: Optional[SlackConfig] = None, + dataflow_index: Optional[str] = None, + ) -> "LogAnalysisConfig": + """Build from CLI string (lib, mcp, or URL). user/job_id read by LogAnalysisClient from env.""" + v = val.strip().lower() + if v == "lib": + return cls( + mode="lib", + timeout_seconds=timeout_seconds, + slack=slack, + dataflow_index=dataflow_index, + ) + if v == "mcp": + return cls( + mode="mcp", + timeout_seconds=timeout_seconds, + slack=slack, + dataflow_index=dataflow_index, + ) + url = _validate_attribution_url(v) + return cls( + mode="url", + attribution_service_url=url, + timeout_seconds=timeout_seconds, + slack=slack, + dataflow_index=dataflow_index, + ) + + +# --- HTTP client (URL mode) --- +class AttributionServiceClient: + """ + HTTP client for the attribution service (URL mode). + Talks to nvrx_attrsvc AttributionService over HTTP. + """ + + def __init__(self, base_url: str, timeout_seconds: float = 60.0): + self._base_url = base_url.rstrip("/") + self._timeout = max(1.0, float(timeout_seconds)) + + def path_notify(self, log_path: str) -> None: + """Notify path before workers start (fire-and-forget POST).""" + threading.Thread( + target=self._do_submit_log, + args=(log_path,), + daemon=True, + ).start() + + def _do_submit_log(self, log_path: str) -> None: + try: + with httpx.Client(timeout=10.0) as client: + url = f"{self._base_url}/logs" + logger.debug("AttributionServiceClient POST: %s (log_path=%s)", url, log_path) + client.post( + url, + json={"log_path": log_path}, + headers={"accept": "application/json"}, + ) + except Exception as e: + logger.warning( + "AttributionServiceClient POST %s failed: %s: %s", log_path, type(e).__name__, e + ) + + def get_result_sync(self, log_path: str) -> Optional[Dict[str, Any]]: + """Get analysis results via GET (blocking). Uses client timeout.""" + if not log_path: + return None + try: + with httpx.Client(timeout=self._timeout) as client: + q_path = quote_plus(log_path) + url = f"{self._base_url}/logs?log_path={q_path}" + logger.debug("AttributionServiceClient GET: %s (log_path=%s)", url, log_path) + resp = client.get(url, headers={"accept": "application/json"}) + if resp.status_code == 200: + payload = resp.json() if resp.text else {} + result = payload.get("result", payload) + if isinstance(result, dict): + return result + return {"result": result} if result is not None else None + logger.warning( + "AttributionServiceClient GET for %s returned %d", log_path, resp.status_code + ) + return None + except Exception as e: + logger.warning( + "AttributionServiceClient GET %s failed: %s: %s", log_path, type(e).__name__, e + ) + return None + + +# --- Client (selects lib / mcp / url backend) --- +class LogAnalysisClient: + """Client for log analysis attribution. Chooses backend from config.""" + + def __init__(self, config: LogAnalysisConfig) -> None: + self._config = config + self._timeout = max(1, config.timeout_seconds) + self._user = job_user_from_env() + self._job_id = job_id_from_env() + self._fetch_result: Optional[Callable[[str], Optional[Dict[str, Any]]]] = None + self._path_notify: Optional[Callable[[str], None]] = None + self._init_backend() + + def _init_backend(self) -> None: + if self._config.use_lib or self._config.use_mcp: + from nvidia_resiliency_ext.attribution.log_analyzer.runner import ( + ensure_analyzer_ready, + run_log_analysis_sync, + ) + from nvidia_resiliency_ext.attribution.postprocessing import ( + configure_postprocessing_resolved, + ) + + # Postprocessing: centralized config (Slack from token file or env; dataflow; cluster from SLURM) + slack_cfg = self._config.slack + slack_token = slack_cfg.bot_token if slack_cfg else None + slack_channel = slack_cfg.channel if slack_cfg else None + dataflow_index = (self._config.dataflow_index or "").strip() + + configure_postprocessing_resolved( + cluster_name="", + dataflow_index=dataflow_index, + slack_token=slack_token, + slack_channel=slack_channel, + cluster_name_env="SLURM_CLUSTER_NAME", + create_dataflow_poster_if_needed=True, + ) + + if not ensure_analyzer_ready( + timeout_seconds=self._timeout, use_lib_log_analysis=self._config.use_lib + ): + self._fetch_result = None + return + + user = self._user + job_id = self._job_id + + def fetch(log_path: str) -> Optional[Dict[str, Any]]: + return run_log_analysis_sync(log_path, user=user, job_id=job_id) + + self._fetch_result = fetch + elif self._config.use_url and self._config.attribution_service_url: + attr_svc = AttributionServiceClient( + base_url=self._config.attribution_service_url, + timeout_seconds=self._timeout, + ) + + def fetch(log_path: str) -> Optional[Dict[str, Any]]: + return attr_svc.get_result_sync(log_path) + + self._fetch_result = fetch + self._path_notify = attr_svc.path_notify + + def fetch_result(self, log_path: str) -> Optional[Dict[str, Any]]: + """Run log analysis and return result. None on skip/timeout/error. + Timeout from config (set at init).""" + if self._fetch_result is None: + return None + return self._fetch_result(log_path) + + def should_stop(self, log_path: str) -> bool: + """Return True if attribution recommends stop (no restart), False to restart. + Wraps fetch_result + attribution_no_restart.""" + attr_result = self.fetch_result(log_path) + return attribution_no_restart(attr_result) + + @property + def path_notify(self) -> Optional[Callable[[str], None]]: + """Notify path before workers start (URL mode only; None for lib/mcp).""" + return self._path_notify diff --git a/src/nvidia_resiliency_ext/fault_tolerance/ft_rendezvous_barrier.py b/src/nvidia_resiliency_ext/fault_tolerance/ft_rendezvous_barrier.py index fa4425dc..ac011758 100644 --- a/src/nvidia_resiliency_ext/fault_tolerance/ft_rendezvous_barrier.py +++ b/src/nvidia_resiliency_ext/fault_tolerance/ft_rendezvous_barrier.py @@ -47,11 +47,14 @@ RendezvousInfo = None RendezvousStoreInfo = None +from nvidia_resiliency_ext.fault_tolerance.ft_attribution import ( + LogAnalysisClient, + LogAnalysisConfig, +) from nvidia_resiliency_ext.shared_utils.log_manager import LogConfig from ..inprocess.utils import format_rank_set_verbose from ..shared_utils.health_check import ( - AttributionService, DistributedStorageHealthCheck, GPUHealthCheck, NicLinkStateHealthCheck, @@ -1491,8 +1494,7 @@ def from_backend( enable_dist_storage_healthcheck: bool = False, link_state_path_template: Optional[str] = None, storage_healthcheck_paths: Optional[list] = None, - attrsvc_host: Optional[str] = None, - attrsvc_port: Optional[int] = None, + log_analysis_config: Optional[LogAnalysisConfig] = None, ): """Create a new :py:class:`FtRendezvousBarrierHandler`. @@ -1523,10 +1525,8 @@ def from_backend( Template path for NIC link state files. storage_healthcheck_paths: List of storage paths to check for health. - attrsvc_host: - Hostname or IP address of the attribution service. - attrsvc_port: - Port number of the attribution service. + log_analysis_config: + Consolidated config for log analysis attribution (mode, attribution_service_url, timeout). """ # We associate each handler instance with a unique node descriptor. node = cls._node_desc_generator.generate(local_addr) @@ -1552,8 +1552,7 @@ def from_backend( enable_dist_storage_healthcheck=enable_dist_storage_healthcheck, link_state_path_template=link_state_path_template, storage_healthcheck_paths=storage_healthcheck_paths, - attrsvc_host=attrsvc_host, - attrsvc_port=attrsvc_port, + log_analysis_config=log_analysis_config, ) def __init__( @@ -1567,8 +1566,7 @@ def __init__( enable_dist_storage_healthcheck: bool = False, link_state_path_template: Optional[str] = None, storage_healthcheck_paths: Optional[list] = None, - attrsvc_host: Optional[str] = None, - attrsvc_port: Optional[int] = None, + log_analysis_config: Optional[LogAnalysisConfig] = None, ) -> None: if not settings.run_id: raise ValueError("The run id must be a non-empty string.") @@ -1629,14 +1627,15 @@ def __init__( StoragePathHealthCheck(storage_healthcheck_paths) if storage_healthcheck_paths else None ) - # Attribution service client (optional, only on master node) - if is_store_host and attrsvc_host and attrsvc_port is not None: - self._attr_service = AttributionService( - host=attrsvc_host, - port=int(attrsvc_port), - ) - else: - self._attr_service = None + # Attribution: log analysis client (optional, only when config enabled) + self._log_analysis_client = None + if is_store_host and log_analysis_config is not None: + self._log_analysis_client = LogAnalysisClient(log_analysis_config) + + @property + def log_analysis_client(self) -> Optional[LogAnalysisClient]: + """Log analysis client for attribution, or None if not configured.""" + return self._log_analysis_client @property def _rendezvous_round(self) -> int: @@ -1787,11 +1786,6 @@ def ensure_node_is_healthy(self) -> None: f"Node {self._this_node} has invalid or unreadable paths.", ) - # Perform optional log analysis (non-fatal) - # Note: _submit_log() was already called from launcher before workers started - if self._attr_service is not None: - self._attr_service() - # Perform Node health check (external service if available) _nodehealth_checker = get_node_health_check() if _nodehealth_checker is not None: @@ -2145,8 +2139,10 @@ def create_handler( ) storage_healthcheck_paths = params.config.get('storage_healthcheck_paths', None) link_state_path_template = params.config.get('link_state_path_template', None) - attrsvc_host = params.config.get('attrsvc_host', None) - attrsvc_port = params.config.get('attrsvc_port', None) + log_analysis_cfg_dict = params.config.get('log_analysis_config', None) + log_analysis_config = None + if log_analysis_cfg_dict: + log_analysis_config = LogAnalysisConfig.from_dict(log_analysis_cfg_dict) return FtRendezvousBarrierHandler.from_backend( params.run_id, @@ -2163,8 +2159,7 @@ def create_handler( enable_dist_storage_healthcheck=enable_dist_storage_healthcheck, link_state_path_template=link_state_path_template, storage_healthcheck_paths=storage_healthcheck_paths, - attrsvc_host=attrsvc_host, - attrsvc_port=attrsvc_port, + log_analysis_config=log_analysis_config, ) except Exception as e: construct_and_record_rdzv_event( diff --git a/src/nvidia_resiliency_ext/fault_tolerance/launcher.py b/src/nvidia_resiliency_ext/fault_tolerance/launcher.py index 21c61f63..b88a50ec 100644 --- a/src/nvidia_resiliency_ext/fault_tolerance/launcher.py +++ b/src/nvidia_resiliency_ext/fault_tolerance/launcher.py @@ -74,6 +74,7 @@ FT_RANK_MONITOR_IPC_SOCKET_ENV_VAR, UpdateConfigMsg, ) +from nvidia_resiliency_ext.fault_tolerance.ft_attribution import LogAnalysisConfig from nvidia_resiliency_ext.fault_tolerance.per_cycle_logs import PipeBasedLogsSpecs from nvidia_resiliency_ext.fault_tolerance.progress_tracker import TrainingProgressTracker from nvidia_resiliency_ext.fault_tolerance.rank_monitor_server import RankMonitorServer @@ -81,6 +82,7 @@ get_processes_by_pgids, hostnames_to_slurm_nodelist, is_slurm_job_array, + job_id_from_env, patched_method, terminate_mp_processes, write_obj_to_ipc_stream, @@ -491,7 +493,7 @@ def _on_cycle_end(self) -> None: if self._cycle_info_writer is None: return current_cycle = self._get_global_restart_count() - job_id = os.environ.get("SLURM_ARRAY_JOB_ID") or os.environ.get("SLURM_JOB_ID", "") + job_id = job_id_from_env() attempt_index = int(os.environ.get("SLURM_RESTART_CNT", "0")) self._cycle_info_writer.update_cycle_end( job_id=job_id, @@ -504,7 +506,7 @@ def _write_cycle_start_info(self, current_cycle: int) -> Optional[str]: """Write NVRx cycle info at cycle start. Returns path to current cycle info file, or None.""" if self._cycle_info_writer is None: return None - job_id = os.environ.get("SLURM_ARRAY_JOB_ID") or os.environ.get("SLURM_JOB_ID", "") + job_id = job_id_from_env() attempt_index = int(os.environ.get("SLURM_RESTART_CNT", "0")) cycle_log_file = self._logs_specs.get_cycle_log_file(current_cycle) # Legacy FtRendezvousHandler does not define these; barrier handler does. @@ -575,6 +577,11 @@ def _open_rendezvous_for_restart(self): logger.error(f"Failed to open rendezvous: {e}") # For legacy rendezvous, no action needed - it uses different mechanism + def _restart_workers(self, worker_group: WorkerGroup, *args, **kwargs) -> None: + """Override to pass will_restart and time_consumed_before_reclaim to _stop_workers.""" + self._stop_workers(worker_group, *args, will_restart=True, **kwargs) + self._start_workers(worker_group) + def _handle_restart_decision( self, role: str, @@ -583,7 +590,10 @@ def _handle_restart_decision( open_rendezvous: bool = False, notify_peer: bool = False, ) -> bool: - """Handle restart decision logic based on progress tracking and remaining restarts. + """Decide whether to restart based on attribution, progress tracking, and remaining restarts. + + If restart: calls _restart_workers and returns True. + If stop: returns False; caller must call _stop_workers. Args: role: The role name for logging @@ -593,9 +603,27 @@ def _handle_restart_decision( notify_peer: Whether to notify peers to abort the workers in current cycle. Returns: - True if restart was initiated (caller should continue monitoring loop) - False if no restart (caller should stop workers and return failure) + True if restart was initiated, False if no restart (caller should call _stop_workers). """ + # Notify peers immediately so they can proceed with their own rendezvous + if notify_peer and hasattr(self._rdzv_handler, '_barrier_state'): + self._rdzv_handler._barrier_state._increment_peer_aborted_count() + if open_rendezvous: + self._open_rendezvous_for_restart() + + start = time.time() + should_terminate_early = self._run_attribution() + if should_terminate_early: + if self._ft_cfg.attribution_dry_run: + logger.info( + "[%s] Attribution dry run: would NOT restart (attribution says stop), " + "but proceeding as configured (action not applied).", + role, + ) + else: + logger.error("[%s] Attribution says do not restart; will not restart.", role) + return False + self._progress_tracker.analyze_previous_cycle() should_terminate_early = self._progress_tracker.should_terminate_early() @@ -609,12 +637,11 @@ def _handle_restart_decision( elif self._remaining_restarts > 0: logger.info(log_msg, role) self._remaining_restarts -= 1 - # Increment peer_aborted_count to notify other nodes (for barrier-based rendezvous) - if notify_peer and hasattr(self._rdzv_handler, '_barrier_state'): - self._rdzv_handler._barrier_state._increment_peer_aborted_count() - if open_rendezvous: - self._open_rendezvous_for_restart() - self._restart_workers(self._worker_group) + time_consumed = time.time() - start + self._restart_workers( + self._worker_group, + time_consumed_before_reclaim=time_consumed, + ) return True else: # No more restarts available @@ -678,13 +705,10 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes ) should_restart = self._handle_restart_decision( role, spec, log_msg, open_rendezvous=True, - notify_peer=True + notify_peer=True, ) - if should_restart: - continue # Continue monitoring after restart - - # No more restarts (either exhausted or early termination) + continue self._stop_workers(self._worker_group) self._worker_group.state = WorkerState.FAILED return RunResult(state=WorkerState.FAILED) @@ -711,17 +735,16 @@ def _invoke_run_with_any_failed_policy(self, role: str = DEFAULT_ROLE) -> RunRes f"(nodes_waiting={num_nodes_waiting}, peer_aborted={peer_aborted_count}); " f"will restart worker group" ) - # Note: The node that triggered the change (unhealthy or new) already opened - # the rendezvous, so we don't need to open it again here. + # Note: The node that triggered the change already opened the rendezvous. should_restart = self._handle_restart_decision( role, spec, log_msg, open_rendezvous=False, - notify_peer=False + notify_peer=False, ) - - if not should_restart: - self._stop_workers(self._worker_group) - self._worker_group.state = WorkerState.FAILED - return RunResult(state=WorkerState.FAILED) + if should_restart: + continue + self._stop_workers(self._worker_group) + self._worker_group.state = WorkerState.FAILED + return RunResult(state=WorkerState.FAILED) else: raise Exception(f"[{role}] Worker group in {state.name} state") @@ -1043,15 +1066,41 @@ def _log_watchdog_event( event = events.Event(name=name, source=events.EventSource.AGENT, metadata=metadata) events.record(event) + @property + def _log_analysis_client(self): + """Log analysis client from rdzv handler, or None if not configured.""" + return getattr(self._rdzv_handler, "log_analysis_client", None) + + def _run_attribution(self) -> bool: + """Run attribution if configured. Returns True if attribution says do not restart, else False.""" + if not self._is_store_host or self._log_analysis_client is None: + return False + cycle_log_file = None + if hasattr(self._logs_specs, "get_cycle_log_file"): + current_cycle = self._get_global_restart_count() + cycle_log_file = self._logs_specs.get_cycle_log_file(current_cycle) + if cycle_log_file is None: + return False + return self._log_analysis_client.should_stop(cycle_log_file) + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator # `torch.distributed.elastic.metrics.prof`. @prof - def _stop_workers(self, worker_group: WorkerGroup, *args, **kwargs) -> None: + def _stop_workers( + self, worker_group: WorkerGroup, *args, **kwargs + ) -> Optional[Any]: # Support both old and new SimpleElasticAgent._stop_workers signatures: # - Before 2.5.1: _stop_workers(self, worker_group: WorkerGroup) -> None # - 2.5.1: _stop_workers(self, worker_group: WorkerGroup, is_restarter: bool = False) -> None # - 2.7.1+: _stop_workers(self, worker_group: WorkerGroup) -> None (reverted back) # We use *args and **kwargs to handle both cases transparently + # + # Optional: will_restart [bool] - if True, wait for GPU reclaim before next cycle. + # Optional: time_consumed_before_reclaim [float] - deducted from reclaim budget when will_restart. + will_restart: bool = kwargs.pop("will_restart", False) + time_consumed_before_reclaim: float = kwargs.pop( + "time_consumed_before_reclaim", 0.0 + ) logger.info(f"Stopping workers... Timeout = {self._workers_stop_timeout} sec.") # Rank monitors will detect worker shutdown when worker processes disconnect @@ -1076,16 +1125,22 @@ def _stop_workers(self, worker_group: WorkerGroup, *args, **kwargs) -> None: else: logger.debug("All worker processes and descendants terminated successfully") - # Wait for GPU memory to be reclaimed BEFORE returning control - # This ensures the node doesn't proceed to the next rendezvous cycle while memory is still tied up - if self._ft_cfg.gpu_memory_reclaim_timeout > 0: - logger.debug( - "Waiting for GPU memory to be reclaimed (timeout: %ds, tolerance: %d MB, poll interval: %ds)...", - int(self._ft_cfg.gpu_memory_reclaim_timeout), - int(self._ft_cfg.gpu_memory_tolerance_mb), - int(self._ft_cfg.gpu_memory_poll_interval), - ) - self._wait_for_gpu_memory_reclaim(worker_group.spec.local_world_size) + # Wait for GPU memory to be reclaimed only when restarting (shutdown case skips). + reclaim_timeout = self._ft_cfg.gpu_memory_reclaim_timeout + if will_restart and reclaim_timeout > 0: + remaining_reclaim = max(0.0, reclaim_timeout - time_consumed_before_reclaim) + if remaining_reclaim > 0: + logger.debug( + "Waiting for GPU memory to be reclaimed (timeout: %.1fs, " + "tolerance: %d MB, poll interval: %ds)...", + remaining_reclaim, + int(self._ft_cfg.gpu_memory_tolerance_mb), + int(self._ft_cfg.gpu_memory_poll_interval), + ) + self._wait_for_gpu_memory_reclaim( + worker_group.spec.local_world_size, + timeout_override=remaining_reclaim, + ) # Wait for reader thread to drain pipes (polls every 100ms, wait 3 cycles) # then close pipe file objects to prevent FD reuse bugs @@ -1116,6 +1171,7 @@ def _stop_workers(self, worker_group: WorkerGroup, *args, **kwargs) -> None: node_id=self._rdzv_handler._this_node, rank=worker_group.group_rank, ) + return None # pyre-fixme[56]: Pyre was not able to infer the type of the decorator # `torch.distributed.elastic.metrics.prof`. @@ -1163,14 +1219,12 @@ def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: f"MASTER_ADDR={master_addr}, MASTER_PORT={master_port}" ) - # Submit current cycle's log to attribution service (master node only, before workers start) - if ( - self._is_store_host - and self._rdzv_handler._attr_service is not None - and hasattr(self._logs_specs, 'get_cycle_log_file') - ): - cycle_log_file = self._logs_specs.get_cycle_log_file(current_cycle) - self._rdzv_handler._attr_service._submit_log(cycle_log_file) + # Submit current cycle's log to attribution service (master node only, URL mode, before workers start) + if self._is_store_host and hasattr(self._logs_specs, "get_cycle_log_file"): + client = self._log_analysis_client + if client is not None and client.path_notify is not None: + cycle_log_file = self._logs_specs.get_cycle_log_file(current_cycle) + client.path_notify(cycle_log_file) # Write NVRx cycle info and set env for workload current_cycle_info_path = self._write_cycle_start_info(current_cycle) @@ -1280,7 +1334,9 @@ def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: return self._pcontext.pids() - def _wait_for_gpu_memory_reclaim(self, num_gpus: int) -> None: + def _wait_for_gpu_memory_reclaim( + self, num_gpus: int, timeout_override: Optional[float] = None + ) -> None: """ Wait for GPU memory to be reclaimed below the tolerance threshold before starting new workers. This is called on restarts (not on initial start) to ensure memory has been cleaned up. @@ -1289,6 +1345,7 @@ def _wait_for_gpu_memory_reclaim(self, num_gpus: int) -> None: Args: num_gpus: Number of GPUs on this node + timeout_override: If set, use this instead of gpu_memory_reclaim_timeout (for time accounting). """ def log_memory_stats(memory_stats, num_gpus, log_func, message_template, *args): """Helper to log GPU memory statistics.""" @@ -1312,7 +1369,13 @@ def log_memory_stats(memory_stats, num_gpus, log_func, message_template, *args): ) memory_logger = GPUMemoryLogger() - timeout = self._ft_cfg.gpu_memory_reclaim_timeout + timeout = ( + timeout_override + if timeout_override is not None + else self._ft_cfg.gpu_memory_reclaim_timeout + ) + if timeout <= 0: + return tolerance_mb = self._ft_cfg.gpu_memory_tolerance_mb poll_interval = self._ft_cfg.gpu_memory_poll_interval @@ -1857,9 +1920,7 @@ def launch_agent( # unhealthy_count) before the store goes away. shutdown_rdzv only controls explicit # permanent-close signaling; it cannot keep the store alive after process exit. if is_store_host: - # Trigger attribution service analysis for final cycle - if agent._rdzv_handler._attr_service is not None: - agent._rdzv_handler._attr_service() + # Attribution is invoked on the Restart & progress state path (inside _handle_restart_decision), not at exit. # No ordering required between cycle_info_writer and rendezvous: the writer # is independent I/O. Run grace-period wait and writer shutdown in parallel @@ -2892,22 +2953,60 @@ def get_args_parser() -> ArgumentParser: "format and log the traceback, and use os._exit() to exit the process reliably. Default: False.", ) - # Attribution service configuration (optional) + # Attribution: --ft-attribution-loganalysis [lib|mcp|url]; default lib parser.add_argument( - "--ft-attrsvc-host", - "--ft_attrsvc_host", - type=str, + "--ft-attribution-loganalysis", + "--ft_attribution_loganalysis", + nargs="?", + const="lib", default=None, - dest="ft_attrsvc_host", - help="Hostname or IP for the attribution service (e.g., 127.0.0.1).", + dest="ft_attribution_loganalysis", + metavar="MODE", + help="Enable log analysis attribution on the Restart & progress state path. " + "lib (default)= in-process; mcp= LogSage in MCP subprocess; " + "url= HTTP service (e.g. http://127.0.0.1:8000).", ) parser.add_argument( - "--ft-attrsvc-port", - "--ft_attrsvc_port", + "--ft-attribution-timeout", + "--ft_attribution_timeout", type=int, + default=60, + dest="ft_attribution_timeout_seconds", + help="Attribution wait/timeout in seconds; skip result if exceeded (default: 60).", + ) + parser.add_argument( + "--ft-attribution-dry-run", + "--ft_attribution_dry_run", + action="store_true", + default=None, + dest="ft_attribution_dry_run", + help="Attribution dry run: run full attribution chain (log analysis, Slack, dataflow) " + "but do not apply the restart/stop decision. Log what would happen instead. " + "Useful for validating the chain without affecting behavior.", + ) + parser.add_argument( + "--ft-slack-channel", + "--ft_slack_channel", + type=str, default=None, - dest="ft_attrsvc_port", - help="Port for the attribution service (e.g., 8000).", + dest="ft_slack_channel", + help="Slack channel for FT alerts (attribution, etc.).", + ) + parser.add_argument( + "--ft-slack-token-file", + "--ft_slack_token_file", + type=str, + default=None, + dest="ft_slack_bot_token_file", + help="Path to file containing Slack bot token. Else uses SLACK_BOT_TOKEN/SLACK_BOT_TOKEN_FILE env.", + ) + parser.add_argument( + "--ft-dataflow-index", + "--ft_dataflow_index", + type=str, + default=None, + dest="ft_dataflow_index", + help="Dataflow/Elasticsearch index for attribution posting (lib/mcp). Requires nvdataflow.", ) parser.add_argument( @@ -3091,6 +3190,7 @@ def _validate_args(args: Any) -> None: "Cycle info needs per-cycle log file path from the applog prefix." ) + def config_from_args(args, launcher_pipe_read_fd=None, launcher_log_file=None) -> Tuple[LaunchConfig, Union[Callable, str], List[str]]: # If ``args`` not passed, defaults to ``sys.argv[:1]`` _validate_args(args) @@ -3164,11 +3264,22 @@ def config_from_args(args, launcher_pipe_read_fd=None, launcher_log_file=None) - # Pass enable_nic_healthcheck and link_state_path_template from fault tolerance config to rendezvous config rdzv_configs['enable_nic_healthcheck'] = fault_tol_cfg.enable_nic_healthcheck rdzv_configs['link_state_path_template'] = fault_tol_cfg.link_state_path_template - # Pass attribution service configuration if provided - if getattr(fault_tol_cfg, 'attrsvc_host', None): - rdzv_configs['attrsvc_host'] = fault_tol_cfg.attrsvc_host - if getattr(fault_tol_cfg, 'attrsvc_port', None) is not None: - rdzv_configs['attrsvc_port'] = int(fault_tol_cfg.attrsvc_port) + + # Attribution: --ft-attribution-loganalysis [lib|mcp|url]; default lib + attribution_loganalysis = getattr(fault_tol_cfg, "attribution_loganalysis", None) + attribution_timeout = int(getattr(fault_tol_cfg, "attribution_timeout_seconds", 60)) + ft_slack = getattr(fault_tol_cfg, "slack", None) + ft_dataflow_index = getattr(fault_tol_cfg, "dataflow_index", None) + if attribution_loganalysis: + timeout_sec = max(1, attribution_timeout) + log_analysis_cfg = LogAnalysisConfig.from_ft_cli_value( + attribution_loganalysis, + timeout_seconds=timeout_sec, + slack=ft_slack, + dataflow_index=ft_dataflow_index, + ) + rdzv_configs["log_analysis_config"] = log_analysis_cfg.to_dict() + # Pass distributed storage health check configuration cli_dist_storage = getattr(args, 'ft_enable_dist_storage_healthcheck', None) if cli_dist_storage is not None: diff --git a/src/nvidia_resiliency_ext/fault_tolerance/utils.py b/src/nvidia_resiliency_ext/fault_tolerance/utils.py index 38de84d6..20af1e3e 100644 --- a/src/nvidia_resiliency_ext/fault_tolerance/utils.py +++ b/src/nvidia_resiliency_ext/fault_tolerance/utils.py @@ -232,6 +232,16 @@ def is_slurm_job_array() -> bool: return os.getenv('SLURM_ARRAY_TASK_ID') is not None +def job_user_from_env() -> str: + """Read job user from SLURM_JOB_USER or USER env.""" + return os.environ.get("SLURM_JOB_USER") or os.environ.get("USER", "") or "" + + +def job_id_from_env() -> str: + """Read job id from SLURM_ARRAY_JOB_ID or SLURM_JOB_ID env.""" + return os.environ.get("SLURM_ARRAY_JOB_ID") or os.environ.get("SLURM_JOB_ID", "") or "" + + def is_process_alive(pid): try: process = psutil.Process(pid) diff --git a/src/nvidia_resiliency_ext/shared_utils/health_check.py b/src/nvidia_resiliency_ext/shared_utils/health_check.py index 886805f5..50d5a99e 100644 --- a/src/nvidia_resiliency_ext/shared_utils/health_check.py +++ b/src/nvidia_resiliency_ext/shared_utils/health_check.py @@ -25,12 +25,9 @@ import traceback from collections import defaultdict from functools import wraps -from typing import Any, Callable, Dict, Optional, Union -from urllib.parse import quote_plus +from typing import Callable, Dict, Optional, Union import defusedxml.ElementTree as ET -import httpx -from pydantic import BaseModel from nvidia_resiliency_ext.shared_utils.log_manager import LogConfig @@ -1321,100 +1318,3 @@ def _perform_health_check(self) -> bool: if self.paths: logger.debug("all storage paths accessible:\n" + "\n".join(self.paths)) return True - - -class AttrSvcResult(BaseModel): - result: Any - status: str = "completed" - - -class AttributionService: - """ - Client that queries an external attribution service to analyze artifacts (e.g., logs). - Behavior: - - POSTs to submit for log analysis - - GETs results by the last submitted log_path - """ - - def __init__( - self, - host: str, - port: int, - ): - self.host = host - self.port = port - # Track the most recent log_path we submitted - self._last_submitted: Optional[str] = None - - def __call__(self) -> None: - """ - Fire-and-forget entrypoint. GET results for the previously submitted log. - Runs in a background daemon thread. - - Note: _submit_log() should be called first (from launcher) to set _last_submitted. - """ - log_path = self._last_submitted - if log_path: - threading.Thread( - target=self._get_results, - args=(log_path,), - daemon=True, - ).start() - - def _submit_log(self, log_path: str) -> None: - """ - Submit a log file for analysis via POST. - Runs in a background daemon thread (fire-and-forget). - """ - self._last_submitted = log_path - threading.Thread( - target=self._do_submit_log, - args=(log_path,), - daemon=True, - ).start() - - def _do_submit_log(self, log_path: str) -> None: - """Perform the actual POST request (runs in background thread).""" - try: - with httpx.Client(timeout=10.0) as client: - url = f"http://{self.host}:{self.port}/logs" - logger.debug("AttributionService POST: %s (log_path=%s)", url, log_path) - client.post( - url, - json={"log_path": log_path}, - headers={"accept": "application/json"}, - ) - except Exception as e: - logger.warning( - "AttributionService POST %s failed: %s: %s", log_path, type(e).__name__, e - ) - - def _get_results(self, log_path: str) -> None: - """ - Get analysis results for a previously submitted log file via GET. - """ - try: - with httpx.Client(timeout=60.0) as client: - q_path = quote_plus(log_path) - url = f"http://{self.host}:{self.port}/logs?log_path={q_path}" - logger.debug("AttributionService GET: %s (log_path=%s)", url, log_path) - resp = client.get(url, headers={"accept": "application/json"}) - if resp.status_code == 200: - payload = resp.json() if resp.text else {} - result = payload.get("result", payload) - status = payload.get("status", "completed") - attrsvc_result = AttrSvcResult(result=result, status=status) - logger.info("AttrSvcResult for %s: status=%s", log_path, attrsvc_result.status) - logger.info( - "AttrSvcResult for %s: result preview: %s", - log_path, - str(attrsvc_result.result)[:200], - ) - else: - logger.warning( - "AttributionService GET for %s returned %d", log_path, resp.status_code - ) - except Exception as e: - logger.warning( - "AttributionService GET %s failed: %s: %s", log_path, type(e).__name__, e - )