From 35e5365822ec81b9302a4a2c0a134b4db46650ed Mon Sep 17 00:00:00 2001 From: namitdhameja Date: Sun, 24 May 2026 13:56:51 -0700 Subject: [PATCH] Add FACT dmesg attribution integration --- .../fault_tolerance/fact_node_attribution.rst | 171 +++ docs/source/fault_tolerance/index.rst | 3 +- docs/source/fault_tolerance/usage_guide.rst | 60 +- pyproject.toml | 1 + .../attribution/fact/__init__.py | 4 + .../attribution/fact/agent.py | 1015 +++++++++++++++++ .../attribution/fact/client.py | 439 +++++++ .../fact/fact_integration_design.md | 641 +++++++++++ .../attribution/fact/history_client.py | 200 ++++ .../attribution/fact/hot_cache.py | 69 ++ .../attribution/fact/manager.py | 311 +++++ .../attribution/fact/models.py | 38 + .../fact/repeat_offender_policy.py | 82 ++ .../attribution/fact/rpc.py | 66 ++ .../attribution/policy/__init__.py | 4 + .../fault_tolerance/__init__.py | 2 + .../fault_tolerance/cli_args.py | 131 ++- .../fault_tolerance/config.py | 125 +- .../fault_tolerance/ft_rendezvous_barrier.py | 65 +- .../fault_tolerance/launcher.py | 378 +++++- .../fault_tolerance/per_cycle_logs.py | 14 + .../shared_utils/health_check.py | 12 +- .../shared_utils/log_paths.py | 19 + tests/attribution/unit/test_fact_agent.py | 962 ++++++++++++++++ tests/attribution/unit/test_fact_client.py | 200 ++++ tests/attribution/unit/test_fact_manager.py | 199 ++++ tests/fault_tolerance/unit/test_config.py | 163 +++ tests/fault_tolerance/unit/test_launcher.py | 420 +++++++ .../unit/test_per_cycle_logs.py | 6 + tests/shared_utils/test_health_check.py | 23 +- 30 files changed, 5776 insertions(+), 47 deletions(-) create mode 100644 docs/source/fault_tolerance/fact_node_attribution.rst create mode 100644 src/nvidia_resiliency_ext/attribution/fact/__init__.py create mode 100644 src/nvidia_resiliency_ext/attribution/fact/agent.py create mode 100644 src/nvidia_resiliency_ext/attribution/fact/client.py create mode 100644 src/nvidia_resiliency_ext/attribution/fact/fact_integration_design.md create mode 100644 src/nvidia_resiliency_ext/attribution/fact/history_client.py create mode 100644 src/nvidia_resiliency_ext/attribution/fact/hot_cache.py create mode 100644 src/nvidia_resiliency_ext/attribution/fact/manager.py create mode 100644 src/nvidia_resiliency_ext/attribution/fact/models.py create mode 100644 src/nvidia_resiliency_ext/attribution/fact/repeat_offender_policy.py create mode 100644 src/nvidia_resiliency_ext/attribution/fact/rpc.py create mode 100644 src/nvidia_resiliency_ext/attribution/policy/__init__.py create mode 100644 src/nvidia_resiliency_ext/shared_utils/log_paths.py create mode 100644 tests/attribution/unit/test_fact_agent.py create mode 100644 tests/attribution/unit/test_fact_client.py create mode 100644 tests/attribution/unit/test_fact_manager.py diff --git a/docs/source/fault_tolerance/fact_node_attribution.rst b/docs/source/fault_tolerance/fact_node_attribution.rst new file mode 100644 index 00000000..3e8970fe --- /dev/null +++ b/docs/source/fault_tolerance/fact_node_attribution.rst @@ -0,0 +1,171 @@ +FACT dmesg evidence collection +============================== + +Use this feature to let FACT inspect recent host ``dmesg`` after a failed +fault-tolerance cycle. FACT filters the logs for node-level symptoms such as +XIDs, NVIDIA driver messages, and kernel faults, then returns suspect or faulty +nodes. + +``ft_launcher`` starts one local ``nvrx-fact-agent`` per node when +``--ft-fact-url`` is set. On a failed cycle, the launcher sends a local UDS +notification to the agent and continues after ACK. FACT output is node-level +evidence; it does not by itself stop the job or change placement. + +``--ft-attribution-endpoint`` is the separate application-log path for +job-level restart recommendations such as ``STOP`` or ``RESTART``. It is not +used for FACT dmesg submission. + +Configuration +------------- + +The current-cycle FACT path is enabled by ``--ft-fact-url``. Artifact output is +optional: + +* ``health_logging.prefix`` or ``--ft-health-log-prefix`` sets the absolute + output prefix. +* ``health_logging.dmesg.enabled`` or ``--ft-enable-health-log-dmesg`` queues + the collected dmesg window to a shared per-cycle dmesg file. +* ``health_logging.fact_result.enabled`` or + ``--ft-enable-fact-result-artifact`` queues per-node FACT submission records + and the store-host FACT result record to a shared JSONL file. + +When either artifact is enabled, the launcher gRPC log funnel must also be +enabled with ``--ft-per-cycle-applog-prefix`` and +``--ft-enable-log-server true``. The root log server is the only writer for the +shared FACT artifacts. + +YAML configuration: + +.. code-block:: yaml + + fault_tolerance: + fact_url: http://fact.example.internal:8001/latest + health_logging: + prefix: /lustre/logs/job_health.log + dmesg: + enabled: true + fact_result: + enabled: true + +Equivalent launcher flags: + +.. code-block:: bash + + ft_launcher \ + --ft-per-cycle-applog-prefix /lustre/logs/train.log \ + --ft-enable-log-server true \ + --ft-fact-url http://fact.example.internal:8001/latest \ + --ft-health-log-prefix /lustre/logs/job_health.log \ + --ft-enable-health-log-dmesg true \ + --ft-enable-fact-result-artifact true \ + ... + +``--ft-fact-url`` accepts either the FACT service root or the FACT API root +(``/latest``). + +Cycle Flow +---------- + +One failed-cycle notification maps to one evidence collection attempt: + +* After workers are stopped, the launcher sends ``cycle_failed`` to the local + ``nvrx-fact-agent`` over UDS with the failed cycle id and cycle start + timestamp. +* The agent ACKs immediately, collects a bounded recent dmesg window, queues + the optional dmesg artifact, and POSTs the collected text to FACT. +* The FACT workload ``job_start_time`` uses the actual cycle start timestamp. + The dmesg observation window remains the recent collection window. +* TCPStore carries only ``attributor_id`` and completion count. +* Completion count means a node reached any terminal local outcome: successful + FACT submission, empty dmesg, collection failure, or FACT POST failure. It + does not mean FACT accepted evidence from that node. +* The store-host agent waits for completion count or a deadline, performs the + FACT GET, and queues the optional result artifact. + +The default dmesg window is 12 minutes so NCCL timeout cases, where the +interesting kernel event may be roughly 10 minutes old, are still in scope. + +The UDS ACK only means the local agent accepted the request. Dmesg collection, +FACT POST/GET, TCPStore completion, and gRPC artifact drain are best-effort. +They may be missing or partial if the launcher, agent, store-host, or gRPC log +funnel exits before they finish. When FACT ingestion succeeds, FACT / +Elasticsearch is the durable observability source; local artifacts are +postmortem evidence. + +Artifacts +--------- + +For cycle ``N``, artifact paths are derived from the health-log prefix: + +.. code-block:: text + + /lustre/logs/job_health.log -> /lustre/logs/job_health_dmesg_cycleN.log + /lustre/logs/job_health.log -> /lustre/logs/job_health_fact_cycleN.log + +The dmesg artifact is one shared file per failed cycle. Production collection +prefixes each dmesg line with the source node name, so per-node inspection can +filter the shared file by node. + +The result artifact is JSONL. Leaf agents queue one record for their local FACT +submission result, including ``observation_id`` when FACT returns one. The +store-host agent queues a record containing the full ``FactAttributionResult`` +plus ``run_id``, ``cycle``, ``job_id``, expected/completed node counts, and the +history-policy ``avoid_nodes`` decision when available. + +All artifact records and dmesg chunks go through the launcher gRPC log funnel. +Record/chunk order is not guaranteed, but one queued chunk should not interleave +with another. FACT submission does not read these files; the service receives +the collected contents directly. + +FACT History and Node Reuse +--------------------------- + +The current-cycle FACT result answers: which nodes look suspect for this failed +cycle? FACT history answers: has the same node appeared as a suspect node in +recent FACT records? + +The controls above enable current-cycle dmesg collection, FACT submission, and +optional artifacts. They do not, by themselves, make a job-level ``STOP`` or +``RESTART`` decision. For node reuse, NVRx only evaluates nodes that FACT marks +suspect in the current cycle. For those nodes, it combines a short-lived +in-memory NVRx hot cache with optional durable FACT history. For example, if the +current failed cycle implicates ``node-a`` and an earlier cycle in the same +NVRx run also implicated ``node-a``, NVRx may avoid assigning active ranks to +``node-a`` on the next retry when enough other joined nodes are available. + +Durable FACT history is optional and extends the same decision with failures +that happened before this NVRx process started: + +.. code-block:: text + + --ft-fact-history-es-url + --ft-fact-history-es-auth-file + +The behavioral defaults should normally be left unchanged. The concrete FACT +history index or backend collection is deployment-specific and should be +provided by the FACT deployment. + +.. code-block:: text + + fact_history_lookback = 14d + fact_history_max_candidate_nodes = 16 + fact_history_query_timeout = 30s + fact_policy_ready_timeout = 60s + min_repeat_count_for_avoid = 2 + max_attribution_avoids_per_cycle = 1 + +This is not a hard exclusion and it is not a job-level ``STOP`` decision. The +policy must fail open: if FACT attribution, FACT history, or the local policy +answer is unavailable, late, or ambiguous, rendezvous proceeds without avoiding +nodes based on FACT repeat history. Concrete health-check failures and nodes +that cannot rejoin remain the immediate hard exclusion inputs. + +See Also +-------- + +* :doc:`usage_guide` for the rest of the launcher workflow. +* :ref:`fault-tolerance-attribution-service` for application-log restart + recommendations. +* :doc:`api/config` for the ``FaultToleranceConfig`` schema. +* ``src/nvidia_resiliency_ext/attribution/fact/fact_integration_design.md`` for + the internal FACT agent, history, and avoid-node design. diff --git a/docs/source/fault_tolerance/index.rst b/docs/source/fault_tolerance/index.rst index 93e520d7..62d909dd 100644 --- a/docs/source/fault_tolerance/index.rst +++ b/docs/source/fault_tolerance/index.rst @@ -17,6 +17,7 @@ The ``nvidia-resiliency-ext`` package also includes the PTL callback ``FaultTole :caption: Contents: usage_guide + fact_node_attribution integration api - examples \ No newline at end of file + examples diff --git a/docs/source/fault_tolerance/usage_guide.rst b/docs/source/fault_tolerance/usage_guide.rst index b2dc8585..8774dd61 100644 --- a/docs/source/fault_tolerance/usage_guide.rst +++ b/docs/source/fault_tolerance/usage_guide.rst @@ -199,8 +199,10 @@ Validation behavior: - Other existing types (e.g., devices/symlinks): performs ``stat`` access -Attribution service integration -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. _fault-tolerance-attribution-service: + +Application-log attribution service integration +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Per-cycle application logs do not enable attribution by themselves. To enable attribution, set ``--ft-attribution-endpoint``. The endpoint value ``localhost`` makes ``ft_launcher`` run the @@ -271,6 +273,58 @@ service dependencies. --ft-attribution-endpoint http://attribution.service.internal:8000 \ train.py +FACT dmesg evidence collection +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +FACT host-evidence collection is separate from the application-log service +above. ``--ft-fact-url`` enables a launcher-managed local ``nvrx-fact-agent`` on +each node. On failed cycles, the launcher notifies the agent over a private UDS; +the agent collects a bounded recent ``dmesg`` window, submits it to FACT, and +continues best-effort if FACT is unavailable. + +To also write postmortem artifacts, set an absolute health-log prefix and +enable the desired artifacts. The FACT result JSONL artifact requires launcher +gRPC log aggregation because the root log server is the single writer for the +shared result file. + +YAML configuration: + +.. code-block:: yaml + + fault_tolerance: + fact_url: http://fact.example.internal:8001/latest + health_logging: + prefix: /lustre/logs/job_health.log + dmesg: + enabled: true + fact_result: + enabled: true + +When either artifact is enabled, also provide the launcher log-funnel flags or +equivalent env args: ``--ft-per-cycle-applog-prefix`` and +``--ft-enable-log-server true``. + +Equivalent launcher flags: + +.. code-block:: bash + + ft_launcher \ + --ft-per-cycle-applog-prefix /lustre/logs/train.log \ + --ft-enable-log-server true \ + --ft-fact-url http://fact.example.internal:8001/latest \ + --ft-health-log-prefix /lustre/logs/job_health.log \ + --ft-enable-health-log-dmesg true \ + --ft-enable-fact-result-artifact true \ + train.py + +``--ft-fact-url`` accepts either the FACT service root or API root +(``/latest``). The artifact flags only control optional postmortem artifacts; +FACT consumes POSTed dmesg contents, not the written log path. + +See :doc:`fact_node_attribution` for output file naming, FACT result behavior, +and the distinction between FACT node findings and job-level restart +recommendations. + GPU Memory Reclaim ^^^^^^^^^^^^^^^^^^ @@ -294,6 +348,8 @@ the tolerance threshold or the timeout is reached. Memory statistics for each GP and logged after the reclaim process completes. If the timeout is reached, an error is logged but the restart proceeds as a best effort. +.. _fault-tolerance-per-cycle-logging: + Per-cycle logging and gRPC log aggregation ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/pyproject.toml b/pyproject.toml index 0c5bc325..8215d8f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,7 @@ nvrx-control = "nvidia_resiliency_ext.fault_tolerance.control_plane:main" nvrx-mcp-analysis = "nvidia_resiliency_ext.attribution.mcp_integration.server_launcher:main" nvrx-attrsvc = "nvidia_resiliency_ext.services.attrsvc.__main__:main" nvrx-smonsvc = "nvidia_resiliency_ext.services.smonsvc.__main__:main" +nvrx-fact-agent = "nvidia_resiliency_ext.attribution.fact.agent:main" [tool.poetry.extras] attribution = [ diff --git a/src/nvidia_resiliency_ext/attribution/fact/__init__.py b/src/nvidia_resiliency_ext/attribution/fact/__init__.py new file mode 100644 index 00000000..31ea4c13 --- /dev/null +++ b/src/nvidia_resiliency_ext/attribution/fact/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""FACT attribution client and NVRx agent integration.""" diff --git a/src/nvidia_resiliency_ext/attribution/fact/agent.py b/src/nvidia_resiliency_ext/attribution/fact/agent.py new file mode 100644 index 00000000..ca9b792b --- /dev/null +++ b/src/nvidia_resiliency_ext/attribution/fact/agent.py @@ -0,0 +1,1015 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import argparse +import contextlib +import json +import logging +import os +import queue +import random +import signal +import socket +import threading +import time +from concurrent.futures import ThreadPoolExecutor +from dataclasses import asdict, dataclass, replace +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Any, Callable, Optional + +from torch.distributed import TCPStore +from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint + +from nvidia_resiliency_ext.attribution.fact.client import ( + FactAttributionResult, + FactAttributionService, + collect_recent_dmesg_text, +) +from nvidia_resiliency_ext.attribution.fact.history_client import ( + FactHistoryClient, + parse_duration, +) +from nvidia_resiliency_ext.attribution.fact.hot_cache import FactHotCache +from nvidia_resiliency_ext.attribution.fact.models import AvoidDecision +from nvidia_resiliency_ext.attribution.fact.repeat_offender_policy import ( + compute_repeat_offender_decision, +) +from nvidia_resiliency_ext.attribution.fact.rpc import ( + DEFAULT_MAX_RPC_BYTES, + default_socket_path, + recv_frame, + send_frame, +) +from nvidia_resiliency_ext.shared_utils.log_manager import LogConfig, setup_logger +from nvidia_resiliency_ext.shared_utils.log_paths import get_source_cycle_log_file + +logger = logging.getLogger(LogConfig.name) + +DEFAULT_DMESG_WINDOW_S = 12.0 * 60.0 +DEFAULT_OBSERVATION_DEADLINE_S = 30.0 +DEFAULT_STORE_TIMEOUT_S = 60.0 +DEFAULT_FACT_HISTORY_LOOKBACK = "14d" +DEFAULT_FACT_HISTORY_MAX_CANDIDATE_NODES = 16 +DEFAULT_FACT_HISTORY_QUERY_TIMEOUT_S = 30.0 +DEFAULT_FACT_MIN_REPEAT_COUNT_FOR_AVOID = 2 +DEFAULT_FACT_MAX_ATTRIBUTION_AVOIDS_PER_CYCLE = 1 +_ATTRIBUTOR_FAILURE_PREFIX = "__nvrx_fact_attributor_failed__:" +_POST_RETRY_INITIAL_DELAY_S = 0.25 +_POST_RETRY_MAX_DELAY_S = 2.0 +_POST_RETRY_MIN_REMAINING_S = 0.5 +_GRPC_RESULT_DRAIN_TIMEOUT_S = 4.0 + + +@dataclass(frozen=True) +class FactAgentRequest: + run_id: str + cycle: int + rdzv_endpoint: str + local_node: str + is_store_host: bool = False + store_timeout_s: float = DEFAULT_STORE_TIMEOUT_S + job_id: Optional[str] = None + expected_nodes: tuple[str, ...] = () + ranks_per_node: int = 1 + cycle_start_time: Optional[datetime] = None + cycle_end_time: Optional[datetime] = None + dmesg_path: Optional[str] = None + result_path: Optional[str] = None + grpc_server_address: Optional[str] = None + grpc_node_id: Optional[str] = None + + @classmethod + def from_payload( + cls, + payload: dict[str, Any], + *, + run_id: Optional[str] = None, + rdzv_endpoint: Optional[str] = None, + local_node: Optional[str] = None, + is_store_host: bool = False, + store_timeout_s: float = DEFAULT_STORE_TIMEOUT_S, + job_id: Optional[str] = None, + ranks_per_node: int = 1, + cycle_start_time: Optional[datetime] = None, + cycle_end_time: Optional[datetime] = None, + dmesg_path: Optional[str] = None, + result_path: Optional[str] = None, + grpc_server_address: Optional[str] = None, + grpc_node_id: Optional[str] = None, + ) -> "FactAgentRequest": + if payload.get("event") != "cycle_failed": + raise ValueError("unsupported FACT agent event") + resolved_run_id = str(run_id or "").strip() + resolved_rdzv_endpoint = str(rdzv_endpoint or "").strip() + resolved_local_node = str(local_node or socket.getfqdn(socket.gethostname())) + if not resolved_run_id: + raise ValueError("cycle_failed requires run_id") + if not resolved_rdzv_endpoint: + raise ValueError("cycle_failed requires rdzv_endpoint") + raw_cycle = payload.get("cycle", payload.get("cycle_id")) + if raw_cycle is None: + raise ValueError("cycle_failed requires cycle") + cycle = int(raw_cycle) + expected_nodes_raw = payload.get("expected_nodes") or [] + if not isinstance(expected_nodes_raw, list): + raise ValueError("expected_nodes must be a list when provided") + expected_nodes = tuple(str(node) for node in expected_nodes_raw if str(node)) + return cls( + run_id=resolved_run_id, + cycle=cycle, + rdzv_endpoint=resolved_rdzv_endpoint, + local_node=resolved_local_node, + is_store_host=bool(is_store_host), + store_timeout_s=float(store_timeout_s), + job_id=str(job_id or resolved_run_id), + expected_nodes=expected_nodes, + ranks_per_node=max(1, int(ranks_per_node)), + cycle_start_time=cls._parse_datetime( + payload.get("cycle_start_time"), + fallback=cycle_start_time, + ), + cycle_end_time=cls._parse_datetime( + payload.get("cycle_end_time"), + fallback=cycle_end_time, + ), + dmesg_path=dmesg_path, + result_path=result_path, + grpc_server_address=grpc_server_address, + grpc_node_id=grpc_node_id, + ) + + @staticmethod + def _parse_datetime(value: Any, *, fallback: Optional[datetime] = None) -> Optional[datetime]: + if value is None or value == "": + return fallback + if isinstance(value, datetime): + return value + if isinstance(value, str): + text = value.strip() + if text.endswith("Z"): + text = f"{text[:-1]}+00:00" + parsed = datetime.fromisoformat(text) + if parsed.tzinfo is None: + return parsed.replace(tzinfo=timezone.utc) + return parsed + raise ValueError("cycle_start_time must be an ISO-8601 datetime when provided") + + +class FactAgentKeys: + def __init__(self, run_id: str, cycle: int) -> None: + self.prefix = f"fact_agent:{run_id}:cycle{cycle}" + + @property + def attributor_id(self) -> str: + return f"{self.prefix}:attributor_id" + + @property + def done_count(self) -> str: + return f"{self.prefix}:done_count" + + +StoreFactory = Callable[[FactAgentRequest], Any] +FactClientFactory = Callable[[], FactAttributionService] +FactHistoryClientFactory = Callable[[], FactHistoryClient] +DmesgCollector = Callable[[float, str], str] +GrpcWriterFactory = Callable[[queue.Queue, str, str, logging.Logger], threading.Thread] + + +class FactAgent: + def __init__( + self, + *, + fact_url: str, + socket_path: Optional[str] = None, + dmesg_window_s: float = DEFAULT_DMESG_WINDOW_S, + observation_deadline_s: float = DEFAULT_OBSERVATION_DEADLINE_S, + fact_timeout_s: float = 60.0, + run_id: Optional[str] = None, + rdzv_endpoint: Optional[str] = None, + store_timeout_s: float = DEFAULT_STORE_TIMEOUT_S, + local_node: Optional[str] = None, + is_store_host: bool = False, + job_id: Optional[str] = None, + ranks_per_node: int = 1, + username: Optional[str] = None, + cluster: Optional[str] = None, + health_log_prefix: Optional[str] = None, + dmesg_artifact_enabled: bool = False, + result_artifact_enabled: bool = False, + grpc_server_address: Optional[str] = None, + grpc_node_id: Optional[str] = None, + fact_history_es_url: Optional[str] = None, + fact_history_es_auth_file: Optional[str] = None, + fact_history_lookback: str = DEFAULT_FACT_HISTORY_LOOKBACK, + fact_history_index: Optional[str] = None, + fact_history_max_candidate_nodes: int = DEFAULT_FACT_HISTORY_MAX_CANDIDATE_NODES, + fact_history_query_timeout_s: float = DEFAULT_FACT_HISTORY_QUERY_TIMEOUT_S, + fact_min_repeat_count_for_avoid: int = DEFAULT_FACT_MIN_REPEAT_COUNT_FOR_AVOID, + fact_max_attribution_avoids_per_cycle: int = ( + DEFAULT_FACT_MAX_ATTRIBUTION_AVOIDS_PER_CYCLE + ), + store_factory: Optional[StoreFactory] = None, + fact_client_factory: Optional[FactClientFactory] = None, + fact_history_client_factory: Optional[FactHistoryClientFactory] = None, + dmesg_collector: Optional[DmesgCollector] = None, + grpc_writer_factory: Optional[GrpcWriterFactory] = None, + ) -> None: + self.fact_url = fact_url + self.socket_path = socket_path or default_socket_path() + self.dmesg_window_s = dmesg_window_s + self.observation_deadline_s = observation_deadline_s + self.fact_timeout_s = fact_timeout_s + self.run_id = str(run_id).strip() if run_id else None + self.rdzv_endpoint = str(rdzv_endpoint).strip() if rdzv_endpoint else None + self.store_timeout_s = float(store_timeout_s) + self.local_node = local_node or socket.getfqdn(socket.gethostname()) + self.is_store_host = bool(is_store_host) + self.job_id = str(job_id).strip() if job_id else None + self.ranks_per_node = max(1, int(ranks_per_node)) + self.username = str(username).strip() if username else None + self.cluster = str(cluster).strip() if cluster else None + self.health_log_prefix = str(health_log_prefix).strip() if health_log_prefix else None + self.dmesg_artifact_enabled = bool(dmesg_artifact_enabled) + self.result_artifact_enabled = bool(result_artifact_enabled) + self.grpc_server_address = str(grpc_server_address).strip() if grpc_server_address else None + self.grpc_node_id = str(grpc_node_id).strip() if grpc_node_id else None + self.fact_history_es_url = str(fact_history_es_url).strip() if fact_history_es_url else None + self.fact_history_es_auth_file = ( + str(fact_history_es_auth_file).strip() if fact_history_es_auth_file else None + ) + self.fact_history_lookback = fact_history_lookback + self.fact_history_index = str(fact_history_index).strip() if fact_history_index else None + self.fact_history_max_candidate_nodes = max(1, int(fact_history_max_candidate_nodes)) + self.fact_history_query_timeout_s = max(0.1, float(fact_history_query_timeout_s)) + self.fact_min_repeat_count_for_avoid = max(1, int(fact_min_repeat_count_for_avoid)) + self.fact_max_attribution_avoids_per_cycle = max( + 0, + int(fact_max_attribution_avoids_per_cycle), + ) + self._store_factory = store_factory or self._connect_tcp_store + self._fact_client_factory = fact_client_factory or self._new_fact_client + self._fact_history_client_factory = ( + fact_history_client_factory or self._new_fact_history_client + ) + self._dmesg_collector = dmesg_collector or self._collect_dmesg + self._grpc_writer_factory = grpc_writer_factory or self._new_grpc_writer + self._grpc_writers: dict[tuple[str, str], tuple[queue.Queue, threading.Thread]] = {} + self._grpc_writers_lock = threading.Lock() + self._executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="nvrx-fact-agent") + self._stop_event = threading.Event() + self._hot_cache = FactHotCache() + self._avoid_decisions: dict[int, AvoidDecision] = {} + self._avoid_decisions_lock = threading.Lock() + + def _new_fact_client(self) -> FactAttributionService: + return FactAttributionService( + url=self.fact_url, + timeout_s=self.fact_timeout_s, + ) + + def _new_fact_history_client(self) -> FactHistoryClient: + if not self.fact_history_es_url or not self.fact_history_es_auth_file: + raise RuntimeError("FACT history is not configured") + return FactHistoryClient( + es_url=self.fact_history_es_url, + auth_file=self.fact_history_es_auth_file, + index=self.fact_history_index, + timeout_s=self.fact_history_query_timeout_s, + ) + + def _connect_tcp_store(self, request: FactAgentRequest) -> TCPStore: + host, port = parse_rendezvous_endpoint(request.rdzv_endpoint, default_port=-1) + if not host or port == -1: + raise ValueError(f"invalid rendezvous endpoint: {request.rdzv_endpoint!r}") + return TCPStore( + host, + port, + is_master=False, + timeout=timedelta(seconds=max(1.0, request.store_timeout_s)), + multi_tenant=True, + ) + + @staticmethod + def _collect_dmesg(window_s: float, local_node: str) -> str: + return collect_recent_dmesg_text(window_s=window_s, hostname=local_node) + + @staticmethod + def _new_grpc_writer( + write_queue: queue.Queue, + grpc_server_address: str, + node_id: str, + logger: logging.Logger, + ) -> threading.Thread: + from nvidia_resiliency_ext.fault_tolerance.per_cycle_logs import GrpcWriterThread + + return GrpcWriterThread( + write_queue=write_queue, + grpc_server_address=grpc_server_address, + node_id=node_id, + logger=logger, + ) + + def handle_payload(self, payload: dict[str, Any]) -> dict[str, Any]: + if payload.get("event") == "ping": + return {"accepted": True} + if payload.get("event") == "shutdown": + self.request_stop() + return {"accepted": True} + if payload.get("event") == "get_avoid_nodes": + return self._handle_get_avoid_nodes(payload) + request = self._request_from_payload(payload) + self._executor.submit(self.process_cycle_failed, request) + return {"accepted": True} + + def _handle_get_avoid_nodes(self, payload: dict[str, Any]) -> dict[str, Any]: + if not self.is_store_host: + return {"status": "skipped", "avoid_nodes": []} + raw_cycle = payload.get("cycle", payload.get("cycle_id")) + if raw_cycle is None: + return {"status": "skipped", "avoid_nodes": []} + cycle = int(raw_cycle) + with self._avoid_decisions_lock: + decision = self._avoid_decisions.get(cycle) + if decision is None: + return {"cycle_id": str(cycle), "status": "pending", "avoid_nodes": []} + return { + "cycle_id": str(cycle), + "status": decision.status, + "avoid_nodes": list(decision.avoid_nodes), + } + + def _request_from_payload(self, payload: dict[str, Any]) -> FactAgentRequest: + request = FactAgentRequest.from_payload( + payload, + run_id=self.run_id, + rdzv_endpoint=self.rdzv_endpoint, + local_node=self.local_node, + is_store_host=self.is_store_host, + store_timeout_s=self.store_timeout_s, + job_id=self.job_id, + ranks_per_node=self.ranks_per_node, + grpc_server_address=self.grpc_server_address, + grpc_node_id=self.grpc_node_id, + ) + dmesg_path = request.dmesg_path + result_path = request.result_path + if self.health_log_prefix: + if self.dmesg_artifact_enabled and request.grpc_server_address and not dmesg_path: + dmesg_path = get_source_cycle_log_file( + self.health_log_prefix, + "dmesg", + request.cycle, + ) + if self.result_artifact_enabled and request.grpc_server_address and not result_path: + result_path = get_source_cycle_log_file( + self.health_log_prefix, + "fact", + request.cycle, + ) + if dmesg_path != request.dmesg_path or result_path != request.result_path: + return replace(request, dmesg_path=dmesg_path, result_path=result_path) + return request + + def process_cycle_failed(self, request: FactAgentRequest) -> None: + try: + store = self._store_factory(request) + except Exception as exc: + logger.warning("FACT agent failed to connect to TCPStore: %s", exc) + return + keys = FactAgentKeys(request.run_id, request.cycle) + if request.is_store_host: + self._process_store_host(request, store, keys) + else: + self._submit_local_evidence(request, store, keys) + + def _process_store_host( + self, request: FactAgentRequest, store: Any, keys: FactAgentKeys + ) -> None: + expected_nodes = list(dict.fromkeys(request.expected_nodes or (request.local_node,))) + nranks = request.ranks_per_node * max(1, len(expected_nodes)) + end_time = datetime.now(timezone.utc) + workload_start_time = request.cycle_start_time or end_time - timedelta( + seconds=max(0.0, self.dmesg_window_s) + ) + try: + service = self._fact_client_factory() + attributor_id = service.create_failure_attributor( + job_id=request.job_id or request.run_id, + cycle_index=request.cycle, + nodes=expected_nodes, + ranks_per_node=request.ranks_per_node, + nranks=nranks, + start_time=workload_start_time, + end_time=end_time, + username=self.username, + cluster=self.cluster, + ) + store.set(keys.attributor_id, str(attributor_id).encode("utf-8")) + except Exception as exc: + logger.warning("FACT agent failed to create FACT attributor: %s", exc) + self._store_avoid_decision(AvoidDecision(cycle_id=request.cycle, status="skipped")) + self._publish_attributor_failure(store, keys, exc) + self._write_result_record( + request, + { + "record_type": "fact_result", + "status": "failed", + "phase": "create_attributor", + "run_id": request.run_id, + "cycle": request.cycle, + "job_id": request.job_id or request.run_id, + "error": str(exc), + }, + ) + return + + self._submit_local_evidence(request, store, keys) + completed_count = self._wait_for_completion_count(store, keys, len(expected_nodes)) + try: + result = service.get_attribution_result( + attributor_id=str(attributor_id), + observation_ids=[], + ) + except Exception as exc: + logger.warning("FACT agent attribution GET failed: %s", exc) + self._store_avoid_decision(AvoidDecision(cycle_id=request.cycle, status="skipped")) + self._write_result_record( + request, + { + "record_type": "fact_result", + "status": "failed", + "phase": "get_attribution", + "run_id": request.run_id, + "cycle": request.cycle, + "job_id": request.job_id or request.run_id, + "attributor_id": str(attributor_id), + "completed_node_count": completed_count, + "expected_node_count": len(expected_nodes), + "error": str(exc), + }, + ) + return + avoid_decision = self._compute_avoid_decision(request, result) + self._write_result_record( + request, + { + "record_type": "fact_result", + "status": "complete", + "run_id": request.run_id, + "cycle": request.cycle, + "job_id": request.job_id or request.run_id, + "expected_node_count": len(expected_nodes), + "completed_node_count": completed_count, + "avoid_nodes": list(avoid_decision.avoid_nodes), + **self._result_payload(result), + }, + ) + logger.info( + "FACT attribution completed for run_id=%s cycle=%s " + "completed_nodes=%s expected_nodes=%s faulty_nodes=%s", + request.run_id, + request.cycle, + completed_count, + len(expected_nodes), + result.faulty_nodes, + ) + + def _compute_avoid_decision( + self, + request: FactAgentRequest, + result: FactAttributionResult, + ) -> AvoidDecision: + cluster = self.cluster or "unknown" + job_id = request.job_id or request.run_id + current_suspects = sorted({str(node) for node in result.faulty_nodes if str(node)}) + cycle_end_time = request.cycle_end_time or datetime.now(timezone.utc) + + if not current_suspects: + decision = AvoidDecision(cycle_id=request.cycle, status="skipped") + self._store_avoid_decision(decision) + return decision + + if len(current_suspects) > self.fact_history_max_candidate_nodes: + decision = AvoidDecision(cycle_id=request.cycle, status="skipped") + self._store_avoid_decision(decision) + return decision + + history_records = [] + history_end_time = request.cycle_start_time or cycle_end_time + if self.fact_history_es_url and self.fact_history_es_auth_file: + try: + lookback = parse_duration( + self.fact_history_lookback, + default=timedelta(days=14), + ) + history_records = self._fact_history_client_factory().query_node_history( + cluster=cluster, + nodes=current_suspects, + start_time=history_end_time - lookback, + end_time=history_end_time, + ) + except Exception as exc: + logger.warning( + "FACT history query failed; no avoid_nodes for cycle %s: %s", + request.cycle, + exc, + ) + decision = AvoidDecision(cycle_id=request.cycle, status="skipped") + self._store_avoid_decision(decision) + self._hot_cache.add_current_cycle( + cluster=cluster, + nodes=current_suspects, + job_id=job_id, + cycle_id=request.cycle, + event_time=cycle_end_time, + ) + return decision + + hot_records = self._hot_cache.records_for( + cluster=cluster, + nodes=current_suspects, + before=cycle_end_time, + ) + decision = compute_repeat_offender_decision( + cycle_id=request.cycle, + current_suspect_nodes=current_suspects, + history_records=history_records, + hot_cache_records=hot_records, + max_candidate_nodes=self.fact_history_max_candidate_nodes, + min_repeat_count_for_avoid=self.fact_min_repeat_count_for_avoid, + max_avoids_per_cycle=self.fact_max_attribution_avoids_per_cycle, + ) + self._store_avoid_decision(decision) + self._hot_cache.add_current_cycle( + cluster=cluster, + nodes=current_suspects, + job_id=job_id, + cycle_id=request.cycle, + event_time=cycle_end_time, + ) + logger.info( + "FACT avoid decision for run_id=%s cycle=%s status=%s avoid_nodes=%s", + request.run_id, + request.cycle, + decision.status, + decision.avoid_nodes, + ) + return decision + + def _store_avoid_decision(self, decision: AvoidDecision) -> None: + with self._avoid_decisions_lock: + self._avoid_decisions[decision.cycle_id] = decision + + def _submit_local_evidence( + self, request: FactAgentRequest, store: Any, keys: FactAgentKeys + ) -> None: + operation_deadline = time.monotonic() + max(0.0, self.observation_deadline_s) + node = request.local_node + status: dict[str, Any] = { + "record_type": "fact_observation", + "run_id": request.run_id, + "cycle": request.cycle, + "job_id": request.job_id or request.run_id, + "node": node, + "source": "dmesg", + "status": "skipped", + "attributor_id": None, + "observation_id": None, + "lines_collected": 0, + "bytes_collected": 0, + "dmesg_path": "", + "dmesg_write_error": "", + "error": "", + } + + try: + dmesg_text = self._dmesg_collector(self.dmesg_window_s, node) + collection_end_time = datetime.now(timezone.utc) + status["lines_collected"] = len(dmesg_text.splitlines()) + status["bytes_collected"] = len(dmesg_text.encode("utf-8", errors="replace")) + except Exception as exc: + logger.warning("FACT agent failed to collect dmesg on %s: %s", node, exc) + status.update(status="collect_failed", error=str(exc)) + self._write_result_record(request, status) + self._write_terminal_completion(store, keys, node) + return + + if request.dmesg_path and dmesg_text: + try: + self._write_dmesg_artifact(request, request.dmesg_path, dmesg_text) + status["dmesg_path"] = request.dmesg_path + except Exception as exc: + logger.warning( + "FACT agent failed to write dmesg evidence %s: %s", + request.dmesg_path, + exc, + ) + status["dmesg_write_error"] = str(exc) + + try: + attributor_wait_s = request.store_timeout_s + if self.observation_deadline_s > 0: + remaining_s = max(0.001, operation_deadline - time.monotonic()) + attributor_wait_s = min(attributor_wait_s, remaining_s) + raw_attributor_id = self._store_get_bytes_with_deadline( + store, keys.attributor_id, attributor_wait_s + ) + if not raw_attributor_id: + raise RuntimeError("timed out waiting for attributor_id") + attributor_id = raw_attributor_id.decode("utf-8") + if attributor_id.startswith(_ATTRIBUTOR_FAILURE_PREFIX): + error = attributor_id[len(_ATTRIBUTOR_FAILURE_PREFIX) :] + raise RuntimeError(f"FACT attributor creation failed on store host: {error}") + status["attributor_id"] = attributor_id + except Exception as exc: + logger.warning("FACT agent could not read attributor_id on %s: %s", node, exc) + status.update(status="attributor_failed", error=str(exc)) + self._write_result_record(request, status) + self._write_terminal_completion(store, keys, node) + return + + end_time = collection_end_time + start_time = end_time - timedelta(seconds=max(0.0, self.dmesg_window_s)) + try: + observation_id = self._submit_dmesg_observation_with_retries( + attributor_id=attributor_id, + dmesg_text=dmesg_text, + start_time=start_time, + end_time=end_time, + default_hostname=node, + deadline_s=operation_deadline, + ) + if observation_id is None: + status["status"] = "empty" + else: + status.update(status="submitted", observation_id=observation_id) + except Exception as exc: + logger.warning("FACT agent failed to submit dmesg observation for %s: %s", node, exc) + status.update(status="post_failed", error=str(exc)) + self._write_result_record(request, status) + self._write_terminal_completion(store, keys, node) + + def _submit_dmesg_observation_with_retries( + self, + *, + attributor_id: str, + dmesg_text: str, + start_time: datetime, + end_time: datetime, + default_hostname: str, + deadline_s: float, + ) -> Any: + delay_s = _POST_RETRY_INITIAL_DELAY_S + attempt = 0 + while True: + attempt += 1 + try: + return self._fact_client_factory().submit_dmesg_text_observation( + attributor_id=attributor_id, + dmesg_text=dmesg_text, + start_time=start_time, + end_time=end_time, + default_hostname=default_hostname, + ) + except Exception: + remaining_s = deadline_s - time.monotonic() + if remaining_s <= _POST_RETRY_MIN_REMAINING_S: + raise + sleep_s = min( + delay_s * random.uniform(0.5, 1.5), + max(0.0, remaining_s - _POST_RETRY_MIN_REMAINING_S), + ) + if sleep_s <= 0: + raise + logger.info( + "FACT agent observation POST attempt %s failed for %s; " "retrying in %.2fs", + attempt, + default_hostname, + sleep_s, + exc_info=True, + ) + time.sleep(sleep_s) + delay_s = min(_POST_RETRY_MAX_DELAY_S, delay_s * 2.0) + + def _write_terminal_completion( + self, + store: Any, + keys: FactAgentKeys, + node: str, + ) -> None: + try: + store.add(keys.done_count, 1) + except Exception as exc: + logger.warning("FACT agent failed to publish completion count for %s: %s", node, exc) + + def _publish_attributor_failure( + self, + store: Any, + keys: FactAgentKeys, + exc: Exception, + ) -> None: + try: + store.set(keys.attributor_id, f"{_ATTRIBUTOR_FAILURE_PREFIX}{exc}".encode("utf-8")) + except Exception as store_exc: + logger.warning( + "FACT agent failed to publish attributor failure sentinel: %s", store_exc + ) + + def _wait_for_completion_count( + self, + store: Any, + keys: FactAgentKeys, + expected_node_count: int, + ) -> int: + deadline = time.monotonic() + max(0.0, self.observation_deadline_s) + completed_count = 0 + while time.monotonic() < deadline and completed_count < expected_node_count: + try: + completed_count = int(store.add(keys.done_count, 0)) + except Exception: + pass + if completed_count >= expected_node_count: + break + time.sleep(0.1) + return min(completed_count, expected_node_count) + + @staticmethod + def _store_get_bytes_with_deadline(store: Any, key: str, timeout_s: float) -> Optional[bytes]: + timeout_s = max(0.0, timeout_s) + wait_fn = getattr(store, "wait", None) + if callable(wait_fn): + try: + wait_fn([key], timedelta(seconds=max(0.001, timeout_s))) + return store.get(key) + except TypeError: + pass + except Exception: + return None + + deadline = time.monotonic() + max(0.0, timeout_s) + sleep_s = random.uniform(0.05, 0.15) + while True: + try: + if store.check([key]): + return store.get(key) + except Exception: + return None + remaining = deadline - time.monotonic() + if remaining <= 0: + return None + time.sleep(min(sleep_s, remaining)) + sleep_s = min(1.0, sleep_s * random.uniform(1.25, 1.75)) + + def _write_dmesg_artifact(self, request: FactAgentRequest, path: str, payload: str) -> None: + self._append_text_artifact(request, path, payload) + + def _append_text_artifact(self, request: FactAgentRequest, path: str, payload: str) -> None: + if not request.grpc_server_address: + raise RuntimeError("FACT artifact requires gRPC log aggregation") + self._enqueue_grpc_artifact( + request.grpc_server_address, + request.grpc_node_id or request.local_node, + path, + self._ensure_trailing_newline(payload), + ) + + def _write_result_record(self, request: FactAgentRequest, payload: dict[str, Any]) -> None: + if not request.result_path: + return + if not request.grpc_server_address: + logger.warning( + "FACT artifact requires gRPC log aggregation; skipping %s", + request.result_path, + ) + return + try: + text = json.dumps(payload, separators=(",", ":"), sort_keys=True) + "\n" + self._append_text_artifact(request, request.result_path, text) + except Exception as exc: + logger.warning( + "FACT agent failed to write result artifact %s: %s", request.result_path, exc + ) + + def _enqueue_grpc_artifact( + self, + grpc_server_address: str, + node_id: str, + path: str, + payload: str, + ) -> None: + writer_queue, writer = self._get_grpc_writer(grpc_server_address, node_id) + writer_queue.put((path, payload)) + self._wait_for_grpc_writer_queue(writer_queue, writer) + + def _get_grpc_writer( + self, + grpc_server_address: str, + node_id: str, + ) -> tuple[queue.Queue, threading.Thread]: + key = (grpc_server_address, node_id) + with self._grpc_writers_lock: + existing = self._grpc_writers.get(key) + if existing is not None: + return existing + + write_queue: queue.Queue = queue.Queue() + writer = self._grpc_writer_factory(write_queue, grpc_server_address, node_id, logger) + writer.start() + self._grpc_writers[key] = (write_queue, writer) + return write_queue, writer + + @staticmethod + def _ensure_trailing_newline(payload: str) -> str: + if payload and not payload.endswith("\n"): + return payload + "\n" + return payload + + @staticmethod + def _result_payload(result: FactAttributionResult) -> dict[str, Any]: + return {"fact_attribution_result": asdict(result)} + + def request_stop(self) -> None: + self._stop_event.set() + + def stop(self) -> None: + self._stop_event.set() + self._executor.shutdown(wait=True, cancel_futures=False) + with self._grpc_writers_lock: + writers = list(self._grpc_writers.values()) + self._grpc_writers.clear() + self._wait_for_grpc_writer_queues(writers) + for _, writer in writers: + shutdown = getattr(writer, "shutdown", None) + if callable(shutdown): + with contextlib.suppress(Exception): + shutdown() + for _, writer in writers: + with contextlib.suppress(Exception): + writer.join(timeout=5.0) + + @staticmethod + def _wait_for_grpc_writer_queues( + writers: list[tuple[queue.Queue, threading.Thread]], + ) -> None: + deadline = time.monotonic() + _GRPC_RESULT_DRAIN_TIMEOUT_S + for write_queue, writer in writers: + FactAgent._wait_for_grpc_writer_queue(write_queue, writer, deadline=deadline) + + @staticmethod + def _wait_for_grpc_writer_queue( + write_queue: queue.Queue, + writer: threading.Thread, + *, + deadline: Optional[float] = None, + ) -> None: + is_alive = getattr(writer, "is_alive", None) + if not callable(is_alive): + return + deadline = deadline or (time.monotonic() + _GRPC_RESULT_DRAIN_TIMEOUT_S) + while is_alive() and not write_queue.empty() and time.monotonic() < deadline: + time.sleep(0.05) + if is_alive() and not write_queue.empty(): + logger.warning( + "FACT result artifact gRPC queue still has %s records pending", + write_queue.qsize(), + ) + + def serve_forever(self, *, max_rpc_bytes: int = DEFAULT_MAX_RPC_BYTES) -> None: + socket_path = Path(self.socket_path) + if socket_path.exists(): + socket_path.unlink() + socket_path.parent.mkdir(parents=True, exist_ok=True) + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as server: + try: + server.bind(str(socket_path)) + os.chmod(socket_path, 0o600) + server.listen(128) + server.settimeout(1.0) + logger.info("nvrx-fact-agent listening on %s", socket_path) + while not self._stop_event.is_set(): + try: + conn, _ = server.accept() + except socket.timeout: + continue + threading.Thread( + target=self._handle_connection, + args=(conn, max_rpc_bytes), + daemon=True, + ).start() + finally: + self.stop() + with contextlib.suppress(FileNotFoundError): + socket_path.unlink() + + def _handle_connection(self, conn: socket.socket, max_rpc_bytes: int) -> None: + with conn: + try: + payload = recv_frame(conn, max_bytes=max_rpc_bytes) + ack = self.handle_payload(payload) + send_frame(conn, ack) + except Exception as exc: + logger.warning("nvrx-fact-agent rejected RPC: %s", exc) + with contextlib.suppress(Exception): + send_frame(conn, {"accepted": False, "error": str(exc)}) + + +def get_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(prog="nvrx-fact-agent") + parser.add_argument( + "--fact-url", required=True, help="FACT API URL, e.g. http://host:8001/latest" + ) + parser.add_argument("--socket-path", default=default_socket_path()) + parser.add_argument("--dmesg-window", type=float, default=DEFAULT_DMESG_WINDOW_S) + parser.add_argument( + "--observation-deadline", type=float, default=DEFAULT_OBSERVATION_DEADLINE_S + ) + parser.add_argument("--fact-timeout", type=float, default=60.0) + parser.add_argument("--run-id", default=None) + parser.add_argument("--rdzv-endpoint", default=None) + parser.add_argument("--store-timeout", type=float, default=DEFAULT_STORE_TIMEOUT_S) + parser.add_argument("--local-node", default=None) + parser.add_argument("--is-store-host", action="store_true") + parser.add_argument("--job-id", default=None) + parser.add_argument("--ranks-per-node", type=int, default=1) + parser.add_argument("--username", default=None) + parser.add_argument("--cluster", default=None) + parser.add_argument("--health-log-prefix", default=None) + parser.add_argument("--dmesg-artifact-enabled", action="store_true") + parser.add_argument("--result-artifact-enabled", action="store_true") + parser.add_argument("--grpc-server-address", default=None) + parser.add_argument("--grpc-node-id", default=None) + parser.add_argument("--fact-history-es-url", default=None) + parser.add_argument("--fact-history-es-auth-file", default=None) + parser.add_argument("--fact-history-lookback", default=DEFAULT_FACT_HISTORY_LOOKBACK) + parser.add_argument("--fact-history-index", default=None) + parser.add_argument( + "--fact-history-max-candidate-nodes", + type=int, + default=DEFAULT_FACT_HISTORY_MAX_CANDIDATE_NODES, + ) + parser.add_argument( + "--fact-history-query-timeout", + type=float, + default=DEFAULT_FACT_HISTORY_QUERY_TIMEOUT_S, + ) + parser.add_argument( + "--fact-min-repeat-count-for-avoid", + type=int, + default=DEFAULT_FACT_MIN_REPEAT_COUNT_FOR_AVOID, + ) + parser.add_argument( + "--fact-max-attribution-avoids-per-cycle", + type=int, + default=DEFAULT_FACT_MAX_ATTRIBUTION_AVOIDS_PER_CYCLE, + ) + parser.add_argument("--max-rpc-bytes", type=int, default=DEFAULT_MAX_RPC_BYTES) + return parser + + +def main(argv: Optional[list[str]] = None) -> None: + args = get_arg_parser().parse_args(argv) + setup_logger(node_local_tmp_prefix="nvrxfactagent") + service = FactAgent( + fact_url=args.fact_url, + socket_path=args.socket_path, + dmesg_window_s=args.dmesg_window, + observation_deadline_s=args.observation_deadline, + fact_timeout_s=args.fact_timeout, + run_id=args.run_id, + rdzv_endpoint=args.rdzv_endpoint, + store_timeout_s=args.store_timeout, + local_node=args.local_node, + is_store_host=args.is_store_host, + job_id=args.job_id, + ranks_per_node=args.ranks_per_node, + username=args.username, + cluster=args.cluster, + health_log_prefix=args.health_log_prefix, + dmesg_artifact_enabled=args.dmesg_artifact_enabled, + result_artifact_enabled=args.result_artifact_enabled, + grpc_server_address=args.grpc_server_address, + grpc_node_id=args.grpc_node_id, + fact_history_es_url=args.fact_history_es_url, + fact_history_es_auth_file=args.fact_history_es_auth_file, + fact_history_lookback=args.fact_history_lookback, + fact_history_index=args.fact_history_index, + fact_history_max_candidate_nodes=args.fact_history_max_candidate_nodes, + fact_history_query_timeout_s=args.fact_history_query_timeout, + fact_min_repeat_count_for_avoid=args.fact_min_repeat_count_for_avoid, + fact_max_attribution_avoids_per_cycle=args.fact_max_attribution_avoids_per_cycle, + ) + + def _handle_stop_signal(signum: int, _frame: Any) -> None: + logger.info("nvrx-fact-agent received signal %s; requesting shutdown", signum) + service.request_stop() + + signal.signal(signal.SIGTERM, _handle_stop_signal) + signal.signal(signal.SIGINT, _handle_stop_signal) + service.serve_forever(max_rpc_bytes=args.max_rpc_bytes) + + +if __name__ == "__main__": + main() diff --git a/src/nvidia_resiliency_ext/attribution/fact/client.py b/src/nvidia_resiliency_ext/attribution/fact/client.py new file mode 100644 index 00000000..749be211 --- /dev/null +++ b/src/nvidia_resiliency_ext/attribution/fact/client.py @@ -0,0 +1,439 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import json +import os +import re +import socket +import subprocess # nosec B404 +import time +from collections import OrderedDict +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import TYPE_CHECKING, Any, Iterable, Optional +from urllib.parse import urlparse + +if TYPE_CHECKING: + import httpx + + +_FACT_API_PREFIX = "/latest" +_DMESG_COMMAND_TIMEOUT_S = 10.0 +_NVRX_GRPC_NODE_ID_RE = re.compile(r"^(.+)_\d+$") +_FACT_DMESG_REGEXES = ( + re.compile(r"\bNVRM:\s+Xid\b"), + re.compile(r"\bXid\s+\(PCI:"), + re.compile(r"\bSXid\b"), + re.compile(r"\bSXid\s+\(PCI:"), + re.compile(r"\bNV_ERR_[A-Z0-9_]*\b"), + re.compile(r"\bNV_WARN_[A-Z0-9_]*\b"), +) +_FACT_DMESG_SUBSTRINGS = ( + "NVRM: rpcRmApiAlloc_GSP: GspRmAlloc failed", + "NVRM: nvAssertFailedNoLog: Assertion failed:", + "CTRL-EVENT-EAP-FAILURE EAP authentication failed", + "CTRL-EVENT-EAP-SUCCESS EAP authentication completed successfully", + "System is powering down", + "Out of memory: Killed process", + "NMI watchdog: Watchdog detected hard LOCKUP", + "general protection fault", + "kernel stack frame pointer at", + "LustreError", + "connection2:0: ping timeout", + "detected conn error", + "Abrupt nvidia-imex daemon shutdown detected, robust channel recovery invoked!", + "Failed to collect nvlink status info!", + "not responding, still trying", + "Failed to update Rx Detect Link mask!", + "warthog-fake: INFO APS/WARTHOG Induced fatal error", + "Stopping nvidia-imex.service", + "Connection lost to node", + "Lost connection to GPU", + "Unable to handle kernel", +) +_LOKI_TIMESTAMP_END_OFFSET_S = 1.0 + + +@dataclass +class FactAttributionResult: + attributor_id: str + observation_ids: list[Any] + faulty_nodes: list[str] + attribution: dict[str, Any] + + +def _severity_from_dmesg(message: str) -> str: + lowered = message.lower() + if "xid" in lowered or "error" in lowered or "failed" in lowered: + return "err" + if "warn" in lowered: + return "warning" + return "info" + + +def _split_prefixed_dmesg_line(line: str, default_hostname: str) -> tuple[str, str]: + stripped = line.rstrip("\n") + if ": " not in stripped: + return default_hostname, stripped + + maybe_host, message = stripped.split(": ", 1) + if maybe_host and " " not in maybe_host and "/" not in maybe_host: + return _normalize_node_id(maybe_host), message + return default_hostname, stripped + + +def _normalize_node_id(node: str) -> str: + """Map NVRx gRPC log node IDs like ``host_pid`` back to the host name.""" + match = _NVRX_GRPC_NODE_ID_RE.match(node) + if match: + return match.group(1) + return node + + +def _fact_attributor_node_list(nodes: Iterable[str]) -> list[str]: + """Return explicit FACT attributor node names without gRPC log-id normalization.""" + return sorted({node for raw_node in nodes if (node := str(raw_node))}) + + +def normalize_fact_attribution_url(url: str) -> str: + """Return the FACT API base URL, accepting either service root or API root.""" + normalized = url.strip().rstrip("/") + parsed = urlparse(normalized) + if parsed.scheme not in ("http", "https") or not parsed.netloc: + raise ValueError("FACT attribution URL must include http(s) scheme and host") + if parsed.query or parsed.fragment: + raise ValueError("FACT attribution URL must not include query parameters or fragments") + if parsed.path in ("", "/"): + return f"{normalized}{_FACT_API_PREFIX}" + return normalized + + +def is_fact_relevant_dmesg_message(message: str) -> bool: + """Return whether a dmesg message matches FACT's current syslog extractors.""" + if any(pattern.search(message) for pattern in _FACT_DMESG_REGEXES): + return True + return any(term in message for term in _FACT_DMESG_SUBSTRINGS) + + +def _prefix_dmesg_text(text: str, hostname: Optional[str]) -> str: + if not hostname: + return text + prefixed = [] + for line in text.splitlines(keepends=True): + if line.endswith("\n"): + prefixed.append(f"{hostname}: {line}") + else: + prefixed.append(f"{hostname}: {line}\n") + return "".join(prefixed) + + +def collect_recent_dmesg_text( + *, + window_s: float, + hostname: Optional[str] = None, +) -> str: + """Collect a bounded recent dmesg window and optionally prefix every line.""" + since = datetime.now() - timedelta(seconds=max(0.0, float(window_s))) + proc = subprocess.run( # nosec B603 + ["dmesg", "--since", since.strftime("%Y-%m-%d %H:%M:%S.%f")], + capture_output=True, + text=True, + check=False, + timeout=_DMESG_COMMAND_TIMEOUT_S, + ) + if proc.returncode != 0: + err = (proc.stderr or proc.stdout or "").strip() + raise RuntimeError(f"dmesg --since failed with rc={proc.returncode}: {err}") + return _prefix_dmesg_text(proc.stdout or "", hostname) + + +def dmesg_lines_to_raw_loki_streams( + lines: Iterable[str], + *, + default_hostname: Optional[str] = None, + timestamp_start_ns: Optional[int] = None, + prefilter: bool = False, +) -> tuple[list[dict[str, Any]], list[str]]: + """Convert dmesg lines into FACT raw_loki_streams JSON.""" + hostname = default_hostname or socket.gethostname() + base_ns = timestamp_start_ns if timestamp_start_ns is not None else time.time_ns() + streams: "OrderedDict[str, list[list[str]]]" = OrderedDict() + accepted_offset = 0 + + for line in lines: + if not line.strip(): + continue + node, message = _split_prefixed_dmesg_line(line, hostname) + if prefilter and not is_fact_relevant_dmesg_message(message): + continue + payload = { + "body": message, + "severity": _severity_from_dmesg(message), + "attributes": { + "hostname": node, + "appname": "kernel", + "facility": 0, + }, + "resources": {}, + } + streams.setdefault(node, []).append( + [ + str(base_ns + accepted_offset), + json.dumps(payload, separators=(",", ":")), + ] + ) + accepted_offset += 1 + + raw_streams = [ + { + "stream": { + "job": "nvrx", + "app": "dmesg", + "hostname": node, + }, + "values": values, + } + for node, values in streams.items() + ] + return raw_streams, list(streams.keys()) + + +def dmesg_text_to_raw_loki_streams( + text: str, + *, + default_hostname: Optional[str] = None, + timestamp_start_ns: Optional[int] = None, + prefilter: bool = False, +) -> tuple[list[dict[str, Any]], list[str]]: + """Convert dmesg text into FACT raw_loki_streams JSON.""" + return dmesg_lines_to_raw_loki_streams( + text.splitlines(), + default_hostname=default_hostname, + timestamp_start_ns=timestamp_start_ns, + prefilter=prefilter, + ) + + +def _raw_loki_timestamp_anchor_ns(start_time: datetime, end_time: datetime) -> int: + """Anchor synthetic dmesg Loki timestamps inside the observation interval. + + Plain ``dmesg`` output carries monotonic kernel timestamps, not wall-clock + timestamps. Put the synthetic Loki times near collection end so short NVRx + cycles do not place evidence before the FACT workload start time. + """ + span_s = max(0.0, (end_time - start_time).total_seconds()) + if span_s <= 0.0: + anchor = end_time + else: + anchor = end_time - timedelta(seconds=min(_LOKI_TIMESTAMP_END_OFFSET_S, span_s / 2.0)) + return int(anchor.timestamp() * 1_000_000_000) + + +class FactAttributionService: + """Small client for FACT's attributor/observation/attribution API.""" + + def __init__( + self, + *, + url: str, + timeout_s: float = 60.0, + ) -> None: + self.timeout_s = timeout_s + self.base_url = normalize_fact_attribution_url(url) + + def create_failure_attributor( + self, + *, + job_id: str, + cycle_index: int, + nodes: Iterable[str], + ranks_per_node: int, + nranks: int, + start_time: datetime, + end_time: datetime, + username: Optional[str] = None, + cluster: Optional[str] = None, + ) -> str: + import httpx + + observation_nodes = _fact_attributor_node_list(nodes) + if not observation_nodes: + observation_nodes = [socket.gethostname()] + nranks = max(nranks, ranks_per_node * len(observation_nodes)) + attributor_info = self._build_attributor_info( + job_id=job_id, + cycle_index=cycle_index, + nodes=observation_nodes, + ranks_per_node=ranks_per_node, + nranks=nranks, + start_time=start_time, + end_time=end_time, + username=username, + cluster=cluster, + ) + with httpx.Client(timeout=self.timeout_s) as client: + return self._create_attributor(client, attributor_info) + + def submit_dmesg_text_observation( + self, + *, + attributor_id: str, + dmesg_text: str, + start_time: datetime, + end_time: datetime, + default_hostname: Optional[str] = None, + ) -> Optional[Any]: + import httpx + + raw_streams, _ = dmesg_text_to_raw_loki_streams( + dmesg_text, + default_hostname=default_hostname, + timestamp_start_ns=_raw_loki_timestamp_anchor_ns(start_time, end_time), + prefilter=True, + ) + if not raw_streams: + return None + + with httpx.Client(timeout=self.timeout_s) as client: + return self._post_observation( + client, + attributor_id=attributor_id, + source="syslog", + format_="raw_loki_streams", + body=json.dumps(raw_streams, separators=(",", ":")), + start_time=start_time, + end_time=end_time, + ) + + def get_attribution_result( + self, + *, + attributor_id: str, + observation_ids: Optional[list[Any]] = None, + ) -> FactAttributionResult: + import httpx + + with httpx.Client(timeout=self.timeout_s) as client: + attribution = self._get_attribution(client, attributor_id) + return FactAttributionResult( + attributor_id=attributor_id, + observation_ids=observation_ids or [], + faulty_nodes=self._extract_faulty_nodes(attribution), + attribution=attribution, + ) + + def _build_attributor_info( + self, + *, + job_id: str, + cycle_index: int, + nodes: list[str], + ranks_per_node: int, + nranks: int, + start_time: datetime, + end_time: datetime, + username: Optional[str] = None, + cluster: Optional[str] = None, + ) -> dict[str, Any]: + username = ( + username + or os.environ.get("SLURM_JOB_USER") + or os.environ.get("USER") + or os.environ.get("LOGNAME") + or "unknown" + ) + cluster = ( + cluster + or os.environ.get("SLURM_CLUSTER_NAME") + or os.environ.get("NVRX_CLUSTER_NAME") + or "unknown" + ) + tenant = os.environ.get("SLURM_JOB_ACCOUNT") or os.environ.get("NVRX_TENANT") or "unknown" + job_name = os.environ.get("SLURM_JOB_NAME") or "unknown" + return { + "workload": { + "id": f"{job_id}:cycle{cycle_index}:{time.time_ns()}", + "type": "slurm" if os.environ.get("SLURM_JOB_ID") else "", + "srun_cmd": "unknown", + "job_start_time": start_time.isoformat(), + "job_end_time": end_time.isoformat(), + "status": "FAILED", + "nranks": nranks, + "ranks_per_node": ranks_per_node, + "nodes": nodes, + "name": job_name, + "username": username, + "framework": os.environ.get("NVRX_FRAMEWORK", "unknown"), + "exit_code_signal": 1, + }, + "metadata": { + "cluster": cluster, + "agent": "nvrx-ft-launcher", + "tenant": tenant, + "ruleset": "default", + }, + } + + def _create_attributor(self, client: httpx.Client, info: dict[str, Any]) -> str: + response = client.post( + f"{self.base_url}/attributor", + json=info, + headers={"accept": "application/json"}, + ) + response.raise_for_status() + return str(response.json()["attributor_id"]) + + def _post_observation( + self, + client: httpx.Client, + *, + attributor_id: str, + source: str, + format_: str, + body: str, + start_time: datetime, + end_time: datetime, + ) -> Any: + payload = { + "context": { + "time_interval": { + "start": start_time.isoformat(), + "end": end_time.isoformat(), + }, + "resources": "AllJobResources", + "source": source, + "format": format_, + }, + "body": body, + } + response = client.post( + f"{self.base_url}/attributor/{attributor_id}/observation", + json=payload, + headers={"accept": "application/json"}, + ) + response.raise_for_status() + return response.json().get("observation_id") + + def _get_attribution(self, client: httpx.Client, attributor_id: str) -> dict[str, Any]: + response = client.get( + f"{self.base_url}/attributor/{attributor_id}/attribution", + headers={"accept": "application/json"}, + ) + response.raise_for_status() + return response.json() + + @staticmethod + def _extract_faulty_nodes(attribution: dict[str, Any]) -> list[str]: + nodes = [] + for item in attribution.get("attributions", []): + if item.get("type") != "NodeAttribution": + continue + if item.get("attributions"): + node = item.get("node") + if node: + nodes.append(str(node)) + return sorted(set(nodes)) diff --git a/src/nvidia_resiliency_ext/attribution/fact/fact_integration_design.md b/src/nvidia_resiliency_ext/attribution/fact/fact_integration_design.md new file mode 100644 index 00000000..3a54e588 --- /dev/null +++ b/src/nvidia_resiliency_ext/attribution/fact/fact_integration_design.md @@ -0,0 +1,641 @@ +# NVRx FACT Integration Design + +This is the internal design for the NVRx FACT integration. The user-facing +operator doc is `docs/source/fault_tolerance/fact_node_attribution.rst`. + +FACT, Failure Attribution and Characterization, provides node-level attribution +from host evidence. It does not produce the NVRx application-log +`STOP`/`RESTART` decision; that remains the existing application-log +attribution path. + +## Scope + +The integration has two legs: + +1. **Current-cycle attribution:** after a failed FT cycle, live agents collect + local dmesg, POST observations to FACT Attribution Service, and the + store-host agent GETs one attribution result. +2. **Repeat-offender avoid policy:** the store-host agent combines current FACT + suspect nodes with recent FACT node-history records plus an in-memory + hot cache, then returns `avoid_nodes` for the next rendezvous placement. + +Hard health/liveness failures are outside this policy. If a node fails a Node +Health Check, disappears, kernel-panics, loses power, or cannot rejoin, it is a +hard exclusion. FACT attribution is weaker suspect evidence and must fail open. + +## Ownership + +```text +FT launcher + parses CLI/YAML + starts one local nvrx-fact-agent per launcher process + sends local UDS notifications + applies avoid_nodes during rendezvous placement + +nvrx-fact-agent on every node + collects local dmesg evidence + POSTs local observation to FACT Attribution Service + queues optional dmesg/result artifacts through gRPC + +nvrx-fact-agent on the store host + creates the FACT attributor + performs the FACT attribution GET + queries FACT node-history + maintains the in-memory hot cache + computes and serves avoid_nodes over local UDS + +FACT services + own current-cycle attribution and durable historical node records +``` + +NVRx owns placement because it knows quorum, joined participants, and standby +capacity. FACT should eventually own the durable history query interface +because it owns FACT DB schema and ingestion paths. + +There is no separate policy daemon. Extend the launcher-managed +`nvrx-fact-agent`. + +## Configuration + +Current-cycle FACT evidence path: + +```text +--ft-fact-url +--ft-fact-agent-socket-path +--ft-fact-agent-rpc-timeout +--ft-fact-agent-store-timeout +--ft-health-log-prefix +--ft-enable-health-log-dmesg +--ft-enable-fact-result-artifact +``` + +History-based avoid policy: + +```text +--ft-fact-history-es-url +--ft-fact-history-es-auth-file +--ft-fact-history-lookback +--ft-fact-history-index +--ft-fact-history-max-candidate-nodes +--ft-fact-history-query-timeout +--ft-fact-min-repeat-count-for-avoid +--ft-fact-max-attribution-avoids-per-cycle +``` + +Defaults: + +```text +dmesg_window = 12m +observation_deadline = 30s + +fact_policy_ready_timeout = 60s + +fact_history_lookback = 14d +fact_history_index = +fact_history_max_candidate_nodes = 16 +fact_history_query_timeout = 30s +min_repeat_count_for_avoid = 2 +max_attribution_avoids_per_cycle = 1 +``` + +Repeat-offender policy is available when current-cycle FACT attribution is +enabled: + +```text +--ft-fact-url +``` + +With only `--ft-fact-url`, the policy uses the in-memory hot cache from earlier +cycles in the same NVRx run. If `--ft-fact-history-es-url` and +`--ft-fact-history-es-auth-file` are also set, the store-host agent adds durable +FACT history to the same decision. The auth-file contents must not be logged or +written to artifacts. Most knobs should stay defaulted; if the agent command +line keeps growing, prefer a generated config file over environment +variables. + +## Agent Lifecycle + +```text +FT parses config +FT computes session context +FT starts nvrx-fact-agent after session context and gRPC log funnel are known +FT sends UDS ping and waits for accepted=true +FT sends UDS cycle_failed after worker failure/stop +FT store-host queries UDS get_avoid_nodes before next placement +FT sends UDS shutdown on normal launcher shutdown +``` + +Agent startup/session args: + +```text +--fact-url +--socket-path +--run-id +--rdzv-endpoint +--store-timeout +--local-node +--is-store-host +--job-id +--ranks-per-node +--username +--cluster +--health-log-prefix +--dmesg-artifact-enabled +--result-artifact-enabled +--grpc-server-address +--grpc-node-id +--fact-history-es-url +--fact-history-es-auth-file +--fact-history-lookback +--fact-history-index +--fact-history-max-candidate-nodes +--fact-history-query-timeout +--fact-min-repeat-count-for-avoid +--fact-max-attribution-avoids-per-cycle +``` + +Only the store-host agent uses history args. Passing them to all agents is +acceptable because each launcher starts the same binary. + +## UDS RPCs + +### `ping` + +```json +{"event": "ping"} +``` + +```json +{"accepted": true} +``` + +### `cycle_failed` + +FT sends this after workers for the failed cycle are stopped. The ACK means the +agent accepted work into its local queue; it is not a FACT or policy result. + +Store-host payload: + +```json +{ + "event": "cycle_failed", + "cycle_id": "3", + "cycle_start_time": "2026-05-10T12:00:00+00:00", + "cycle_end_time": "2026-05-10T12:42:00+00:00", + "expected_nodes": ["node-a", "node-b"] +} +``` + +Leaf payload: + +```json +{ + "event": "cycle_failed", + "cycle_id": "3", + "cycle_start_time": "2026-05-10T12:00:00+00:00", + "cycle_end_time": "2026-05-10T12:42:00+00:00", + "expected_nodes": [] +} +``` + +Response: + +```json +{"accepted": true} +``` + +Field contract: + +| Field | Scope | Meaning | +| --- | --- | --- | +| `cycle_id` | Per failed cycle | NVRx cycle id. FACT node-history needs an equivalent cycle-distinct episode id. | +| `cycle_start_time` | Per failed cycle | FACT workload start time and upper bound for history lookup. | +| `cycle_end_time` | Per failed cycle | Hot-cache event time and recency tie-breaker. Defaults to agent receive time if omitted. | +| `expected_nodes` | Store-host only | Active nodes for the failed cycle, not every allocated/spare node. Used for FACT workload scope and completion target. | + +Rank shape is startup/session state: `ranks_per_node` is passed when the agent +starts, and the store host derives `nranks = ranks_per_node * len(expected_nodes)`. + +### `get_avoid_nodes` + +Before next placement, the store-host FT side queries its local agent: + +```json +{ + "event": "get_avoid_nodes", + "cycle_id": "3" +} +``` + +Possible responses: + +```text +{"cycle_id": "3", "status": "ready", "avoid_nodes": ["node-a"]} +{"cycle_id": "3", "status": "skipped", "avoid_nodes": []} +{"cycle_id": "3", "status": "pending", "avoid_nodes": []} +``` + +The current FT side makes a single local query at placement time. Missing, +malformed, skipped, pending, or failed responses are treated as an empty avoid +list. Candidate ranking, reason strings, timestamps, and history details are +internal agent state; FT consumes only `avoid_nodes`. + +## TCPStore Contract + +TCPStore carries only FACT evidence control-plane state: + +```text +fact_agent::cycle:attributor_id +fact_agent::cycle:done_count +``` + +| Key | Writer | Reader | Meaning | +| --- | --- | --- | --- | +| `attributor_id` | Store-host | All agents | FACT attributor id, or a failure sentinel. | +| `done_count` | All agents | Store-host | Atomic count of agents that reached terminal local outcome. | + +`done_count` increments for successful submission and terminal local failures: +empty evidence, collection failure, POST failure, or attributor-id failure. It +is not a count of successful FACT submissions. Nodes that never increment +before the deadline are missing/liveness evidence. + +No per-node status, observation ids, faulty nodes, result path, result status, +or policy output is written to TCPStore. These are JSONL artifact records when +enabled, or in-memory policy state on the store-host agent. + +## FACT Attribution Flow + +### 1. Create Attributor + +Store-host agent: + +```text +POST /attributor +``` + +Request body: + +```json +{ + "workload": { + "id": ":cycle:", + "type": "slurm", + "job_start_time": "", + "job_end_time": "", + "status": "FAILED", + "nodes": ["node-a", "node-b"], + "ranks_per_node": 4, + "nranks": 8, + "name": "", + "username": "" + }, + "metadata": { + "cluster": "", + "agent": "nvrx-ft-launcher", + "tenant": "", + "ruleset": "default" + } +} +``` + +The store-host writes the returned `attributor_id` to TCPStore, then runs the +same dmesg observation path locally. + +### 2. POST Dmesg Observation + +Every live agent: + +```text +collect dmesg --since now - 12m +queue raw dmesg artifact when enabled and non-empty +wait for TCPStore[attributor_id] +apply built-in FACT dmesg filter +convert matching lines to raw_loki_streams +POST /attributor/{attributor_id}/observation +queue fact_observation JSONL when enabled +increment done_count +``` + +The 12-minute window covers NCCL timeout cases where the kernel event may be +roughly 10 minutes old. The FACT workload `job_start_time` remains the actual +cycle start time; the dmesg collection window is independent. + +FACT observation body: + +```json +{ + "context": { + "time_interval": { + "start": "", + "end": "" + }, + "resources": "AllJobResources", + "source": "syslog", + "format": "raw_loki_streams" + }, + "body": "[{\"stream\":...,\"values\":...}]" +} +``` + +The agent emits Loki stream attributes with `hostname`, `appname=kernel`, and +the dmesg line body. Default `dmesg` output has monotonic kernel timestamps, so +the agent assigns synthetic Loki timestamps near collection end time to keep +short-cycle evidence inside the FACT workload window. + +An empty or fully prefiltered dmesg window produces no FACT POST, records +`status = "empty"` when result JSONL is enabled, and still increments +`done_count`. Observation POST failures are retried with jitter only while the +observation deadline still has useful time remaining. + +### 3. GET Attribution + +Store-host waits for `done_count >= len(expected_nodes)` or the observation +deadline, then performs one authoritative GET: + +```text +GET /attributor/{attributor_id}/attribution +``` + +The full FACT response is wrapped as `FactAttributionResult` for audit and for +repeat-offender policy input. + +## Artifacts and Durability + +FACT submission does not read artifacts; it receives POSTed contents. Artifacts +are optional postmortem evidence and require the launcher gRPC log funnel. There +is no direct shared-file fallback. + +For cycle `N`, paths are derived from `health_logging.prefix`: + +```text + -> _dmesg_cycle.log + -> _fact_cycle.log +``` + +### Dmesg Artifact + +Enabled by `--ft-enable-health-log-dmesg`. All live agents queue their raw +collected dmesg text before FACT filtering to one shared per-cycle file. +Production collection prefixes each line with the source node name. Empty raw +windows queue no dmesg chunk and do not create a 0-byte file. + +### Result JSONL Artifact + +Enabled by `--ft-enable-fact-result-artifact`. All agents queue compact JSONL +records to the root writer. + +Leaf record: + +```json +{ + "record_type": "fact_observation", + "run_id": "run-1", + "cycle": 3, + "job_id": "12345", + "node": "node-a", + "source": "dmesg", + "status": "submitted", + "attributor_id": "att-1", + "observation_id": "obs-1", + "bytes_collected": 2048, + "lines_collected": 12, + "dmesg_path": "/lustre/logs/job_health_dmesg_cycle3.log", + "error": "" +} +``` + +Leaf status values: + +| Status | Meaning | +| --- | --- | +| `submitted` | FACT returned an `observation_id`. | +| `empty` | Collection worked, but filtering left no lines to POST. | +| `collect_failed` | Local dmesg collection failed. | +| `attributor_failed` | No usable `attributor_id` was available. | +| `post_failed` | Dmesg was collected, but FACT observation POST failed. | + +Store-host record: + +```json +{ + "record_type": "fact_result", + "status": "complete", + "run_id": "run-1", + "cycle": 3, + "job_id": "12345", + "expected_node_count": 2, + "completed_node_count": 1, + "avoid_nodes": ["node-a"], + "fact_attribution_result": { + "attributor_id": "att-1", + "observation_ids": [], + "faulty_nodes": ["node-a"], + "attribution": {} + } +} +``` + +Records and dmesg payloads are queued as chunks, so bytes from different chunks +should not interleave. Ordering across nodes is not a correctness contract. + +Dmesg collection, FACT POST/GET, TCPStore completion, policy computation, and +gRPC artifact drain are best-effort. On normal launcher shutdown, +`FactAgentManager` asks the agent to exit over UDS so queued JSONL can drain, +but completeness remains best-effort. + +## Repeat-Offender Policy + +Current FACT attribution returns suspect nodes for this cycle: + +```text +current_suspect_nodes = fact_result.faulty_nodes +``` + +If history config is present, the store-host agent queries FACT node-history +only for those current suspects: + +```text +source = configured FACT node-history source +cluster == current cluster +node in current_suspect_nodes +lookback_start <= event_time < cycle_start_time +``` + +MVP logic uses: + +```text +cluster +node +episode_id +event_time +``` + +`episode_id` is used with `cluster` and `node` to deduplicate historical +episodes. The expected NVRx node-history contract is: + +```text +episode_id = _ +``` + +or an equivalent FACT-defined cycle-distinct id. The deployment adapter maps +these generic fields to the concrete FACT history backend. The FACT +node-history source does not currently expose `attributor_id` or +`observation_id`, so MVP dedupe cannot rely on those ids. + +The hot cache is in-memory only for MVP: + +```text +scope = current NVRx job/process +durability = none +hot_cache_episode_key = (job_id, cycle_id, node_id) +``` + +FACT history is the durable source once FACT PM catches up. Expected FACT PM +upload lag can be about 30 minutes, so back-to-back failures need the local hot +cache before history catches up. The hot cache is only a current-process +overlay for candidate nodes; MVP does not run a broad bidirectional +reconciliation loop with FACT history. + +MVP aggregation: + +```text +candidate_nodes = current_fact_result.faulty_nodes +prior_history = FACT node-history rows for candidate_nodes before cycle_start_time, + or [] when no history source is configured +hot_overlay = earlier NVRx events for candidate_nodes already processed by this + agent before the current policy computation +repeat_count(node) = 1 current event + prior_history_episodes + hot_overlay_episodes +``` + +Broad/systemic guard: + +```text +if len(current_suspect_nodes) > fact_history_max_candidate_nodes: + skip history lookup + do not add these suspects to future repeat counts + return avoid_nodes=[] +``` + +The default guard is 16 nodes. This is a broad-event boundary, not a precise +rack classifier. + +Decision table: + +| Evidence | Action | +| --- | --- | +| One current FACT attribution only | Suspect; no placement action. | +| Two or more total same-node events, including current | `avoid-for-retry` if feasible. | +| More than `fact_history_max_candidate_nodes` current suspects | Audit only; skip history query. | +| Health/liveness failure | Hard exclude outside this policy. | + +Ranking: + +```text +1. higher repeat_count +2. more recent prior same-node episode +3. lexical node name + +prior_last_seen(node) = + max(event_time from prior FACT history, + cycle_end_time from prior unreconciled hot-cache episodes) +``` + +Do not use the current cycle's `cycle_end_time` for this tie-breaker; it is the +same failure event currently being handled. If the top-ranked candidate is not +feasible to avoid, MVP skips all attribution-based avoids for that cycle. + +Family-aware escalation is out of MVP because `FactAttributionResult` and FACT +node-history do not currently expose a stable normalized failure family. Until +FACT provides one, the policy is same-node repeat only and stores raw symptoms +for audit. The desired future family taxonomy is subsystem-level: +`accelerator-memory`, `accelerator-gpu`, `accelerator-fabric`, +`network-fabric`, `storage`, `kernel-fatal`, `host-oom`, `UNKNOWN_FAMILY`. + +## Placement Handoff + +The store-host FT/rendezvous side queries local UDS: + +```text +get_avoid_nodes(cycle_id) -> avoid_nodes +``` + +Then it validates: + +```text +node is a joined participant +min_nodes / quorum still holds +rank and accelerator count still holds +standby capacity is available +max attribution avoids per cycle is respected +``` + +Avoided nodes are treated as standby/spare for the next retry. They are not +hard-excluded and may still join rendezvous. If the policy is absent, late, +malformed, skipped, or infeasible, rendezvous proceeds normally. + +## Module Placement + +Keep MVP FACT-specific code directly under `attribution/fact`: + +```text +nvidia_resiliency_ext/attribution/fact/ + models.py + history_client.py + hot_cache.py + repeat_offender_policy.py +``` + +| Module | Responsibility | +| --- | --- | +| `models.py` | Dataclasses for history records, hot-cache episodes, candidates, and decisions. | +| `history_client.py` | FACT history query and auth-file handling. | +| `hot_cache.py` | In-memory current-job episode store. | +| `repeat_offender_policy.py` | Count, rank, feasibility precheck, and decision construction. | + +Introduce a subpackage only if this grows beyond a few files or needs multiple +policy variants. + +## Tests + +Minimum coverage: + +```text +config parses CLI/YAML and passes settings into FactAgentManager +FactAgentManager passes session/history args to nvrx-fact-agent +cycle_failed ACK means queued, not complete +TCPStore only carries attributor_id and done_count +leaf agents never query FACT history +store-host queries history only for current FACT faulty_nodes +broad current suspect set is audit-only and does not update hot cache +history query failure/timeout makes get_avoid_nodes return empty avoid_nodes +episode_id is required for historical dedupe +repeat_count >= 2 produces one avoid candidate +top candidate infeasible means no avoids for MVP +get_avoid_nodes output is ignored when missing/malformed/late +result JSONL records include terminal local failures before done_count +dmesg artifact is one shared gRPC-written file per cycle +empty raw dmesg queues no dmesg artifact chunk +``` + +## Open Items + +1. Verify production-like FACT node-history rows expose `episode_id = + _` or an equivalent cycle-distinct id. +2. Define the auth-file format consumed by `history_client.py`. +3. Decide whether the long-term FACT-owned history API is a library or an + Attribution Service endpoint. +4. Have FACT review the built-in dmesg prefilter, especially missing + `mlx5*`/HCA port-flap and network/fabric cases. +5. Validate FACT Attribution surge capacity for 5k-20k near-simultaneous POSTs. +6. Decide whether to pre-create cycle-scoped attributors, after confirming FACT + tolerates idle or long-lived attributors. +7. Revisit `done_count` fan-in at 5k-20k nodes. Alternatives are fixed-deadline + GET, sharded counters, or hierarchical completion. +8. Confirm whether FACT GET is strongly consistent with completed observation + POSTs or needs bounded retry/ready signaling. +9. Confirm result JSONL size and gRPC/log-funnel limits for large FACT + responses. +10. Define production FACT credential delivery, refresh, and redaction. +11. Define a stable physical node id beyond scheduler hostname. +12. Confirm FACT timestamp contract for the node-history `event_time`; avoid + timestamp-window dedupe until then. diff --git a/src/nvidia_resiliency_ext/attribution/fact/history_client.py b/src/nvidia_resiliency_ext/attribution/fact/history_client.py new file mode 100644 index 00000000..474da4aa --- /dev/null +++ b/src/nvidia_resiliency_ext/attribution/fact/history_client.py @@ -0,0 +1,200 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import json +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Any, Iterable, Optional +from urllib.parse import quote + +from nvidia_resiliency_ext.attribution.fact.models import FactHistoryRecord + + +def parse_duration(value: str | float | int | None, *, default: timedelta) -> timedelta: + if value is None or value == "": + return default + if isinstance(value, (int, float)): + return timedelta(seconds=float(value)) + text = str(value).strip().lower() + if not text: + return default + unit = text[-1] + number_text = text[:-1] if unit.isalpha() else text + number = float(number_text) + if unit == "d": + return timedelta(days=number) + if unit == "h": + return timedelta(hours=number) + if unit == "m": + return timedelta(minutes=number) + if unit == "s" or not unit.isalpha(): + return timedelta(seconds=number) + raise ValueError(f"unsupported duration suffix in {value!r}") + + +class FactHistoryClient: + """Generic Elasticsearch-style FACT node-history client.""" + + def __init__( + self, + *, + es_url: str, + auth_file: str, + index: Optional[str] = None, + timeout_s: float = 30.0, + ) -> None: + self.es_url = es_url.rstrip("/") + self.auth_file = auth_file + self.index = str(index).strip() if index else None + self.timeout_s = float(timeout_s) + + def query_node_history( + self, + *, + cluster: str, + nodes: Iterable[str], + start_time: datetime, + end_time: datetime, + ) -> list[FactHistoryRecord]: + import httpx + + node_list = sorted({str(node) for node in nodes if str(node)}) + if not node_list: + return [] + payload = self._build_query( + cluster=cluster, + nodes=node_list, + start_time=start_time, + end_time=end_time, + ) + with httpx.Client(timeout=self.timeout_s, headers=self._auth_headers()) as client: + response = client.post(self._search_url(), json=payload) + response.raise_for_status() + return self._parse_response(response.json()) + + def _search_url(self) -> str: + if self.es_url.endswith("/_search"): + return self.es_url + if self.index: + return f"{self.es_url}/{quote(self.index, safe='*,')}/_search" + return f"{self.es_url}/_search" + + @staticmethod + def _build_query( + *, + cluster: str, + nodes: list[str], + start_time: datetime, + end_time: datetime, + ) -> dict[str, Any]: + return { + "size": 10_000, + "_source": ["cluster", "node", "episode_id", "event_time"], + "query": { + "bool": { + "filter": [ + {"term": {"cluster": cluster}}, + {"terms": {"node": nodes}}, + { + "range": { + "event_time": { + "gte": _isoformat(start_time), + "lt": _isoformat(end_time), + } + } + }, + ] + } + }, + } + + def _auth_headers(self) -> dict[str, str]: + text = Path(self.auth_file).read_text(encoding="utf-8").strip() + if not text: + return {} + try: + parsed = json.loads(text) + except json.JSONDecodeError: + return {"Authorization": _authorization_value(text)} + if not isinstance(parsed, dict): + return {} + headers = parsed.get("headers") + if isinstance(headers, dict): + return {str(key): str(value) for key, value in headers.items()} + if parsed.get("authorization"): + return {"Authorization": str(parsed["authorization"])} + if parsed.get("bearer_token"): + return {"Authorization": f"Bearer {parsed['bearer_token']}"} + if parsed.get("api_key"): + return {"Authorization": f"ApiKey {parsed['api_key']}"} + return {} + + @staticmethod + def _parse_response(payload: dict[str, Any]) -> list[FactHistoryRecord]: + hits = payload.get("hits", {}).get("hits", []) + if not isinstance(hits, list): + return [] + records = [] + for hit in hits: + if not isinstance(hit, dict): + continue + source = hit.get("_source") if isinstance(hit.get("_source"), dict) else {} + fields = hit.get("fields") if isinstance(hit.get("fields"), dict) else {} + cluster = _field_value(source, fields, "cluster") + node = _field_value(source, fields, "node") + episode_id = _field_value(source, fields, "episode_id") + event_time = _parse_datetime(_field_value(source, fields, "event_time")) + if not cluster or not node or not episode_id or event_time is None: + continue + records.append( + FactHistoryRecord( + cluster=str(cluster), + node=str(node), + episode_id=str(episode_id), + event_time=event_time, + ) + ) + return records + + +def _field_value(source: dict[str, Any], fields: dict[str, Any], name: str) -> Any: + if name in source: + value = source[name] + else: + value = fields.get(name) + if isinstance(value, list): + return value[0] if value else None + return value + + +def _parse_datetime(value: Any) -> datetime | None: + if isinstance(value, datetime): + return _ensure_aware(value) + if not value: + return None + text = str(value).strip() + if text.endswith("Z"): + text = f"{text[:-1]}+00:00" + try: + return _ensure_aware(datetime.fromisoformat(text)) + except ValueError: + return None + + +def _ensure_aware(value: datetime) -> datetime: + if value.tzinfo is None: + return value.replace(tzinfo=timezone.utc) + return value + + +def _isoformat(value: datetime) -> str: + return _ensure_aware(value).isoformat() + + +def _authorization_value(text: str) -> str: + for prefix in ("Bearer ", "ApiKey ", "Basic "): + if text.startswith(prefix): + return text + return f"Bearer {text}" diff --git a/src/nvidia_resiliency_ext/attribution/fact/hot_cache.py b/src/nvidia_resiliency_ext/attribution/fact/hot_cache.py new file mode 100644 index 00000000..d72b4fdb --- /dev/null +++ b/src/nvidia_resiliency_ext/attribution/fact/hot_cache.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections import OrderedDict +from datetime import datetime +from typing import Iterable + +from nvidia_resiliency_ext.attribution.fact.models import FactHistoryRecord, HotCacheEpisode + + +class FactHotCache: + """In-memory current-process FACT episode overlay.""" + + def __init__(self, max_episodes: int = 4096) -> None: + self.max_episodes = max(1, int(max_episodes)) + self._episodes: "OrderedDict[tuple[str, str, str], HotCacheEpisode]" = OrderedDict() + + def add_episode(self, episode: HotCacheEpisode) -> None: + key = (episode.cluster, episode.node, episode.episode_id) + self._episodes.pop(key, None) + self._episodes[key] = episode + while len(self._episodes) > self.max_episodes: + self._episodes.popitem(last=False) + + def add_current_cycle( + self, + *, + cluster: str, + nodes: Iterable[str], + job_id: str, + cycle_id: int, + event_time: datetime, + ) -> None: + episode_id = f"{job_id}_{cycle_id}" + for node in sorted({str(node) for node in nodes if str(node)}): + self.add_episode( + HotCacheEpisode( + cluster=cluster, + node=node, + episode_id=episode_id, + event_time=event_time, + ) + ) + + def records_for( + self, + *, + cluster: str, + nodes: Iterable[str], + before: datetime, + ) -> list[FactHistoryRecord]: + node_set = {str(node) for node in nodes if str(node)} + records = [] + for episode in self._episodes.values(): + if episode.cluster != cluster or episode.node not in node_set: + continue + if episode.event_time >= before: + continue + records.append( + FactHistoryRecord( + cluster=episode.cluster, + node=episode.node, + episode_id=episode.episode_id, + event_time=episode.event_time, + ) + ) + return records diff --git a/src/nvidia_resiliency_ext/attribution/fact/manager.py b/src/nvidia_resiliency_ext/attribution/fact/manager.py new file mode 100644 index 00000000..7a5bafd1 --- /dev/null +++ b/src/nvidia_resiliency_ext/attribution/fact/manager.py @@ -0,0 +1,311 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Launcher-side lifecycle management for nvrx-fact-agent.""" + +from __future__ import annotations + +import contextlib +import logging +import os +import shutil +import subprocess # nosec B404 +import sys +import tempfile +import time +from dataclasses import dataclass +from typing import Optional + +from nvidia_resiliency_ext.attribution.fact.rpc import notify_fact_agent +from nvidia_resiliency_ext.shared_utils.log_manager import LogConfig + +logger = logging.getLogger(LogConfig.name) + +DEFAULT_FACT_AGENT_STARTUP_TIMEOUT = 5.0 +_FACT_AGENT_STOP_TIMEOUT = 10.0 +_FACT_AGENT_READY_POLL_INTERVAL = 0.1 + + +@dataclass(frozen=True) +class FactAgentEndpoint: + socket_path: str + + +class FactAgentManager: + """Start and stop one local nvrx-fact-agent process for this launcher.""" + + def __init__( + self, + *, + fact_url: Optional[str], + socket_path: Optional[str] = None, + rpc_timeout_s: float = 2.0, + startup_timeout_s: float = DEFAULT_FACT_AGENT_STARTUP_TIMEOUT, + log_file: Optional[str] = None, + run_id: Optional[str] = None, + rdzv_endpoint: Optional[str] = None, + store_timeout_s: Optional[float] = None, + local_node: Optional[str] = None, + is_store_host: bool = False, + job_id: Optional[str] = None, + ranks_per_node: Optional[int] = None, + username: Optional[str] = None, + cluster: Optional[str] = None, + health_log_prefix: Optional[str] = None, + dmesg_artifact_enabled: bool = False, + result_artifact_enabled: bool = False, + grpc_server_address: Optional[str] = None, + grpc_node_id: Optional[str] = None, + fact_history_es_url: Optional[str] = None, + fact_history_es_auth_file: Optional[str] = None, + fact_history_lookback: Optional[str] = None, + fact_history_index: Optional[str] = None, + fact_history_max_candidate_nodes: Optional[int] = None, + fact_history_query_timeout_s: Optional[float] = None, + fact_min_repeat_count_for_avoid: Optional[int] = None, + fact_max_attribution_avoids_per_cycle: Optional[int] = None, + ) -> None: + self.fact_url = str(fact_url).strip() if fact_url else None + self.socket_path = socket_path or _managed_fact_agent_socket_path() + self.rpc_timeout_s = max(0.1, float(rpc_timeout_s)) + self.startup_timeout_s = max(0.1, float(startup_timeout_s)) + self.log_file = log_file or _managed_fact_agent_log_path() + self.run_id = str(run_id).strip() if run_id else None + self.rdzv_endpoint = str(rdzv_endpoint).strip() if rdzv_endpoint else None + self.store_timeout_s = float(store_timeout_s) if store_timeout_s is not None else None + self.local_node = str(local_node).strip() if local_node else None + self.is_store_host = bool(is_store_host) + self.job_id = str(job_id).strip() if job_id else None + self.ranks_per_node = max(1, int(ranks_per_node)) if ranks_per_node is not None else None + self.username = str(username).strip() if username else None + self.cluster = str(cluster).strip() if cluster else None + self.health_log_prefix = str(health_log_prefix).strip() if health_log_prefix else None + self.dmesg_artifact_enabled = bool(dmesg_artifact_enabled) + self.result_artifact_enabled = bool(result_artifact_enabled) + self.grpc_server_address = str(grpc_server_address).strip() if grpc_server_address else None + self.grpc_node_id = str(grpc_node_id).strip() if grpc_node_id else None + self.fact_history_es_url = str(fact_history_es_url).strip() if fact_history_es_url else None + self.fact_history_es_auth_file = ( + str(fact_history_es_auth_file).strip() if fact_history_es_auth_file else None + ) + self.fact_history_lookback = ( + str(fact_history_lookback).strip() if fact_history_lookback else None + ) + self.fact_history_index = str(fact_history_index).strip() if fact_history_index else None + self.fact_history_max_candidate_nodes = fact_history_max_candidate_nodes + self.fact_history_query_timeout_s = fact_history_query_timeout_s + self.fact_min_repeat_count_for_avoid = fact_min_repeat_count_for_avoid + self.fact_max_attribution_avoids_per_cycle = fact_max_attribution_avoids_per_cycle + self.process: Optional[subprocess.Popen] = None + + @property + def is_enabled(self) -> bool: + return self.fact_url is not None + + def start_if_needed(self) -> Optional[FactAgentEndpoint]: + if not self.is_enabled: + return None + + assert self.fact_url is not None + cmd = _fact_agent_command() + [ + "--fact-url", + self.fact_url, + "--socket-path", + self.socket_path, + ] + cmd.extend(self._session_args()) + logger.info( + "Starting local nvrx-fact-agent (socket_path=%s, log_file=%s)", + self.socket_path, + self.log_file, + ) + + os.makedirs(os.path.dirname(self.socket_path) or ".", exist_ok=True) + os.makedirs(os.path.dirname(self.log_file) or ".", exist_ok=True) + log_fd = open(self.log_file, "w", encoding="utf-8") + try: + self.process = subprocess.Popen( # nosec B603 + cmd, + stdout=log_fd, + stderr=subprocess.STDOUT, + env=os.environ.copy(), + shell=False, + ) + finally: + log_fd.close() + + try: + self._wait_until_ready() + except Exception: + self.stop() + raise + + logger.info( + "nvrx-fact-agent is ready: PID=%s socket_path=%s", + self.process.pid if self.process else None, + self.socket_path, + ) + return FactAgentEndpoint(socket_path=self.socket_path) + + def stop(self) -> None: + proc = self.process + if proc is None: + with contextlib.suppress(FileNotFoundError): + os.unlink(self.socket_path) + return + if proc.poll() is not None: + logger.info( + "nvrx-fact-agent PID=%s already exited with returncode=%s", + proc.pid, + proc.returncode, + ) + self.process = None + with contextlib.suppress(FileNotFoundError): + os.unlink(self.socket_path) + return + + try: + notify_fact_agent( + socket_path=self.socket_path, + payload={"event": "shutdown"}, + timeout_s=self.rpc_timeout_s, + ) + logger.info("Requested graceful nvrx-fact-agent shutdown for PID=%s", proc.pid) + except Exception as exc: + logger.info( + "Graceful nvrx-fact-agent shutdown request failed for PID=%s: %s", + proc.pid, + exc, + ) + logger.info("Sending SIGTERM to nvrx-fact-agent PID=%s", proc.pid) + with contextlib.suppress(Exception): + proc.terminate() + try: + proc.wait(timeout=_FACT_AGENT_STOP_TIMEOUT) + except subprocess.TimeoutExpired: + logger.warning( + "nvrx-fact-agent PID=%s did not exit within %.0fs; killing", + proc.pid, + _FACT_AGENT_STOP_TIMEOUT, + ) + with contextlib.suppress(Exception): + proc.kill() + with contextlib.suppress(Exception): + proc.wait() + logger.info("nvrx-fact-agent PID=%s finished with returncode=%s", proc.pid, proc.returncode) + self.process = None + with contextlib.suppress(FileNotFoundError): + os.unlink(self.socket_path) + + def _wait_until_ready(self) -> None: + assert self.process is not None + deadline = time.monotonic() + self.startup_timeout_s + last_error = "not probed" + while time.monotonic() < deadline: + rc = self.process.poll() + if rc is not None: + raise RuntimeError( + f"nvrx-fact-agent exited before becoming ready " + f"(returncode={rc}, log_file={self.log_file})" + ) + try: + ack = notify_fact_agent( + socket_path=self.socket_path, + payload={"event": "ping"}, + timeout_s=self.rpc_timeout_s, + ) + if ack.get("accepted"): + return + last_error = f"ping rejected: {ack}" + except Exception as exc: + last_error = f"{type(exc).__name__}: {exc}" + time.sleep(_FACT_AGENT_READY_POLL_INTERVAL) + + raise TimeoutError( + f"nvrx-fact-agent did not become ready within {self.startup_timeout_s:.1f}s " + f"at {self.socket_path} (last_error={last_error}, log_file={self.log_file})" + ) + + def _session_args(self) -> list[str]: + args: list[str] = [] + if self.run_id: + args.extend(["--run-id", self.run_id]) + if self.rdzv_endpoint: + args.extend(["--rdzv-endpoint", self.rdzv_endpoint]) + if self.store_timeout_s is not None: + args.extend(["--store-timeout", str(self.store_timeout_s)]) + if self.local_node: + args.extend(["--local-node", self.local_node]) + if self.is_store_host: + args.append("--is-store-host") + if self.job_id: + args.extend(["--job-id", self.job_id]) + if self.ranks_per_node is not None: + args.extend(["--ranks-per-node", str(self.ranks_per_node)]) + if self.username: + args.extend(["--username", self.username]) + if self.cluster: + args.extend(["--cluster", self.cluster]) + if self.health_log_prefix: + args.extend(["--health-log-prefix", self.health_log_prefix]) + if self.dmesg_artifact_enabled: + args.append("--dmesg-artifact-enabled") + if self.result_artifact_enabled: + args.append("--result-artifact-enabled") + if self.grpc_server_address: + args.extend(["--grpc-server-address", self.grpc_server_address]) + if self.grpc_node_id: + args.extend(["--grpc-node-id", self.grpc_node_id]) + if self.fact_history_es_url: + args.extend(["--fact-history-es-url", self.fact_history_es_url]) + if self.fact_history_es_auth_file: + args.extend(["--fact-history-es-auth-file", self.fact_history_es_auth_file]) + if self.fact_history_lookback: + args.extend(["--fact-history-lookback", self.fact_history_lookback]) + if self.fact_history_index: + args.extend(["--fact-history-index", self.fact_history_index]) + if self.fact_history_max_candidate_nodes is not None: + args.extend( + [ + "--fact-history-max-candidate-nodes", + str(self.fact_history_max_candidate_nodes), + ] + ) + if self.fact_history_query_timeout_s is not None: + args.extend(["--fact-history-query-timeout", str(self.fact_history_query_timeout_s)]) + if self.fact_min_repeat_count_for_avoid is not None: + args.extend( + [ + "--fact-min-repeat-count-for-avoid", + str(self.fact_min_repeat_count_for_avoid), + ] + ) + if self.fact_max_attribution_avoids_per_cycle is not None: + args.extend( + [ + "--fact-max-attribution-avoids-per-cycle", + str(self.fact_max_attribution_avoids_per_cycle), + ] + ) + return args + + +def _fact_agent_command() -> list[str]: + exe = shutil.which("nvrx-fact-agent") + if exe: + return [exe] + return [sys.executable, "-m", "nvidia_resiliency_ext.attribution.fact.agent"] + + +def _managed_fact_agent_socket_path() -> str: + return os.path.join( + tempfile.gettempdir(), + f"nvrx-fact-agent-{os.getuid()}-{os.getpid()}.sock", + ) + + +def _managed_fact_agent_log_path() -> str: + return os.path.join( + tempfile.gettempdir(), + f"nvrx-fact-agent-{os.getuid()}-{os.getpid()}.log", + ) diff --git a/src/nvidia_resiliency_ext/attribution/fact/models.py b/src/nvidia_resiliency_ext/attribution/fact/models.py new file mode 100644 index 00000000..80f69647 --- /dev/null +++ b/src/nvidia_resiliency_ext/attribution/fact/models.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime + + +@dataclass(frozen=True) +class FactHistoryRecord: + cluster: str + node: str + episode_id: str + event_time: datetime + + +@dataclass(frozen=True) +class HotCacheEpisode: + cluster: str + node: str + episode_id: str + event_time: datetime + + +@dataclass(frozen=True) +class AvoidCandidate: + node: str + repeat_count: int + prior_last_seen: datetime | None = None + + +@dataclass(frozen=True) +class AvoidDecision: + cycle_id: int + status: str + avoid_nodes: list[str] = field(default_factory=list) + candidates: list[AvoidCandidate] = field(default_factory=list) diff --git a/src/nvidia_resiliency_ext/attribution/fact/repeat_offender_policy.py b/src/nvidia_resiliency_ext/attribution/fact/repeat_offender_policy.py new file mode 100644 index 00000000..61b12938 --- /dev/null +++ b/src/nvidia_resiliency_ext/attribution/fact/repeat_offender_policy.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from datetime import datetime +from typing import Iterable + +from nvidia_resiliency_ext.attribution.fact.models import ( + AvoidCandidate, + AvoidDecision, + FactHistoryRecord, +) + + +def compute_repeat_offender_decision( + *, + cycle_id: int, + current_suspect_nodes: Iterable[str], + history_records: Iterable[FactHistoryRecord], + hot_cache_records: Iterable[FactHistoryRecord], + max_candidate_nodes: int, + min_repeat_count_for_avoid: int, + max_avoids_per_cycle: int, +) -> AvoidDecision: + suspects = sorted({str(node) for node in current_suspect_nodes if str(node)}) + if not suspects: + return AvoidDecision(cycle_id=cycle_id, status="skipped") + + if len(suspects) > max_candidate_nodes: + return AvoidDecision(cycle_id=cycle_id, status="skipped") + + prior_by_node: dict[str, dict[tuple[str, str, str], FactHistoryRecord]] = { + node: {} for node in suspects + } + for record in list(history_records) + list(hot_cache_records): + if record.node not in prior_by_node: + continue + prior_by_node[record.node][(record.cluster, record.node, record.episode_id)] = record + + candidates = [] + for node in suspects: + prior_records = list(prior_by_node[node].values()) + repeat_count = 1 + len(prior_records) + if repeat_count < min_repeat_count_for_avoid: + continue + prior_last_seen = _max_event_time(prior_records) + candidates.append( + AvoidCandidate( + node=node, + repeat_count=repeat_count, + prior_last_seen=prior_last_seen, + ) + ) + + candidates.sort( + key=lambda candidate: ( + -candidate.repeat_count, + -_timestamp(candidate.prior_last_seen), + candidate.node, + ) + ) + avoid_nodes = [candidate.node for candidate in candidates[: max(0, max_avoids_per_cycle)]] + return AvoidDecision( + cycle_id=cycle_id, + status="ready", + avoid_nodes=avoid_nodes, + candidates=candidates, + ) + + +def _max_event_time(records: Iterable[FactHistoryRecord]) -> datetime | None: + event_times = [record.event_time for record in records] + if not event_times: + return None + return max(event_times) + + +def _timestamp(value: datetime | None) -> float: + if value is None: + return 0.0 + return value.timestamp() diff --git a/src/nvidia_resiliency_ext/attribution/fact/rpc.py b/src/nvidia_resiliency_ext/attribution/fact/rpc.py new file mode 100644 index 00000000..8a76d11f --- /dev/null +++ b/src/nvidia_resiliency_ext/attribution/fact/rpc.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import json +import os +import socket +import struct +import tempfile +from typing import Any + +DEFAULT_RPC_TIMEOUT_S = 2.0 +DEFAULT_MAX_RPC_BYTES = 16 * 1024 * 1024 + + +def default_socket_path() -> str: + return os.path.join(tempfile.gettempdir(), f"nvrx-fact-agent-{os.getuid()}.sock") + + +def json_dumps(payload: Any) -> bytes: + return json.dumps(payload, separators=(",", ":"), sort_keys=True).encode("utf-8") + + +def send_frame(conn: socket.socket, payload: dict[str, Any]) -> None: + body = json_dumps(payload) + conn.sendall(struct.pack("!I", len(body)) + body) + + +def recv_exact(conn: socket.socket, size: int) -> bytes: + chunks = [] + remaining = size + while remaining > 0: + chunk = conn.recv(remaining) + if not chunk: + raise EOFError("socket closed while reading frame") + chunks.append(chunk) + remaining -= len(chunk) + return b"".join(chunks) + + +def recv_frame(conn: socket.socket, *, max_bytes: int) -> dict[str, Any]: + raw_size = recv_exact(conn, 4) + size = struct.unpack("!I", raw_size)[0] + if size > max_bytes: + raise ValueError(f"UDS request is too large: {size} bytes > {max_bytes} bytes") + body = recv_exact(conn, size) + payload = json.loads(body.decode("utf-8")) + if not isinstance(payload, dict): + raise ValueError("UDS request must be a JSON object") + return payload + + +def notify_fact_agent( + *, + socket_path: str, + payload: dict[str, Any], + timeout_s: float = DEFAULT_RPC_TIMEOUT_S, + max_bytes: int = DEFAULT_MAX_RPC_BYTES, +) -> dict[str, Any]: + """Send one framed JSON RPC to the local FACT agent and return its ACK.""" + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock: + sock.settimeout(timeout_s) + sock.connect(socket_path) + send_frame(sock, payload) + return recv_frame(sock, max_bytes=max_bytes) diff --git a/src/nvidia_resiliency_ext/attribution/policy/__init__.py b/src/nvidia_resiliency_ext/attribution/policy/__init__.py new file mode 100644 index 00000000..24a4d71e --- /dev/null +++ b/src/nvidia_resiliency_ext/attribution/policy/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Attribution-driven placement policy notes.""" diff --git a/src/nvidia_resiliency_ext/fault_tolerance/__init__.py b/src/nvidia_resiliency_ext/fault_tolerance/__init__.py index c3a9cc28..3f76ab43 100644 --- a/src/nvidia_resiliency_ext/fault_tolerance/__init__.py +++ b/src/nvidia_resiliency_ext/fault_tolerance/__init__.py @@ -14,6 +14,8 @@ # limitations under the License. from .config import FaultToleranceConfig # noqa: F401 +from .config import HealthLoggingConfig # noqa: F401 +from .config import HealthLogSourceConfig # noqa: F401 from .data import WorkloadAction # noqa: F401 from .data import WorkloadControlRequest # noqa: F401 from .rank_monitor_client import RankMonitorClient # noqa: F401 diff --git a/src/nvidia_resiliency_ext/fault_tolerance/cli_args.py b/src/nvidia_resiliency_ext/fault_tolerance/cli_args.py index d1bfba94..b1e437d8 100644 --- a/src/nvidia_resiliency_ext/fault_tolerance/cli_args.py +++ b/src/nvidia_resiliency_ext/fault_tolerance/cli_args.py @@ -336,10 +336,12 @@ def _add_attribution_args(parser: argparse.ArgumentParser) -> None: default=None, dest="ft_attribution_endpoint", help=( - "Endpoint for the attribution service. Default: disabled. " + "Endpoint for the application-log attribution service that returns job-level " + "restart recommendations such as STOP/RESTART. Default: disabled. " "Set to localhost to let the TCPStore host process manage nvrx-attrsvc. " "Use an explicit endpoint such as http://host:port, grpc://host:port, " - "or unix:///path/to/socket for an externally managed attribution service." + "or unix:///path/to/socket for an externally managed service. " + "This is separate from FACT node attribution configured with --ft-fact-url." ), ) parser.add_argument( @@ -349,7 +351,8 @@ def _add_attribution_args(parser: argparse.ArgumentParser) -> None: default=DEFAULT_ATTRIBUTION_STARTUP_TIMEOUT, dest="ft_attribution_startup_timeout", help=( - "Seconds to wait for launcher-managed attribution service /healthz readiness. " + "Seconds to wait for launcher-managed application-log attribution service " + "/healthz readiness. " f"Default: {DEFAULT_ATTRIBUTION_STARTUP_TIMEOUT}." ), ) @@ -360,7 +363,8 @@ def _add_attribution_args(parser: argparse.ArgumentParser) -> None: default=None, dest="ft_attribution_llm_api_key_file", help=( - "Path to the LLM API key file for launcher-managed attribution service. " + "Path to the LLM API key file for launcher-managed application-log " + "attribution service. " "If unset, LLM_API_KEY_FILE from the launcher environment is used." ), ) @@ -370,7 +374,7 @@ def _add_attribution_args(parser: argparse.ArgumentParser) -> None: type=str, default=None, dest="ft_attribution_llm_base_url", - help="LLM base URL for launcher-managed attribution service.", + help="LLM base URL for launcher-managed application-log attribution service.", ) parser.add_argument( "--ft-attribution-llm-model", @@ -378,7 +382,7 @@ def _add_attribution_args(parser: argparse.ArgumentParser) -> None: type=str, default=None, dest="ft_attribution_llm_model", - help="LLM model identifier for launcher-managed attribution service.", + help="LLM model identifier for launcher-managed application-log attribution service.", ) parser.add_argument( "--ft-attribution-analysis-backend", @@ -387,7 +391,7 @@ def _add_attribution_args(parser: argparse.ArgumentParser) -> None: default=None, dest="ft_attribution_analysis_backend", choices=("mcp", "lib"), - help="Analysis backend for launcher-managed attribution service: mcp or lib.", + help="Analysis backend for launcher-managed application-log attribution service: mcp or lib.", ) parser.add_argument( "--ft-attribution-compute-timeout", @@ -395,7 +399,10 @@ def _add_attribution_args(parser: argparse.ArgumentParser) -> None: type=float, default=None, dest="ft_attribution_compute_timeout", - help="Analysis compute timeout in seconds for launcher-managed attribution service.", + help=( + "Analysis compute timeout in seconds for launcher-managed application-log " + "attribution service." + ), ) parser.add_argument( "--ft-attribution-log-level", @@ -404,7 +411,7 @@ def _add_attribution_args(parser: argparse.ArgumentParser) -> None: default=None, dest="ft_attribution_log_level", choices=("DEBUG", "INFO", "WARNING"), - help="Log level for launcher-managed attribution service.", + help="Log level for launcher-managed application-log attribution service.", ) parser.add_argument( "--ft-attribution-export-url", @@ -412,7 +419,111 @@ def _add_attribution_args(parser: argparse.ArgumentParser) -> None: type=str, default=None, dest="ft_attribution_export_url", - help="Complete result export URL for launcher-managed attribution service.", + help=( + "Complete result export URL for launcher-managed application-log " + "attribution service." + ), + ) + parser.add_argument( + "--ft-fact-url", + type=str, + default=None, + dest="ft_fact_url", + help=( + "FACT API URL used by nvrx-fact-agent for node-level attribution from host " + "evidence, e.g. http://host:8001/latest. ft_launcher starts the local " + "agent and passes this URL to it. Separate from --ft-attribution-endpoint." + ), + ) + parser.add_argument( + "--ft-fact-agent-socket-path", + type=str, + default=None, + dest="ft_fact_agent_socket_path", + help=( + "Advanced override for the local UDS path used by launcher-managed " + "nvrx-fact-agent. Defaults to a private per-launcher tmp path." + ), + ) + parser.add_argument( + "--ft-fact-agent-rpc-timeout", + type=float, + default=None, + dest="ft_fact_agent_rpc_timeout", + help="Timeout in seconds for the local nvrx-fact-agent ACK. Default: 2.", + ) + parser.add_argument( + "--ft-fact-policy-ready-timeout", + type=float, + default=None, + dest="ft_fact_policy_ready_timeout", + help=( + "Maximum seconds for the store-host launcher to wait for FACT avoid_nodes " + "before rendezvous fails open. Default: 60." + ), + ) + parser.add_argument( + "--ft-fact-agent-store-timeout", + type=float, + default=None, + dest="ft_fact_agent_store_timeout", + help="Timeout in seconds for nvrx-fact-agent TCPStore reads. Default: 60.", + ) + parser.add_argument( + "--ft-fact-history-es-url", + type=str, + default=None, + dest="ft_fact_history_es_url", + help="FACT history backend URL used for repeat-offender avoid policy.", + ) + parser.add_argument( + "--ft-fact-history-es-auth-file", + type=str, + default=None, + dest="ft_fact_history_es_auth_file", + help="Auth file for FACT history backend. Contents are read by nvrx-fact-agent.", + ) + parser.add_argument( + "--ft-fact-history-lookback", + type=str, + default=None, + dest="ft_fact_history_lookback", + help="FACT history lookback window for repeat-offender policy. Default: 14d.", + ) + parser.add_argument( + "--ft-fact-history-index", + type=str, + default=None, + dest="ft_fact_history_index", + help="Deployment-specific FACT node-history index or collection.", + ) + parser.add_argument( + "--ft-fact-history-max-candidate-nodes", + type=int, + default=None, + dest="ft_fact_history_max_candidate_nodes", + help="Skip history lookup when current FACT suspects exceed this count. Default: 16.", + ) + parser.add_argument( + "--ft-fact-history-query-timeout", + type=float, + default=None, + dest="ft_fact_history_query_timeout", + help="FACT history query timeout in seconds. Default: 30.", + ) + parser.add_argument( + "--ft-fact-min-repeat-count-for-avoid", + type=int, + default=None, + dest="ft_fact_min_repeat_count_for_avoid", + help="Minimum current+prior same-node count required to avoid a node. Default: 2.", + ) + parser.add_argument( + "--ft-fact-max-attribution-avoids-per-cycle", + type=int, + default=None, + dest="ft_fact_max_attribution_avoids_per_cycle", + help="Maximum attribution-based avoid nodes per cycle. Default: 1.", ) diff --git a/src/nvidia_resiliency_ext/fault_tolerance/config.py b/src/nvidia_resiliency_ext/fault_tolerance/config.py index bf5377bd..87abc2c4 100644 --- a/src/nvidia_resiliency_ext/fault_tolerance/config.py +++ b/src/nvidia_resiliency_ext/fault_tolerance/config.py @@ -19,11 +19,32 @@ import logging import signal from dataclasses import dataclass, fields -from typing import Mapping, Optional +from typing import Any, Mapping, Optional import yaml +@dataclass +class HealthLogSourceConfig: + """Enablement for one FACT agent file output source.""" + + enabled: bool = False + + +@dataclass +class HealthLoggingConfig: + """Optional per-cycle evidence file configuration.""" + + prefix: Optional[str] = None + dmesg: HealthLogSourceConfig = dataclasses.field(default_factory=HealthLogSourceConfig) + fact_result: HealthLogSourceConfig = dataclasses.field(default_factory=HealthLogSourceConfig) + healthcheck: HealthLogSourceConfig = dataclasses.field(default_factory=HealthLogSourceConfig) + + @property + def is_any_source_enabled(self) -> bool: + return self.dmesg.enabled or self.fact_result.enabled or self.healthcheck.enabled + + @dataclass class FaultToleranceConfig: """ @@ -95,10 +116,33 @@ class FaultToleranceConfig: out-of-section timeouts. The first N iterations (relative to cycle start) are excluded from timeout monitoring as they can be significantly slower than steady-state iterations. Default: 5. Can be overridden by workload (e.g., Megatron-LM via init_workload_monitoring). - * Attribution service (optional, disabled unless `attribution_endpoint` is set): - - `attribution_endpoint` [str] endpoint of the attribution service + * Application-log attribution service (optional, disabled unless + `attribution_endpoint` is set): + - `attribution_endpoint` [str] endpoint of the service that returns + job-level restart recommendations such as STOP/RESTART - `attribution_export_url` [str] complete export posting URL for - launcher-managed attribution service postprocessing. + launcher-managed application-log attribution service postprocessing. + * FACT agent (optional): + - `fact_url` [str] FACT API URL used by `nvrx-fact-agent` for node-level + attribution from host evidence. The launcher starts the local agent + and passes this URL to it. + - `fact_agent_socket_path` [str|None] optional local UDS path override + for the launcher-managed `nvrx-fact-agent` + - `fact_agent_rpc_timeout` [float] timeout for local UDS ACK + - `fact_policy_ready_timeout` [float] maximum time to wait for an + avoid-node policy decision before rendezvous fails open + - `fact_agent_store_timeout` [float] timeout used by the agent for TCPStore reads + - `fact_history_es_url` [str|None] FACT history backend URL for + repeat-offender avoid policy + - `fact_history_es_auth_file` [str|None] auth file for the FACT history backend + * Dmesg evidence files (optional): + - `health_logging.prefix` [str] absolute file prefix for per-cycle + shared evidence artifacts + - `health_logging.dmesg.enabled` [bool] asks `nvrx-fact-agent` to write + the collected dmesg text to a shared per-cycle file on failed cycles + - `health_logging.fact_result.enabled` [bool] asks `nvrx-fact-agent` to + write per-node FACT submission records and the store-host FACT result + record through the launcher gRPC log funnel * `cycle_info_dir` [str|None] Full path to the NVRx cycle info directory (e.g. /nvrx/). If set, the rendezvous host writes cycle info JSON files and @@ -149,9 +193,25 @@ class FaultToleranceConfig: # Attribution service configuration (optional) attribution_endpoint: Optional[str] = None attribution_export_url: Optional[str] = None + # FACT agent configuration (optional) + fact_url: Optional[str] = None + fact_agent_socket_path: Optional[str] = None + fact_agent_rpc_timeout: float = 2.0 + fact_policy_ready_timeout: float = 60.0 + fact_agent_store_timeout: float = 60.0 + fact_history_es_url: Optional[str] = None + fact_history_es_auth_file: Optional[str] = None + fact_history_lookback: str = "14d" + fact_history_index: Optional[str] = None + fact_history_max_candidate_nodes: int = 16 + fact_history_query_timeout: float = 30.0 + fact_min_repeat_count_for_avoid: int = 2 + fact_max_attribution_avoids_per_cycle: int = 1 # NVRx cycle info: base directory for cycle_info JSON files cycle_info_dir: Optional[str] = None + # Standalone health logging configuration + health_logging: HealthLoggingConfig = dataclasses.field(default_factory=HealthLoggingConfig) @property def is_progress_tracking_enabled(self) -> bool: @@ -178,6 +238,11 @@ def from_kwargs(ignore_not_recognized: bool = True, **kwargs) -> 'FaultTolerance extra_args = {k: v for k, v in kwargs.items() if k not in fields_set} if extra_args and not ignore_not_recognized: raise ValueError(f"Not recognized args: {extra_args}") + health_logging = matching_args.get("health_logging") + if isinstance(health_logging, dict): + matching_args["health_logging"] = FaultToleranceConfig._parse_health_logging_dict( + health_logging + ) return FaultToleranceConfig(**matching_args) @staticmethod @@ -289,6 +354,10 @@ def from_args(args: argparse.Namespace): 'gpu_memory_reclaim_timeout', 'gpu_memory_tolerance_mb', 'gpu_memory_poll_interval', + 'fact_agent_rpc_timeout', + 'fact_policy_ready_timeout', + 'fact_agent_store_timeout', + 'fact_history_query_timeout', ] for field in fields(FaultToleranceConfig): cli_field_name = f"ft_{field.name}" @@ -304,12 +373,60 @@ def from_args(args: argparse.Namespace): for arg_name, arg_val in cli_ft_args.items(): setattr(ft_cfg, arg_name, arg_val) + # Health logging CLI overrides + health_logging = ft_cfg.health_logging + prefix = getattr(args, "ft_health_log_prefix", None) + if prefix is not None: + health_logging.prefix = prefix + dmesg_enabled = getattr(args, "ft_enable_health_log_dmesg", None) + if dmesg_enabled is not None: + health_logging.dmesg.enabled = bool(dmesg_enabled) + fact_result_enabled = getattr(args, "ft_enable_fact_result_artifact", None) + if fact_result_enabled is not None: + health_logging.fact_result.enabled = bool(fact_result_enabled) + ft_cfg.health_logging = health_logging + # Fix any type issues ft_cfg._fix_log_level_type() ft_cfg._fix_rank_termination_signal_type() return ft_cfg + @staticmethod + def _parse_health_logging_dict(data: Any) -> HealthLoggingConfig: + if isinstance(data, HealthLoggingConfig): + return data + if data is None: + return HealthLoggingConfig() + if not isinstance(data, dict): + raise ValueError( + f"Invalid health_logging config: expected mapping, got {type(data).__name__}" + ) + + prefix = data.get("prefix") + dmesg_cfg = data.get("dmesg", {}) + fact_result_cfg = data.get("fact_result", {}) + healthcheck_cfg = data.get("healthcheck", {}) + + def _parse_source(source_name: str, raw: Any) -> HealthLogSourceConfig: + if raw is None: + return HealthLogSourceConfig() + if isinstance(raw, HealthLogSourceConfig): + return raw + if not isinstance(raw, dict): + raise ValueError( + f"Invalid health_logging.{source_name} config: expected mapping, got {type(raw).__name__}" + ) + enabled = raw.get("enabled", False) + return HealthLogSourceConfig(enabled=bool(enabled)) + + return HealthLoggingConfig( + prefix=prefix, + dmesg=_parse_source("dmesg", dmesg_cfg), + fact_result=_parse_source("fact_result", fact_result_cfg), + healthcheck=_parse_source("healthcheck", healthcheck_cfg), + ) + def to_yaml_file(self, cfg_path: str) -> None: """ Convert the configuration object to a YAML file and save it to the specified path. diff --git a/src/nvidia_resiliency_ext/fault_tolerance/ft_rendezvous_barrier.py b/src/nvidia_resiliency_ext/fault_tolerance/ft_rendezvous_barrier.py index 7c76a1c3..3b7fd003 100644 --- a/src/nvidia_resiliency_ext/fault_tolerance/ft_rendezvous_barrier.py +++ b/src/nvidia_resiliency_ext/fault_tolerance/ft_rendezvous_barrier.py @@ -863,6 +863,62 @@ def with_synthetic_infra_ranks(participants: List[Participant]) -> List[Particip sorted(standby_only_participants, key=lambda x: x[1]), ) + @staticmethod + def _participant_node_key(node_desc: _NodeDesc) -> str: + return str(node_desc.addr).split(":", 1)[0] + + def _fact_avoid_nodes(self) -> List[str]: + agent = self._agent + if agent is None: + return [] + get_avoid_nodes = getattr(agent, "get_fact_avoid_nodes_for_rendezvous", None) + if not callable(get_avoid_nodes): + return [] + try: + return [str(node).split(":", 1)[0] for node in get_avoid_nodes() if str(node)] + except Exception as exc: + log.info("FACT avoid-node lookup failed; proceeding without avoid nodes: %s", exc) + return [] + + def _apply_fact_avoid_nodes( + self, + active_candidate_participants: List[Participant], + standby_only_participants: List[Participant], + min_nodes: int, + ) -> Tuple[List[Participant], List[Participant]]: + avoid_nodes = self._fact_avoid_nodes() + if not avoid_nodes: + return active_candidate_participants, standby_only_participants + + avoid_set = set(avoid_nodes) + kept_active = [] + demoted = [] + for participant in active_candidate_participants: + node_desc, _, _, _ = participant + if self._participant_node_key(node_desc) in avoid_set: + demoted.append(participant) + else: + kept_active.append(participant) + + if not demoted: + return active_candidate_participants, standby_only_participants + + if len(kept_active) < min_nodes or not self._can_meet_segment_constraint( + kept_active, + min_nodes, + ): + log.info( + "Skipping FACT avoid_nodes=%s because placement would be infeasible", + avoid_nodes, + ) + return active_candidate_participants, standby_only_participants + + log.info( + "Applying FACT avoid_nodes=%s by assigning them standby ranks", + [self._participant_node_key(node_desc) for node_desc, _, _, _ in demoted], + ) + return kept_active, demoted + standby_only_participants + def _assign_group_ranks( self, active_candidate_participants: List[Participant], @@ -1530,6 +1586,11 @@ def _host_close_round( f"participants{replacement_group_info}{unhealthy_replacement_group_info}, " f"min={min_nodes} (fetch {fetch_elapsed*1000:.1f}ms, check {check_elapsed*1000:.1f}ms)" ) + active_candidate_participants, standby_only_participants = self._apply_fact_avoid_nodes( + complete_replacement_group_participants, + incomplete_replacement_group_participants, + min_nodes, + ) # Assign ranks BEFORE setting round_done=1 so Step 3 readers can get their rank # immediately without any additional waiting. self.assign_group_ranks( @@ -1537,8 +1598,8 @@ def _host_close_round( max_nodes, node_desc, slot_participants=current_round_participants, - active_candidate_participants=complete_replacement_group_participants, - standby_only_participants=incomplete_replacement_group_participants, + active_candidate_participants=active_candidate_participants, + standby_only_participants=standby_only_participants, ) self._report_cycle_start_as_host(self._round) self.store.set(self.round_done_key, "1".encode('utf-8')) diff --git a/src/nvidia_resiliency_ext/fault_tolerance/launcher.py b/src/nvidia_resiliency_ext/fault_tolerance/launcher.py index 57bd7f99..8eeb6690 100644 --- a/src/nvidia_resiliency_ext/fault_tolerance/launcher.py +++ b/src/nvidia_resiliency_ext/fault_tolerance/launcher.py @@ -35,6 +35,7 @@ from argparse import REMAINDER, ArgumentParser from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field +from datetime import datetime, timezone from string import Template from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union @@ -69,6 +70,12 @@ ) from torch.distributed.elastic.utils import macros +from nvidia_resiliency_ext.attribution.fact.client import normalize_fact_attribution_url +from nvidia_resiliency_ext.attribution.fact.manager import FactAgentManager +from nvidia_resiliency_ext.attribution.fact.rpc import ( + default_socket_path as default_fact_agent_socket_path, +) +from nvidia_resiliency_ext.attribution.fact.rpc import notify_fact_agent from nvidia_resiliency_ext.fault_tolerance.attribution_manager import ( DEFAULT_ATTRIBUTION_PORT, AttributionConfig, @@ -132,6 +139,16 @@ # [root, leaf0, ..., leaf_{N-1}] when ft_log_aggregator_count > 1. _GRPC_SERVER_PROCESSES: List[subprocess.Popen] = [] _ATTRIBUTION_MANAGER: Optional[AttributionManager] = None +_FACT_AGENT_MANAGER: Optional[FactAgentManager] = None + + +def _resolve_fact_username() -> Optional[str]: + return os.environ.get("SLURM_JOB_USER") or os.environ.get("USER") or os.environ.get("LOGNAME") + + +def _resolve_fact_cluster() -> Optional[str]: + return os.environ.get("SLURM_CLUSTER_NAME") or os.environ.get("NVRX_CLUSTER_NAME") + def init_node_health_check(endpoint: Optional[str]) -> None: global _NODE_HEALTH_CHECK_INSTANCE @@ -377,6 +394,7 @@ def __init__( workers_stop_timeout: float = 30, restart_policy: str = "any-failed", is_store_host: bool = False, + rdzv_endpoint: Optional[str] = None, rank_monitors: Optional[Dict[int, RankMonitorState]] = None, ): super().__init__(spec, exit_barrier_timeout) @@ -389,6 +407,7 @@ def __init__( self._term_timeout = term_timeout self._workers_stop_timeout = workers_stop_timeout self._is_store_host = is_store_host + self._rdzv_endpoint = rdzv_endpoint # Rank monitor state (process, IPC connections, listener tasks) per local rank # Pre-created rank monitors passed from config (created before gRPC) self._rank_monitors: Dict[int, RankMonitorState] = rank_monitors or dict() @@ -418,6 +437,9 @@ def __init__( self._children_pgids: Set[int] = set() self._restart_policy = restart_policy self._node_id = self._get_fq_hostname() + self._pending_fact_agent_cycle: Optional[int] = None + self._fact_agent_cycle_start_time: Optional[datetime] = None + self._last_fact_agent_cycle: Optional[int] = None DEFAULT_ROLE = "default" # FIXME @@ -529,6 +551,53 @@ def _open_rendezvous_for_restart(self): logger.error(f"Failed to open rendezvous: {e}") # For legacy rendezvous, no action needed - it uses different mechanism + def _attribution_recommends_stop(self, role: str) -> bool: + service = getattr(self._rdzv_handler, "_attribution_service", None) + if service is None: + return False + + get_last_result = getattr(service, "get_last_result", None) + if not callable(get_last_result): + return False + + try: + attrsvc_result = get_last_result() + except Exception as exc: + logger.warning( + "[%s] Application-log attribution result fetch failed: %s", role, exc + ) + return False + + if attrsvc_result is None: + logger.info( + "[%s] Application-log attribution returned no restart recommendation", role + ) + return False + + recommendation = getattr(attrsvc_result, "recommendation", None) + if getattr(attrsvc_result, "should_stop", False) is True: + logger.error( + "[%s] Application-log attribution recommended STOP for %s; " + "no restart will be attempted.", + role, + getattr(attrsvc_result, "log_path", "submitted log"), + ) + return True + + recommendation_text = ( + recommendation + if isinstance(recommendation, str) + else getattr(recommendation, "action", None) + ) + if recommendation_text: + logger.info( + "[%s] Application-log attribution recommendation for %s: %s", + role, + getattr(attrsvc_result, "log_path", "submitted log"), + recommendation_text, + ) + return False + def _handle_restart_decision( self, role: str, @@ -548,9 +617,14 @@ def _handle_restart_decision( True if restart was initiated (caller should continue monitoring loop) False if no restart (caller should stop workers and return failure) """ + if self._ft_cfg.fact_url: + self._pending_fact_agent_cycle = self._get_global_cycle_number() + attrsvc_recommends_stop = self._attribution_recommends_stop(role) self._progress_tracker.analyze_previous_cycle() should_terminate_early = self._progress_tracker.should_terminate_early() + if attrsvc_recommends_stop: + return False if should_terminate_early: logger.error( "[%s] Progress tracker detected no progress across restarts. " @@ -868,6 +942,119 @@ def _log_watchdog_event( event = events.Event(name=name, source=events.EventSource.AGENT, metadata=metadata) events.record(event) + def _get_fact_agent_node_candidates(self) -> List[str]: + get_active = getattr(self._rdzv_handler, "get_active_node_addrs", None) + active_addrs = get_active() if callable(get_active) else None + if not active_addrs: + return [] + nodes = [] + for addr in active_addrs: + node = str(addr).split(":", 1)[0] + if node: + nodes.append(node) + return sorted(set(nodes)) + + def _get_fact_agent_local_node(self) -> str: + this_node = getattr(self._rdzv_handler, "_this_node", None) + node_addr = getattr(this_node, "addr", None) + return str(node_addr or self._node_id or socket.getfqdn(socket.gethostname())) + + def _notify_fact_agent(self, _worker_group: WorkerGroup, cycle_index: int) -> None: + if not self._ft_cfg.fact_url: + return + + local_node = self._get_fact_agent_local_node() + expected_nodes = self._get_fact_agent_node_candidates() if self._is_store_host else [] + if self._is_store_host and not expected_nodes: + logger.warning( + "nvrx-fact-agent could not discover active nodes from rendezvous; " + "FACT attribution scope is limited to local node %s", + local_node, + ) + expected_nodes = [local_node] + + payload: Dict[str, Any] = { + "event": "cycle_failed", + "cycle": cycle_index, + "cycle_end_time": datetime.now(timezone.utc).isoformat(), + "expected_nodes": expected_nodes, + } + if self._fact_agent_cycle_start_time is not None: + payload["cycle_start_time"] = self._fact_agent_cycle_start_time.isoformat() + socket_path = self._ft_cfg.fact_agent_socket_path or default_fact_agent_socket_path() + try: + ack = notify_fact_agent( + socket_path=socket_path, + payload=payload, + timeout_s=self._ft_cfg.fact_agent_rpc_timeout, + ) + except Exception as exc: + logger.warning( + "nvrx-fact-agent notification failed for cycle %s via %s: %s", + cycle_index, + socket_path, + exc, + ) + return + + if not ack.get("accepted"): + logger.warning( + "nvrx-fact-agent rejected cycle %s request: %s", + cycle_index, + ack.get("error", ack), + ) + return + logger.info("nvrx-fact-agent accepted cycle %s request", cycle_index) + self._last_fact_agent_cycle = cycle_index + + def get_fact_avoid_nodes_for_rendezvous(self) -> List[str]: + if not self._ft_cfg.fact_url or not self._is_store_host: + return [] + cycle_index = self._last_fact_agent_cycle + if cycle_index is None: + return [] + socket_path = self._ft_cfg.fact_agent_socket_path or default_fact_agent_socket_path() + deadline = time.monotonic() + max(0.0, self._ft_cfg.fact_policy_ready_timeout) + response: Dict[str, Any] = {} + while True: + try: + response = notify_fact_agent( + socket_path=socket_path, + payload={"event": "get_avoid_nodes", "cycle": cycle_index}, + timeout_s=self._ft_cfg.fact_agent_rpc_timeout, + ) + except Exception as exc: + logger.info("nvrx-fact-agent avoid-node query failed: %s", exc) + return [] + + status = response.get("status") + if status == "ready": + break + if status != "pending": + logger.info( + "nvrx-fact-agent avoid-node decision unavailable for cycle %s: %s", + cycle_index, + response, + ) + return [] + remaining = deadline - time.monotonic() + if remaining <= 0: + logger.info( + "nvrx-fact-agent avoid-node decision not ready for cycle %s: %s", + cycle_index, + response, + ) + return [] + time.sleep(min(0.5, remaining)) + + avoid_nodes = response.get("avoid_nodes") + if not isinstance(avoid_nodes, list): + return [] + nodes = [str(node) for node in avoid_nodes if str(node)] + if nodes: + logger.info("nvrx-fact-agent suggested avoid_nodes=%s", nodes) + return nodes + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator # `torch.distributed.elastic.metrics.prof`. @prof @@ -932,6 +1119,11 @@ def _stop_workers(self, worker_group: WorkerGroup, *args, **kwargs) -> None: # that would cause cycle N+1 data to be written to cycle N log files. self._logs_specs.clear_all_pipes_from_reader() + fact_agent_cycle = self._pending_fact_agent_cycle + self._pending_fact_agent_cycle = None + if fact_agent_cycle is not None: + self._notify_fact_agent(worker_group, fact_agent_cycle) + # Record worker termination event after shutdown is complete record_profiling_event( ProfilingEvent.WORKER_TERMINATED, @@ -951,6 +1143,7 @@ def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: # At this point, rendezvous has completed and we're about to start workers. # The cycle number is used for profiling and environment variable setting. current_cycle = restart_count = self._get_global_cycle_number() + self._fact_agent_cycle_start_time = datetime.now(timezone.utc) # Send current cycle number to rank monitors for logging self._send_cycle_to_rank_monitors(restart_count) @@ -1550,6 +1743,7 @@ def launch_agent( workers_stop_timeout=config.workers_stop_timeout, restart_policy=config.restart_policy, is_store_host=is_store_host, + rdzv_endpoint=config.rdzv_endpoint, rank_monitors=config.rank_monitors, # Pass pre-created rank monitors ) @@ -2169,6 +2363,41 @@ def get_args_parser() -> ArgumentParser: enable_type=lambda x: x.lower() == 'true', ) + parser.add_argument( + "--ft-health-log-prefix", + action=env, + type=str, + default=None, + dest="ft_health_log_prefix", + help="Prefix for per-cycle health log files (must be absolute path when used, e.g. " + "/lustre/logs/job_health.log). " + "Dmesg health logs are derived by inserting the source and cycle before the extension, " + "for example /lustre/logs/job_health_dmesg_cycle0.log.", + ) + + parser.add_argument( + "--ft-enable-health-log-dmesg", + action=env, + type=lambda x: str(x).lower() not in ["false", "0", "no"], + default=None, + dest="ft_enable_health_log_dmesg", + help="Ask nvrx-fact-agent to queue shared per-cycle dmesg evidence files on " + "failed cycles. Requires --ft-fact-url, --ft-health-log-prefix, and launcher " + "gRPC log aggregation.", + ) + + parser.add_argument( + "--ft-enable-fact-result-artifact", + action=env, + type=lambda x: str(x).lower() not in ["false", "0", "no"], + default=None, + dest="ft_enable_fact_result_artifact", + help="Ask nvrx-fact-agent to write per-node FACT submission records and the " + "store-host FACT result record through the gRPC log funnel. Requires " + "--ft-fact-url, --ft-health-log-prefix, --ft-per-cycle-applog-prefix, " + "and --ft-enable-log-server true.", + ) + parser.add_argument( "-r", "--redirects", @@ -2787,7 +3016,7 @@ def _validate_attribution_requires_per_cycle_applog( ): raise ValueError( "--ft-attribution-endpoint requires --ft-per-cycle-applog-prefix to be specified. " - "Attribution service needs per-cycle application logs as analysis input." + "Application-log attribution needs per-cycle application logs as analysis input." ) @@ -2860,6 +3089,67 @@ def config_from_args(args, launcher_pipe_read_fd=None, launcher_log_file=None) - fault_tol_cfg = FaultToleranceConfig.from_args(args) _validate_attribution_requires_per_cycle_applog(args, fault_tol_cfg) + health_logging_cfg = fault_tol_cfg.health_logging + if fault_tol_cfg.fact_agent_rpc_timeout <= 0: + raise ValueError("--ft-fact-agent-rpc-timeout must be positive") + + if fault_tol_cfg.fact_policy_ready_timeout < 0: + raise ValueError("--ft-fact-policy-ready-timeout must be non-negative") + if fault_tol_cfg.fact_agent_store_timeout < 0: + raise ValueError("--ft-fact-agent-store-timeout must be non-negative") + fact_history_requested = bool( + fault_tol_cfg.fact_history_es_url or fault_tol_cfg.fact_history_es_auth_file + ) + if fact_history_requested: + if not fault_tol_cfg.fact_url: + raise ValueError("FACT history policy requires --ft-fact-url") + if not fault_tol_cfg.fact_history_es_url or not fault_tol_cfg.fact_history_es_auth_file: + raise ValueError( + "FACT history policy requires both --ft-fact-history-es-url and " + "--ft-fact-history-es-auth-file" + ) + if fault_tol_cfg.fact_history_max_candidate_nodes < 1: + raise ValueError("--ft-fact-history-max-candidate-nodes must be positive") + if fault_tol_cfg.fact_history_query_timeout <= 0: + raise ValueError("--ft-fact-history-query-timeout must be positive") + if fault_tol_cfg.fact_min_repeat_count_for_avoid < 1: + raise ValueError("--ft-fact-min-repeat-count-for-avoid must be positive") + if fault_tol_cfg.fact_max_attribution_avoids_per_cycle < 0: + raise ValueError("--ft-fact-max-attribution-avoids-per-cycle must be non-negative") + fact_artifact_enabled = ( + health_logging_cfg.dmesg.enabled or health_logging_cfg.fact_result.enabled + ) + if fact_artifact_enabled: + if not fault_tol_cfg.fact_url: + raise ValueError( + "FACT health logging artifacts require --ft-fact-url so " + "ft_launcher can notify nvrx-fact-agent on failed cycles." + ) + if not health_logging_cfg.prefix: + raise ValueError( + "fault_tolerance.health_logging.prefix (or --ft-health-log-prefix) is required " + "when FACT health logging artifacts are enabled." + ) + if not os.path.isabs(health_logging_cfg.prefix): + raise ValueError( + f"--ft-health-log-prefix must be an absolute path, got: {health_logging_cfg.prefix}. " + "Example: /lustre/logs/job_health.log" + ) + if health_logging_cfg.fact_result.enabled: + if not getattr(args, "ft_per_cycle_applog_prefix", None) or not getattr( + args, "ft_enable_log_server", False + ): + raise ValueError( + "--ft-enable-fact-result-artifact requires gRPC log aggregation " + "(--ft-per-cycle-applog-prefix and --ft-enable-log-server true) " + "so FACT result JSONL is written by a single root log writer." + ) + fact_agent_enabled = fault_tol_cfg.fact_url is not None + if fact_agent_enabled: + try: + normalize_fact_attribution_url(fault_tol_cfg.fact_url or "") + except ValueError as exc: + raise ValueError(f"Invalid --ft-fact-url: {exc}") from exc # Pass segment-related configs to rendezvous config rdzv_configs['segment'] = fault_tol_cfg.segment @@ -2929,6 +3219,12 @@ def config_from_args(args, launcher_pipe_read_fd=None, launcher_log_file=None) - if base_log_file: rdzv_configs['cycle_log_prefix'] = base_log_file + rdzv_endpoint_host = parse_rendezvous_endpoint(rdzv_endpoint, default_port=-1)[0] + host, _ = parse_rendezvous_endpoint(rdzv_endpoint, default_port=0) + is_tcp_store_host = _is_store_host_from_config(host, rdzv_configs) + grpc_server_address = None + node_id = None + if base_log_file: # Validate that the path is absolute (not relative) if not os.path.isabs(base_log_file): @@ -2945,13 +3241,7 @@ def config_from_args(args, launcher_pipe_read_fd=None, launcher_log_file=None) - "Using PipeBasedLogsSpecs automatically." ) - rdzv_endpoint_host = parse_rendezvous_endpoint(rdzv_endpoint, default_port=-1)[0] - host, _ = parse_rendezvous_endpoint(rdzv_endpoint, default_port=0) - is_tcp_store_host = _is_store_host_from_config(host, rdzv_configs) - # Configure gRPC if enabled - grpc_server_address = None - node_id = None log_funnel_ports = None if getattr(args, 'ft_enable_log_server', False): @@ -2998,7 +3288,7 @@ def config_from_args(args, launcher_pipe_read_fd=None, launcher_log_file=None) - if not _GRPC_SERVER_PROCESSES: logger.error( "Failed to start gRPC log server(s) on TCP store host. " - "Disabling gRPC log aggregation for all nodes (falling back to direct file writing)." + "Disabling gRPC log aggregation for all nodes." ) grpc_server_address = None node_id = None @@ -3038,6 +3328,60 @@ def config_from_args(args, launcher_pipe_read_fd=None, launcher_log_file=None) - local_ranks_filter=ranks, ) + if fact_agent_enabled: + if ( + health_logging_cfg.dmesg.enabled or health_logging_cfg.fact_result.enabled + ) and grpc_server_address is None: + raise RuntimeError( + "FACT artifacts require gRPC log aggregation, but no gRPC " + "server address is available." + ) + fact_agent_manager = FactAgentManager( + fact_url=fault_tol_cfg.fact_url, + socket_path=fault_tol_cfg.fact_agent_socket_path, + rpc_timeout_s=fault_tol_cfg.fact_agent_rpc_timeout, + run_id=args.rdzv_id, + rdzv_endpoint=rdzv_endpoint, + store_timeout_s=fault_tol_cfg.fact_agent_store_timeout, + local_node=args.local_addr or socket.getfqdn(), + is_store_host=is_tcp_store_host, + job_id=( + os.environ.get("SLURM_ARRAY_JOB_ID") + or os.environ.get("SLURM_JOB_ID") + or args.rdzv_id + ), + ranks_per_node=nproc_per_node, + username=_resolve_fact_username(), + cluster=_resolve_fact_cluster(), + health_log_prefix=health_logging_cfg.prefix, + dmesg_artifact_enabled=health_logging_cfg.dmesg.enabled, + result_artifact_enabled=health_logging_cfg.fact_result.enabled, + grpc_server_address=grpc_server_address, + grpc_node_id=node_id, + fact_history_es_url=fault_tol_cfg.fact_history_es_url, + fact_history_es_auth_file=fault_tol_cfg.fact_history_es_auth_file, + fact_history_lookback=fault_tol_cfg.fact_history_lookback, + fact_history_index=fault_tol_cfg.fact_history_index, + fact_history_max_candidate_nodes=fault_tol_cfg.fact_history_max_candidate_nodes, + fact_history_query_timeout_s=fault_tol_cfg.fact_history_query_timeout, + fact_min_repeat_count_for_avoid=fault_tol_cfg.fact_min_repeat_count_for_avoid, + fact_max_attribution_avoids_per_cycle=( + fault_tol_cfg.fact_max_attribution_avoids_per_cycle + ), + ) + global _FACT_AGENT_MANAGER + _FACT_AGENT_MANAGER = fact_agent_manager + try: + fact_agent_endpoint = fact_agent_manager.start_if_needed() + if fact_agent_endpoint is not None: + fault_tol_cfg.fact_agent_socket_path = fact_agent_endpoint.socket_path + except Exception as exc: + logger.warning( + "Failed to start local nvrx-fact-agent; FACT dmesg collection " + "will be best-effort and may be unavailable: %s", + exc, + ) + config = LaunchConfig( min_nodes=min_nodes, max_nodes=max_nodes, @@ -3551,19 +3895,23 @@ def main(args=None): logger.info("Agent exits with exit code = 0.") exit_code = 0 finally: - # Clean up gRPC server AFTER all logging is done - # (logging cleanup already happened in run()'s finally block) - global _GRPC_SERVER_PROCESSES, _ATTRIBUTION_MANAGER + # Stop log-producing helper processes before the gRPC log servers so + # their final artifact writes can drain. + global _GRPC_SERVER_PROCESSES, _ATTRIBUTION_MANAGER, _FACT_AGENT_MANAGER + if _FACT_AGENT_MANAGER is not None: + _FACT_AGENT_MANAGER.stop() + _FACT_AGENT_MANAGER = None + + if _ATTRIBUTION_MANAGER is not None: + _ATTRIBUTION_MANAGER.stop() + _ATTRIBUTION_MANAGER = None + if _GRPC_SERVER_PROCESSES: grpc_graceful_shutdown_timeout = float( getattr(args, 'ft_log_server_graceful_shutdown_timeout', 60.0) ) stop_grpc_log_servers(_GRPC_SERVER_PROCESSES, grpc_graceful_shutdown_timeout) - if _ATTRIBUTION_MANAGER is not None: - _ATTRIBUTION_MANAGER.stop() - _ATTRIBUTION_MANAGER = None - sys.exit(exit_code) diff --git a/src/nvidia_resiliency_ext/fault_tolerance/per_cycle_logs.py b/src/nvidia_resiliency_ext/fault_tolerance/per_cycle_logs.py index bbeb3274..a03d9173 100644 --- a/src/nvidia_resiliency_ext/fault_tolerance/per_cycle_logs.py +++ b/src/nvidia_resiliency_ext/fault_tolerance/per_cycle_logs.py @@ -57,6 +57,12 @@ from torch.distributed.elastic.multiprocessing.subprocess_handler import SubprocessHandler from nvidia_resiliency_ext.shared_utils.log_manager import LogConfig +from nvidia_resiliency_ext.shared_utils.log_paths import ( + get_source_cycle_log_file as get_source_cycle_log_file, +) +from nvidia_resiliency_ext.shared_utils.log_paths import ( + insert_suffix_before_ext as insert_suffix_before_ext, +) from nvidia_resiliency_ext.shared_utils.proto import log_aggregation_pb2, log_aggregation_pb2_grpc # Special marker string to signal pipe-based logging @@ -1584,6 +1590,14 @@ def get_cycle_log_file(self, cycle_index: int) -> str: ext = os.path.splitext(self._base_log_file)[1] or ".log" return f"{base_without_ext}_cycle{cycle_index}{ext}" + @property + def grpc_server_address(self) -> Optional[str]: + return self._grpc_server_address + + @property + def node_id(self) -> Optional[Union[int, str]]: + return self._node_id + def cleanup(self): """ Gracefully shut down the reader thread. diff --git a/src/nvidia_resiliency_ext/shared_utils/health_check.py b/src/nvidia_resiliency_ext/shared_utils/health_check.py index e8d2baec..cec6afcd 100644 --- a/src/nvidia_resiliency_ext/shared_utils/health_check.py +++ b/src/nvidia_resiliency_ext/shared_utils/health_check.py @@ -31,7 +31,7 @@ import defusedxml.ElementTree as ET import httpx -from nvidia_resiliency_ext.attribution import parse_attrsvc_response +from nvidia_resiliency_ext.attribution import AttrSvcResult, parse_attrsvc_response from nvidia_resiliency_ext.attribution.orchestration.http_api import ( ROUTE_LOGS, get_log_response, @@ -1382,8 +1382,8 @@ def __call__(self) -> None: daemon=True, ).start() - def get_last_result(self) -> Optional[bool]: - """Synchronously fetch whether attribution recommends stopping the last log.""" + def get_last_result(self) -> Optional[AttrSvcResult]: + """Synchronously fetch attribution for the most recently submitted log.""" log_path = self._last_submitted if not log_path: logger.debug("AttributionService GET skipped: no submitted log path") @@ -1423,9 +1423,9 @@ def _do_submit_log(self, log_path: str) -> None: "AttributionService POST %s failed: %s: %s", log_path, type(e).__name__, e ) - def _get_results(self, log_path: str) -> Optional[bool]: + def _get_results(self, log_path: str) -> Optional[AttrSvcResult]: """ - Get the stop decision for a previously submitted log file via GET. + Get the normalized attribution result for a previously submitted log file via GET. """ base_url = self._http_base_url() if base_url is None: @@ -1439,7 +1439,7 @@ def _get_results(self, log_path: str) -> Optional[bool]: payload = resp.json() if resp.text else {} attrsvc_result = parse_attrsvc_response(payload, log_path=log_path) logger.info(attrsvc_result.format_log_message()) - return attrsvc_result.should_stop + return attrsvc_result else: logger.warning( "AttributionService GET for %s returned %d", log_path, resp.status_code diff --git a/src/nvidia_resiliency_ext/shared_utils/log_paths.py b/src/nvidia_resiliency_ext/shared_utils/log_paths.py new file mode 100644 index 00000000..a6e34315 --- /dev/null +++ b/src/nvidia_resiliency_ext/shared_utils/log_paths.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Shared helpers for cycle-scoped log and evidence artifact paths.""" + +from __future__ import annotations + +import os + + +def insert_suffix_before_ext(path: str, suffix: str) -> str: + """Insert ``suffix`` before the file extension of ``path``.""" + base_without_ext, ext = os.path.splitext(path) + return f"{base_without_ext}{suffix}{ext}" + + +def get_source_cycle_log_file(path_prefix: str, source_name: str, cycle_index: int) -> str: + """Build a source-specific cycle logfile path from ``path_prefix``.""" + return insert_suffix_before_ext(path_prefix, f"_{source_name}_cycle{cycle_index}") diff --git a/tests/attribution/unit/test_fact_agent.py b/tests/attribution/unit/test_fact_agent.py new file mode 100644 index 00000000..694edeff --- /dev/null +++ b/tests/attribution/unit/test_fact_agent.py @@ -0,0 +1,962 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import contextlib +import json +import os +import tempfile +import threading +import time +from datetime import datetime, timedelta, timezone +from unittest.mock import patch + +from nvidia_resiliency_ext.attribution.fact.agent import FactAgent, FactAgentKeys, FactAgentRequest +from nvidia_resiliency_ext.attribution.fact.client import FactAttributionResult +from nvidia_resiliency_ext.attribution.fact.models import FactHistoryRecord + + +class FakeStore: + def __init__(self): + self.data = {} + + def set(self, key, value): + self.data[key] = value if isinstance(value, bytes) else str(value).encode("utf-8") + + def get(self, key): + return self.data[key] + + def check(self, keys): + return all(key in self.data for key in keys) + + def add(self, key, amount): + value = int(self.data.get(key, b"0").decode("utf-8")) + amount + self.data[key] = str(value).encode("utf-8") + return value + + +class WaitStore(FakeStore): + def __init__(self): + super().__init__() + self.wait_calls = [] + + def wait(self, keys, timeout): + self.wait_calls.append((list(keys), timeout)) + if not self.check(keys): + raise RuntimeError("timed out") + + +class FlakyCountReadStore(FakeStore): + def __init__(self): + super().__init__() + self.done_count_reads = 0 + + def add(self, key, amount): + if amount != 0: + return super().add(key, amount) + self.done_count_reads += 1 + if self.done_count_reads == 1: + return 3 + raise RuntimeError("transient TCPStore read failure") + + +class FakeFactClient: + def __init__(self): + self.created = [] + self.submitted = [] + self.gets = [] + + def create_failure_attributor(self, **kwargs): + self.created.append(kwargs) + return "att-1" + + def submit_dmesg_text_observation(self, **kwargs): + self.submitted.append(kwargs) + if "Xid" not in kwargs["dmesg_text"]: + return None + return f"obs-{kwargs['default_hostname']}" + + def get_attribution_result(self, *, attributor_id, observation_ids): + self.gets.append({"attributor_id": attributor_id, "observation_ids": list(observation_ids)}) + return FactAttributionResult( + attributor_id=attributor_id, + observation_ids=list(observation_ids), + faulty_nodes=["node-a"], + attribution={"attributions": []}, + ) + + +class FakeHistoryClient: + def __init__(self, records): + self.records = records + self.queries = [] + + def query_node_history(self, **kwargs): + self.queries.append(kwargs) + return list(self.records) + + +class FailingCreateFactClient(FakeFactClient): + def create_failure_attributor(self, **kwargs): + self.created.append(kwargs) + raise RuntimeError("FACT unavailable") + + +class FlakyPostFactClient(FakeFactClient): + def __init__(self): + super().__init__() + self.failures_left = 1 + + def submit_dmesg_text_observation(self, **kwargs): + if self.failures_left: + self.failures_left -= 1 + raise RuntimeError("temporary overload") + return super().submit_dmesg_text_observation(**kwargs) + + +class FakeGrpcWriter: + def __init__(self): + self.started = False + self.shutdown_called = False + self.join_timeout = None + + def start(self): + self.started = True + + def shutdown(self): + self.shutdown_called = True + + def join(self, timeout=None): + self.join_timeout = timeout + + +def _recording_grpc_writer_factory(): + writer_records = [] + + def grpc_writer_factory(write_queue, address, node_id, logger): + writer = FakeGrpcWriter() + writer_records.append( + { + "address": address, + "node_id": node_id, + "queue": write_queue, + "writer": writer, + } + ) + return writer + + return writer_records, grpc_writer_factory + + +def _drain_grpc_writes(writer_records): + writes = [] + for writer_record in writer_records: + write_queue = writer_record["queue"] + while not write_queue.empty(): + path, payload = write_queue.get_nowait() + writes.append({"path": path, "payload": payload}) + return writes + + +def _request(**overrides): + values = { + "run_id": "run-1", + "cycle": 2, + "rdzv_endpoint": "127.0.0.1:29500", + "local_node": "node-a", + "store_timeout_s": 0.1, + "job_id": "job-1", + "dmesg_path": None, + "result_path": None, + } + values.update(overrides) + return FactAgentRequest(**values) + + +def test_cycle_payload_uses_session_context_and_derives_dmesg_artifact_path(tmp_path): + cycle_start_time = datetime(2026, 5, 10, 12, 0, 0, tzinfo=timezone.utc) + service = FactAgent( + fact_url="http://fact.example/latest", + run_id="run-1", + rdzv_endpoint="store-host:29500", + store_timeout_s=12.0, + local_node="node-a", + is_store_host=True, + job_id="job-1", + ranks_per_node=4, + health_log_prefix=str(tmp_path / "job_health.log"), + dmesg_artifact_enabled=True, + result_artifact_enabled=True, + grpc_server_address="store-host:50051", + grpc_node_id="node-a_123", + ) + + request = service._request_from_payload( + { + "event": "cycle_failed", + "run_id": "payload-run-should-not-win", + "cycle": 3, + "cycle_start_time": cycle_start_time.isoformat(), + "dmesg_path": str(tmp_path / "payload_dmesg.log"), + "expected_nodes": ["node-a", "node-b"], + "result_path": str(tmp_path / "payload_result.log"), + } + ) + + assert request.run_id == "run-1" + assert request.rdzv_endpoint == "store-host:29500" + assert request.store_timeout_s == 12.0 + assert request.local_node == "node-a" + assert request.is_store_host is True + assert request.job_id == "job-1" + assert request.ranks_per_node == 4 + assert request.cycle_start_time == cycle_start_time + assert request.grpc_server_address == "store-host:50051" + assert request.grpc_node_id == "node-a_123" + assert request.expected_nodes == ("node-a", "node-b") + assert request.dmesg_path == str(tmp_path / "job_health_dmesg_cycle3.log") + assert request.result_path == str(tmp_path / "job_health_fact_cycle3.log") + + +def test_cycle_payload_does_not_derive_artifact_paths_without_grpc(tmp_path): + service = FactAgent( + fact_url="http://fact.example/latest", + run_id="run-1", + rdzv_endpoint="store-host:29500", + local_node="node-a", + is_store_host=True, + health_log_prefix=str(tmp_path / "job_health.log"), + dmesg_artifact_enabled=True, + result_artifact_enabled=True, + ) + + request = service._request_from_payload( + { + "event": "cycle_failed", + "cycle": 3, + "expected_nodes": ["node-a", "node-b"], + } + ) + + assert request.dmesg_path is None + assert request.result_path is None + + +def test_leaf_cycle_payload_derives_shared_result_artifact_path(tmp_path): + service = FactAgent( + fact_url="http://fact.example/latest", + run_id="run-1", + rdzv_endpoint="store-host:29500", + local_node="node-b", + is_store_host=False, + health_log_prefix=str(tmp_path / "job_health.log"), + result_artifact_enabled=True, + grpc_server_address="store-host:50051", + grpc_node_id="node-b_123", + ) + + request = service._request_from_payload( + { + "event": "cycle_failed", + "cycle": 3, + "expected_nodes": [], + } + ) + + assert request.result_path == str(tmp_path / "job_health_fact_cycle3.log") + + +def test_result_artifact_is_not_derived_without_grpc(tmp_path): + service = FactAgent( + fact_url="http://fact.example/latest", + run_id="run-1", + rdzv_endpoint="store-host:29500", + local_node="node-b", + is_store_host=False, + health_log_prefix=str(tmp_path / "job_health.log"), + result_artifact_enabled=True, + ) + + request = service._request_from_payload({"event": "cycle_failed", "cycle": 3}) + + assert request.result_path is None + + +def test_leaf_submission_uses_minimal_store_completion(): + store = FakeStore() + fact = FakeFactClient() + keys = FactAgentKeys("run-1", 2) + store.set(keys.attributor_id, b"att-1") + + service = FactAgent( + fact_url="http://fact.example/latest", + store_factory=lambda request: store, + fact_client_factory=lambda: fact, + dmesg_collector=lambda window_s, node: f"{node}: [1.0] NVRM: Xid 95", + ) + + service.process_cycle_failed(_request()) + + assert int(store.get(keys.done_count).decode("utf-8")) == 1 + assert sorted(store.data) == [keys.attributor_id, keys.done_count] + assert fact.submitted[0]["attributor_id"] == "att-1" + + +def test_leaf_submission_appends_observation_jsonl(tmp_path): + store = FakeStore() + fact = FakeFactClient() + keys = FactAgentKeys("run-1", 2) + result_path = tmp_path / "fact_result.jsonl" + writer_records, grpc_writer_factory = _recording_grpc_writer_factory() + store.set(keys.attributor_id, b"att-1") + + service = FactAgent( + fact_url="http://fact.example/latest", + store_factory=lambda request: store, + fact_client_factory=lambda: fact, + dmesg_collector=lambda window_s, node: f"{node}: [1.0] NVRM: Xid 95", + grpc_writer_factory=grpc_writer_factory, + ) + + service.process_cycle_failed( + _request( + local_node="node-b", + result_path=str(result_path), + grpc_server_address="log-host:50051", + grpc_node_id="node-b_123", + ) + ) + + writes = _drain_grpc_writes(writer_records) + assert [write["path"] for write in writes] == [str(result_path)] + records = [json.loads(write["payload"]) for write in writes] + assert records == [ + { + "record_type": "fact_observation", + "run_id": "run-1", + "cycle": 2, + "job_id": "job-1", + "node": "node-b", + "source": "dmesg", + "status": "submitted", + "attributor_id": "att-1", + "observation_id": "obs-node-b", + "lines_collected": 1, + "bytes_collected": len("node-b: [1.0] NVRM: Xid 95".encode("utf-8")), + "dmesg_path": "", + "dmesg_write_error": "", + "error": "", + } + ] + assert int(store.get(keys.done_count).decode("utf-8")) == 1 + + +def test_empty_dmesg_writes_observation_jsonl_without_dmesg_file(tmp_path): + store = FakeStore() + fact = FakeFactClient() + keys = FactAgentKeys("run-1", 2) + dmesg_path = tmp_path / "job_health_dmesg_cycle2.log" + result_path = tmp_path / "fact_result.jsonl" + writer_records, grpc_writer_factory = _recording_grpc_writer_factory() + store.set(keys.attributor_id, b"att-1") + + service = FactAgent( + fact_url="http://fact.example/latest", + store_factory=lambda request: store, + fact_client_factory=lambda: fact, + dmesg_collector=lambda window_s, node: "", + grpc_writer_factory=grpc_writer_factory, + ) + + service.process_cycle_failed( + _request( + local_node="node-b", + dmesg_path=str(dmesg_path), + result_path=str(result_path), + grpc_server_address="log-host:50051", + grpc_node_id="node-b_123", + ) + ) + + writes = _drain_grpc_writes(writer_records) + records = [json.loads(write["payload"]) for write in writes] + assert not dmesg_path.exists() + assert int(store.get(keys.done_count).decode("utf-8")) == 1 + assert records == [ + { + "record_type": "fact_observation", + "run_id": "run-1", + "cycle": 2, + "job_id": "job-1", + "node": "node-b", + "source": "dmesg", + "status": "empty", + "attributor_id": "att-1", + "observation_id": None, + "lines_collected": 0, + "bytes_collected": 0, + "dmesg_path": "", + "dmesg_write_error": "", + "error": "", + } + ] + + +def test_leaf_submission_retries_post_within_deadline(tmp_path): + store = FakeStore() + fact = FlakyPostFactClient() + keys = FactAgentKeys("run-1", 2) + result_path = tmp_path / "fact_result.jsonl" + writer_records, grpc_writer_factory = _recording_grpc_writer_factory() + store.set(keys.attributor_id, b"att-1") + + service = FactAgent( + fact_url="http://fact.example/latest", + observation_deadline_s=30.0, + store_factory=lambda request: store, + fact_client_factory=lambda: fact, + dmesg_collector=lambda window_s, node: f"{node}: [1.0] NVRM: Xid 95", + grpc_writer_factory=grpc_writer_factory, + ) + + with patch("nvidia_resiliency_ext.attribution.fact.agent.time.sleep"): + service.process_cycle_failed( + _request( + result_path=str(result_path), + grpc_server_address="log-host:50051", + grpc_node_id="node-a_123", + ) + ) + + records = [json.loads(write["payload"]) for write in _drain_grpc_writes(writer_records)] + assert len(fact.submitted) == 1 + assert records[0]["status"] == "submitted" + assert records[0]["observation_id"] == "obs-node-a" + + +def test_leaf_uses_store_wait_for_attributor_id(): + store = WaitStore() + fact = FakeFactClient() + keys = FactAgentKeys("run-1", 2) + store.set(keys.attributor_id, b"att-1") + + service = FactAgent( + fact_url="http://fact.example/latest", + store_factory=lambda request: store, + fact_client_factory=lambda: fact, + dmesg_collector=lambda window_s, node: f"{node}: [1.0] NVRM: Xid 95", + ) + + service.process_cycle_failed(_request()) + + assert store.wait_calls[0][0] == [keys.attributor_id] + assert fact.submitted[0]["attributor_id"] == "att-1" + + +def test_leaf_submission_queues_dmesg_evidence_file(tmp_path): + store = FakeStore() + fact = FakeFactClient() + keys = FactAgentKeys("run-1", 2) + dmesg_path = tmp_path / "job_health_dmesg_cycle2.log" + writer_records, grpc_writer_factory = _recording_grpc_writer_factory() + store.set(keys.attributor_id, b"att-1") + + service = FactAgent( + fact_url="http://fact.example/latest", + store_factory=lambda request: store, + fact_client_factory=lambda: fact, + dmesg_collector=lambda window_s, node: f"{node}: [1.0] NVRM: Xid 95", + grpc_writer_factory=grpc_writer_factory, + ) + + service.process_cycle_failed( + _request( + dmesg_path=str(dmesg_path), + grpc_server_address="log-host:50051", + grpc_node_id="node-a_123", + ) + ) + + assert int(store.get(keys.done_count).decode("utf-8")) == 1 + assert sorted(store.data) == [keys.attributor_id, keys.done_count] + assert not dmesg_path.exists() + writes = _drain_grpc_writes(writer_records) + assert [write["path"] for write in writes] == [str(dmesg_path)] + assert writes[0]["payload"] == "node-a: [1.0] NVRM: Xid 95\n" + + +def test_store_host_gets_result_without_tcpstore_status_fan_in(tmp_path): + store = FakeStore() + fact = FakeFactClient() + result_path = tmp_path / "fact_result.json" + cycle_start_time = datetime(2026, 5, 10, 12, 0, 0, tzinfo=timezone.utc) + writer_records, grpc_writer_factory = _recording_grpc_writer_factory() + service = FactAgent( + fact_url="http://fact.example/latest", + observation_deadline_s=0.01, + store_factory=lambda request: store, + fact_client_factory=lambda: fact, + dmesg_collector=lambda window_s, node: f"{node}: [1.0] NVRM: Xid 95", + grpc_writer_factory=grpc_writer_factory, + username="slurm-user", + cluster="slurm-cluster", + ) + + service.process_cycle_failed( + _request( + is_store_host=True, + expected_nodes=("node-a", "node-b"), + ranks_per_node=4, + cycle_start_time=cycle_start_time, + result_path=str(result_path), + grpc_server_address="log-host:50051", + grpc_node_id="node-a_123", + ) + ) + + keys = FactAgentKeys("run-1", 2) + assert sorted(store.data) == [keys.attributor_id, keys.done_count] + assert fact.created[0]["nodes"] == ["node-a", "node-b"] + assert fact.created[0]["ranks_per_node"] == 4 + assert fact.created[0]["nranks"] == 8 + assert fact.created[0]["start_time"] == cycle_start_time + assert fact.created[0]["username"] == "slurm-user" + assert fact.created[0]["cluster"] == "slurm-cluster" + assert fact.gets[0]["observation_ids"] == [] + records = [json.loads(write["payload"]) for write in _drain_grpc_writes(writer_records)] + assert [record["record_type"] for record in records] == ["fact_observation", "fact_result"] + assert records[0]["status"] == "submitted" + assert records[0]["observation_id"] == "obs-node-a" + result_payload = records[1] + assert result_payload["status"] == "complete" + assert result_payload["fact_attribution_result"] == { + "attributor_id": "att-1", + "observation_ids": [], + "faulty_nodes": ["node-a"], + "attribution": {"attributions": []}, + } + assert result_payload["expected_node_count"] == 2 + assert result_payload["completed_node_count"] == 1 + assert "faulty_nodes" not in result_payload + assert "submission_statuses" not in result_payload + + +def test_store_host_computes_avoid_nodes_from_history(): + store = FakeStore() + fact = FakeFactClient() + cycle_start_time = datetime(2026, 5, 10, 12, 0, 0, tzinfo=timezone.utc) + cycle_end_time = cycle_start_time + timedelta(minutes=42) + history = FakeHistoryClient( + [ + FactHistoryRecord( + cluster="slurm-cluster", + node="node-a", + episode_id="job-0_1", + event_time=cycle_start_time - timedelta(hours=1), + ) + ] + ) + service = FactAgent( + fact_url="http://fact.example/latest", + observation_deadline_s=0.01, + store_factory=lambda request: store, + fact_client_factory=lambda: fact, + fact_history_client_factory=lambda: history, + dmesg_collector=lambda window_s, node: f"{node}: [1.0] NVRM: Xid 95", + username="slurm-user", + cluster="slurm-cluster", + is_store_host=True, + fact_history_es_url="http://history.example", + fact_history_es_auth_file="/tmp/token", + ) + + service.process_cycle_failed( + _request( + is_store_host=True, + expected_nodes=("node-a", "node-b"), + cycle_start_time=cycle_start_time, + cycle_end_time=cycle_end_time, + ) + ) + + assert history.queries[0]["cluster"] == "slurm-cluster" + assert history.queries[0]["nodes"] == ["node-a"] + assert history.queries[0]["end_time"] == cycle_start_time + assert service.handle_payload({"event": "get_avoid_nodes", "cycle": 2}) == { + "cycle_id": "2", + "status": "ready", + "avoid_nodes": ["node-a"], + } + + +def test_hot_cache_overlays_history_for_back_to_back_cycles(): + store = FakeStore() + fact = FakeFactClient() + history = FakeHistoryClient([]) + cycle_start_time = datetime(2026, 5, 10, 12, 0, 0, tzinfo=timezone.utc) + service = FactAgent( + fact_url="http://fact.example/latest", + observation_deadline_s=0.01, + store_factory=lambda request: store, + fact_client_factory=lambda: fact, + fact_history_client_factory=lambda: history, + dmesg_collector=lambda window_s, node: f"{node}: [1.0] NVRM: Xid 95", + cluster="slurm-cluster", + is_store_host=True, + fact_history_es_url="http://history.example", + fact_history_es_auth_file="/tmp/token", + ) + + service.process_cycle_failed( + _request( + is_store_host=True, + expected_nodes=("node-a",), + cycle=2, + cycle_start_time=cycle_start_time, + cycle_end_time=cycle_start_time + timedelta(minutes=5, seconds=30), + ) + ) + service.process_cycle_failed( + _request( + is_store_host=True, + expected_nodes=("node-a",), + cycle=3, + cycle_start_time=cycle_start_time + timedelta(minutes=5), + cycle_end_time=cycle_start_time + timedelta(minutes=6), + ) + ) + + assert service.handle_payload({"event": "get_avoid_nodes", "cycle": 3})["avoid_nodes"] == [ + "node-a" + ] + + +def test_hot_cache_computes_avoid_nodes_without_fact_history(): + store = FakeStore() + fact = FakeFactClient() + cycle_start_time = datetime(2026, 5, 10, 12, 0, 0, tzinfo=timezone.utc) + service = FactAgent( + fact_url="http://fact.example/latest", + observation_deadline_s=0.01, + store_factory=lambda request: store, + fact_client_factory=lambda: fact, + dmesg_collector=lambda window_s, node: f"{node}: [1.0] NVRM: Xid 95", + cluster="slurm-cluster", + is_store_host=True, + ) + + service.process_cycle_failed( + _request( + is_store_host=True, + expected_nodes=("node-a",), + cycle=2, + cycle_start_time=cycle_start_time, + cycle_end_time=cycle_start_time + timedelta(minutes=1), + ) + ) + assert service.handle_payload({"event": "get_avoid_nodes", "cycle": 2}) == { + "cycle_id": "2", + "status": "ready", + "avoid_nodes": [], + } + + service.process_cycle_failed( + _request( + is_store_host=True, + expected_nodes=("node-a",), + cycle=3, + cycle_start_time=cycle_start_time + timedelta(minutes=5), + cycle_end_time=cycle_start_time + timedelta(minutes=6), + ) + ) + + assert service.handle_payload({"event": "get_avoid_nodes", "cycle": 3}) == { + "cycle_id": "3", + "status": "ready", + "avoid_nodes": ["node-a"], + } + + +def test_grpc_result_artifact_queues_shared_dmesg_artifact(tmp_path): + store = FakeStore() + fact = FakeFactClient() + dmesg_path = tmp_path / "job_health_dmesg_cycle2.log" + result_path = tmp_path / "fact_result.json" + writer_records = [] + + def grpc_writer_factory(write_queue, address, node_id, logger): + writer = FakeGrpcWriter() + writer_records.append( + { + "address": address, + "node_id": node_id, + "queue": write_queue, + "writer": writer, + } + ) + return writer + + service = FactAgent( + fact_url="http://fact.example/latest", + observation_deadline_s=0.01, + store_factory=lambda request: store, + fact_client_factory=lambda: fact, + dmesg_collector=lambda window_s, node: f"{node}: [1.0] NVRM: Xid 95", + grpc_writer_factory=grpc_writer_factory, + ) + + service.process_cycle_failed( + _request( + is_store_host=True, + expected_nodes=("node-a", "node-b"), + ranks_per_node=4, + dmesg_path=str(dmesg_path), + result_path=str(result_path), + grpc_server_address="log-host:50051", + grpc_node_id="node-a_123", + ) + ) + + keys = FactAgentKeys("run-1", 2) + assert int(store.get(keys.done_count).decode("utf-8")) == 1 + assert store.get(keys.attributor_id) == b"att-1" + assert sorted(store.data) == [keys.attributor_id, keys.done_count] + assert fact.gets[0]["observation_ids"] == [] + + assert len(writer_records) == 1 + assert writer_records[0]["address"] == "log-host:50051" + assert writer_records[0]["node_id"] == "node-a_123" + assert writer_records[0]["writer"].started + writes = [] + while not writer_records[0]["queue"].empty(): + path, payload = writer_records[0]["queue"].get_nowait() + writes.append({"path": path, "payload": payload}) + + assert not dmesg_path.exists() + assert [write["path"] for write in writes] == [ + str(dmesg_path), + str(result_path), + str(result_path), + ] + assert writes[0]["payload"] == "node-a: [1.0] NVRM: Xid 95\n" + observation_payload = json.loads(writes[1]["payload"]) + assert observation_payload["record_type"] == "fact_observation" + assert observation_payload["status"] == "submitted" + assert observation_payload["observation_id"] == "obs-node-a" + result_payload = json.loads(writes[2]["payload"]) + assert result_payload["status"] == "complete" + assert result_payload["fact_attribution_result"]["faulty_nodes"] == ["node-a"] + assert "faulty_nodes" not in result_payload + assert "submission_statuses" not in result_payload + + +def test_grpc_completion_uses_minimal_store_without_dmesg_artifact(): + store = FakeStore() + fact = FakeFactClient() + writer_records = [] + + def grpc_writer_factory(write_queue, address, node_id, logger): + writer = FakeGrpcWriter() + writer_records.append({"queue": write_queue, "writer": writer}) + return writer + + service = FactAgent( + fact_url="http://fact.example/latest", + observation_deadline_s=0.01, + store_factory=lambda request: store, + fact_client_factory=lambda: fact, + dmesg_collector=lambda window_s, node: f"{node}: [1.0] NVRM: Xid 95", + grpc_writer_factory=grpc_writer_factory, + ) + + service.process_cycle_failed( + _request( + is_store_host=True, + expected_nodes=("node-a", "node-b"), + grpc_server_address="log-host:50051", + grpc_node_id="node-a_123", + ) + ) + + keys = FactAgentKeys("run-1", 2) + assert int(store.get(keys.done_count).decode("utf-8")) == 1 + assert store.get(keys.attributor_id) == b"att-1" + assert sorted(store.data) == [keys.attributor_id, keys.done_count] + assert writer_records == [] + + +def test_scalable_attributor_failure_only_uses_attributor_key(): + store = FakeStore() + fact = FailingCreateFactClient() + writer_records = [] + + def grpc_writer_factory(write_queue, address, node_id, logger): + writer = FakeGrpcWriter() + writer_records.append({"queue": write_queue, "writer": writer}) + return writer + + service = FactAgent( + fact_url="http://fact.example/latest", + store_factory=lambda request: store, + fact_client_factory=lambda: fact, + dmesg_collector=lambda window_s, node: f"{node}: [1.0] NVRM: Xid 95", + grpc_writer_factory=grpc_writer_factory, + ) + + service.process_cycle_failed( + _request( + is_store_host=True, + expected_nodes=("node-a", "node-b"), + grpc_server_address="log-host:50051", + grpc_node_id="node-a_123", + ) + ) + + keys = FactAgentKeys("run-1", 2) + assert ( + store.get(keys.attributor_id).decode("utf-8").startswith("__nvrx_fact_attributor_failed__:") + ) + assert not store.check([keys.done_count]) + assert sorted(store.data) == [keys.attributor_id] + assert writer_records == [] + + +def test_store_host_attributor_failure_publishes_leaf_sentinel(): + store = FakeStore() + fact = FailingCreateFactClient() + service = FactAgent( + fact_url="http://fact.example/latest", + store_factory=lambda request: store, + fact_client_factory=lambda: fact, + dmesg_collector=lambda window_s, node: f"{node}: [1.0] NVRM: Xid 95", + ) + + service.process_cycle_failed(_request(is_store_host=True, expected_nodes=("node-a", "node-b"))) + + keys = FactAgentKeys("run-1", 2) + service.process_cycle_failed(_request(local_node="node-b")) + + assert int(store.get(keys.done_count).decode("utf-8")) == 1 + assert fact.submitted == [] + + +def test_leaf_observation_window_uses_collection_time_before_store_wait(): + keys = FactAgentKeys("run-1", 2) + store = FakeStore() + store.set(keys.attributor_id, b"att-1") + store_wait_started = False + original_check = store.check + collection_end_time = datetime(2026, 5, 10, 12, 0, 0, tzinfo=timezone.utc) + + def check(keys_to_check): + nonlocal store_wait_started + if keys_to_check == [keys.attributor_id]: + store_wait_started = True + return original_check(keys_to_check) + + class GuardedDateTime: + @classmethod + def now(cls, tz): + assert not store_wait_started + return collection_end_time + + store.check = check + fact = FakeFactClient() + service = FactAgent( + fact_url="http://fact.example/latest", + store_factory=lambda request: store, + fact_client_factory=lambda: fact, + dmesg_collector=lambda window_s, node: f"{node}: [1.0] NVRM: Xid 95", + ) + + with patch("nvidia_resiliency_ext.attribution.fact.agent.datetime", GuardedDateTime): + service.process_cycle_failed(_request()) + + assert fact.submitted[0]["end_time"] == collection_end_time + + +def test_completion_count_poll_keeps_last_known_count_on_store_exception(): + store = FlakyCountReadStore() + service = FactAgent(fact_url="http://fact.example/latest", observation_deadline_s=1.0) + keys = FactAgentKeys("run-1", 2) + monotonic_values = iter([0.0, 0.0, 0.1, 2.0]) + + with ( + patch( + "nvidia_resiliency_ext.attribution.fact.agent.time.monotonic", + side_effect=lambda: next(monotonic_values), + ), + patch("nvidia_resiliency_ext.attribution.fact.agent.time.sleep"), + ): + completed_count = service._wait_for_completion_count( + store, + keys, + expected_node_count=5, + ) + + assert completed_count == 3 + + +def test_stop_waits_for_executor_tasks_before_draining_grpc_writers(): + writer_records, grpc_writer_factory = _recording_grpc_writer_factory() + service = FactAgent( + fact_url="http://fact.example/latest", + grpc_writer_factory=grpc_writer_factory, + ) + + service._executor.submit( + service._enqueue_grpc_artifact, + "log-host:50051", + "node-a_123", + "/tmp/fact_result.jsonl", + "{}\n", + ) + + service.stop() + + assert len(writer_records) == 1 + assert writer_records[0]["writer"].started + assert writer_records[0]["writer"].shutdown_called + assert writer_records[0]["writer"].join_timeout == 5.0 + + +def test_empty_collection_is_distinct_from_missing_node(): + store = FakeStore() + fact = FakeFactClient() + keys = FactAgentKeys("run-1", 2) + store.set(keys.attributor_id, b"att-1") + service = FactAgent( + fact_url="http://fact.example/latest", + store_factory=lambda request: store, + fact_client_factory=lambda: fact, + dmesg_collector=lambda window_s, node: "plain kernel line", + ) + + service.process_cycle_failed(_request()) + + assert int(store.get(keys.done_count).decode("utf-8")) == 1 + assert fact.submitted[0]["dmesg_text"] == "plain kernel line" + + +def test_serve_forever_exits_after_stop_without_new_connection(): + socket_path = os.path.join( + tempfile.gettempdir(), + f"nvrx-fact-agent-test-{os.getpid()}-{time.monotonic_ns()}.sock", + ) + service = FactAgent(fact_url="http://fact.example/latest", socket_path=socket_path) + thread = threading.Thread(target=service.serve_forever, daemon=True) + + try: + thread.start() + deadline = time.monotonic() + 5.0 + while not os.path.exists(socket_path) and time.monotonic() < deadline: + time.sleep(0.01) + assert os.path.exists(socket_path) + + service.stop() + thread.join(timeout=2.0) + + assert not thread.is_alive() + finally: + service.stop() + with contextlib.suppress(FileNotFoundError): + os.unlink(socket_path) diff --git a/tests/attribution/unit/test_fact_client.py b/tests/attribution/unit/test_fact_client.py new file mode 100644 index 00000000..6f8ba288 --- /dev/null +++ b/tests/attribution/unit/test_fact_client.py @@ -0,0 +1,200 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import json +from datetime import datetime, timedelta, timezone +from types import SimpleNamespace +from unittest.mock import patch + +import pytest + +from nvidia_resiliency_ext.attribution.fact import client as fact_client +from nvidia_resiliency_ext.attribution.fact.client import ( + collect_recent_dmesg_text, + dmesg_text_to_raw_loki_streams, + is_fact_relevant_dmesg_message, + normalize_fact_attribution_url, +) + + +def test_dmesg_text_to_raw_loki_streams_uses_prefixed_hostname(): + streams, nodes = dmesg_text_to_raw_loki_streams( + "gb-nvl-134-compute01: [1247249.751385] NVRM: Xid (PCI:0000:e4:00): 95\n", + timestamp_start_ns=1777601387000000000, + ) + + assert nodes == ["gb-nvl-134-compute01"] + assert streams[0]["stream"]["hostname"] == "gb-nvl-134-compute01" + timestamp, body = streams[0]["values"][0] + payload = json.loads(body) + assert timestamp == "1777601387000000000" + assert payload["attributes"]["hostname"] == "gb-nvl-134-compute01" + assert payload["attributes"]["appname"] == "kernel" + assert payload["severity"] == "err" + assert "Xid" in payload["body"] + + +def test_dmesg_text_to_raw_loki_streams_uses_default_hostname_without_prefix(): + streams, nodes = dmesg_text_to_raw_loki_streams( + "[1247249.751385] plain kernel line\n", + default_hostname="default-node", + timestamp_start_ns=1777601387000000000, + ) + + assert nodes == ["default-node"] + payload = json.loads(streams[0]["values"][0][1]) + assert payload["attributes"]["hostname"] == "default-node" + assert payload["body"] == "[1247249.751385] plain kernel line" + + +def test_dmesg_text_to_raw_loki_streams_preserves_zero_timestamp(): + streams, _ = dmesg_text_to_raw_loki_streams( + "[1247249.751385] plain kernel line\n", + default_hostname="default-node", + timestamp_start_ns=0, + ) + + assert streams[0]["values"][0][0] == "0" + + +def test_dmesg_text_to_raw_loki_streams_normalizes_grpc_node_id(): + streams, nodes = dmesg_text_to_raw_loki_streams( + "gb-nvl-134-compute03_2549019: [1247249.751385] NVRM: Xid 95\n", + timestamp_start_ns=1777601387000000000, + ) + + assert nodes == ["gb-nvl-134-compute03"] + assert streams[0]["stream"]["hostname"] == "gb-nvl-134-compute03" + payload = json.loads(streams[0]["values"][0][1]) + assert payload["attributes"]["hostname"] == "gb-nvl-134-compute03" + + +def test_fact_attributor_node_list_preserves_plain_hostname_suffixes(): + nodes = fact_client._fact_attributor_node_list( + ["rack_a_node_1", "rack_a_node_2", "rack_a_node_1"] + ) + + assert nodes == ["rack_a_node_1", "rack_a_node_2"] + + +def test_dmesg_text_to_raw_loki_streams_prefilters_fact_patterns(): + text = "\n".join( + [ + "gb-nvl-134-compute01: [1.0] plain kernel line", + "gb-nvl-134-compute01: [2.0] NVRM: Xid (PCI:0000:e4:00): 95", + "gb-nvl-134-compute02: [3.0] SXid (PCI:0000:e5:00): 11012", + "gb-nvl-134-compute03: [4.0] mlx5_core 0000:01:00.0: port 1 link down", + ] + ) + + streams, nodes = dmesg_text_to_raw_loki_streams( + text, + timestamp_start_ns=1777601387000000000, + prefilter=True, + ) + + assert nodes == ["gb-nvl-134-compute01", "gb-nvl-134-compute02"] + bodies = [json.loads(value[1])["body"] for stream in streams for value in stream["values"]] + assert any("NVRM: Xid" in body for body in bodies) + assert any("SXid" in body for body in bodies) + assert all("plain kernel line" not in body for body in bodies) + assert all("mlx5_core" not in body for body in bodies) + + +def test_fact_relevant_dmesg_patterns_document_mlx5_gap(): + assert is_fact_relevant_dmesg_message("[1.0] NVRM: Xid (PCI:0000:e4:00): 95") + assert not is_fact_relevant_dmesg_message("[4.0] mlx5_core 0000:01:00.0: port 1 link down") + + +def test_collect_recent_dmesg_uses_subprocess_timeout(): + with patch.object(fact_client.subprocess, "run") as run: + run.return_value = SimpleNamespace(returncode=0, stdout="", stderr="") + + collect_recent_dmesg_text(window_s=12.0, hostname="node-a") + + assert run.call_args.kwargs["timeout"] == fact_client._DMESG_COMMAND_TIMEOUT_S + + +def test_raw_loki_timestamp_anchor_stays_near_collection_end(): + start_time = datetime(2026, 5, 10, 12, 0, 0, tzinfo=timezone.utc) + end_time = start_time + timedelta(minutes=12) + + anchor_ns = fact_client._raw_loki_timestamp_anchor_ns(start_time, end_time) + + assert anchor_ns == int((end_time - timedelta(seconds=1)).timestamp() * 1_000_000_000) + + +def test_raw_loki_timestamp_anchor_stays_inside_short_interval(): + start_time = datetime(2026, 5, 10, 12, 0, 0, tzinfo=timezone.utc) + end_time = start_time + timedelta(milliseconds=100) + + anchor_ns = fact_client._raw_loki_timestamp_anchor_ns(start_time, end_time) + + assert int(start_time.timestamp() * 1_000_000_000) < anchor_ns + assert anchor_ns < int(end_time.timestamp() * 1_000_000_000) + + +def test_normalize_fact_attribution_url_accepts_service_root_or_api_root(): + assert normalize_fact_attribution_url("http://fact.example:8001") == ( + "http://fact.example:8001/latest" + ) + assert normalize_fact_attribution_url("https://fact.example:8001/latest/") == ( + "https://fact.example:8001/latest" + ) + assert normalize_fact_attribution_url("https://proxy.example/fact/latest") == ( + "https://proxy.example/fact/latest" + ) + + +@pytest.mark.parametrize( + "url", + [ + "", + "fact.example:8001", + "http://fact.example:8001/latest?debug=true", + "http://fact.example:8001/latest#fragment", + ], +) +def test_normalize_fact_attribution_url_rejects_invalid_base_url(url): + with pytest.raises(ValueError): + normalize_fact_attribution_url(url) + + +def test_attributor_info_uses_slurm_job_name(monkeypatch): + monkeypatch.setenv("SLURM_JOB_NAME", "resnet-training") + service = fact_client.FactAttributionService(url="http://fact.example/latest") + timestamp = datetime(2026, 5, 10, 12, 0, 0, tzinfo=timezone.utc) + + info = service._build_attributor_info( + job_id="job-1", + cycle_index=2, + nodes=["node-a"], + ranks_per_node=4, + nranks=4, + start_time=timestamp, + end_time=timestamp, + username="slurm-user", + cluster="slurm-cluster", + ) + + assert info["workload"]["name"] == "resnet-training" + assert info["workload"]["username"] == "slurm-user" + assert info["metadata"]["cluster"] == "slurm-cluster" + + +def test_attributor_info_defaults_job_name_when_slurm_job_name_missing(monkeypatch): + monkeypatch.delenv("SLURM_JOB_NAME", raising=False) + service = fact_client.FactAttributionService(url="http://fact.example/latest") + timestamp = datetime(2026, 5, 10, 12, 0, 0, tzinfo=timezone.utc) + + info = service._build_attributor_info( + job_id="job-1", + cycle_index=2, + nodes=["node-a"], + ranks_per_node=4, + nranks=4, + start_time=timestamp, + end_time=timestamp, + ) + + assert info["workload"]["name"] == "unknown" diff --git a/tests/attribution/unit/test_fact_manager.py b/tests/attribution/unit/test_fact_manager.py new file mode 100644 index 00000000..6f6f32ee --- /dev/null +++ b/tests/attribution/unit/test_fact_manager.py @@ -0,0 +1,199 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import MagicMock, patch + +from nvidia_resiliency_ext.attribution.fact import manager as fact_manager + + +def test_fact_agent_manager_disabled_when_fact_url_absent(): + manager = fact_manager.FactAgentManager(fact_url=None) + + assert manager.start_if_needed() is None + + +def test_fact_agent_manager_starts_agent_and_waits_for_ping(tmp_path): + socket_path = str(tmp_path / "fact-agent.sock") + log_file = str(tmp_path / "fact-agent.log") + process = MagicMock() + process.pid = 123 + process.poll.return_value = None + + with ( + patch.object(fact_manager, "_fact_agent_command", return_value=["nvrx-fact-agent"]), + patch.object(fact_manager.subprocess, "Popen", return_value=process) as popen, + patch.object(fact_manager, "notify_fact_agent", return_value={"accepted": True}) as notify, + ): + manager = fact_manager.FactAgentManager( + fact_url="http://fact.example:8001/latest", + socket_path=socket_path, + startup_timeout_s=0.5, + log_file=log_file, + ) + + endpoint = manager.start_if_needed() + + assert endpoint is not None + assert endpoint.socket_path == socket_path + assert popen.call_args.args[0] == [ + "nvrx-fact-agent", + "--fact-url", + "http://fact.example:8001/latest", + "--socket-path", + socket_path, + ] + notify.assert_called_with( + socket_path=socket_path, + payload={"event": "ping"}, + timeout_s=manager.rpc_timeout_s, + ) + + +def test_fact_agent_manager_passes_session_args(tmp_path): + socket_path = str(tmp_path / "fact-agent.sock") + log_file = str(tmp_path / "fact-agent.log") + process = MagicMock() + process.pid = 123 + process.poll.return_value = None + + with ( + patch.object(fact_manager, "_fact_agent_command", return_value=["nvrx-fact-agent"]), + patch.object(fact_manager.subprocess, "Popen", return_value=process) as popen, + patch.object(fact_manager, "notify_fact_agent", return_value={"accepted": True}), + ): + manager = fact_manager.FactAgentManager( + fact_url="http://fact.example:8001/latest", + socket_path=socket_path, + startup_timeout_s=0.5, + log_file=log_file, + run_id="run-1", + rdzv_endpoint="store-host:29500", + store_timeout_s=12.0, + local_node="node-a", + is_store_host=True, + job_id="job-1", + ranks_per_node=4, + username="slurm-user", + cluster="slurm-cluster", + health_log_prefix="/logs/job_health.log", + dmesg_artifact_enabled=True, + result_artifact_enabled=True, + grpc_server_address="store-host:50051", + grpc_node_id="node-a_123", + fact_history_es_url="http://history.example", + fact_history_es_auth_file="/tmp/history.auth", + fact_history_lookback="14d", + fact_history_index="history-*", + fact_history_max_candidate_nodes=16, + fact_history_query_timeout_s=30.0, + fact_min_repeat_count_for_avoid=2, + fact_max_attribution_avoids_per_cycle=1, + ) + + manager.start_if_needed() + + assert popen.call_args.args[0] == [ + "nvrx-fact-agent", + "--fact-url", + "http://fact.example:8001/latest", + "--socket-path", + socket_path, + "--run-id", + "run-1", + "--rdzv-endpoint", + "store-host:29500", + "--store-timeout", + "12.0", + "--local-node", + "node-a", + "--is-store-host", + "--job-id", + "job-1", + "--ranks-per-node", + "4", + "--username", + "slurm-user", + "--cluster", + "slurm-cluster", + "--health-log-prefix", + "/logs/job_health.log", + "--dmesg-artifact-enabled", + "--result-artifact-enabled", + "--grpc-server-address", + "store-host:50051", + "--grpc-node-id", + "node-a_123", + "--fact-history-es-url", + "http://history.example", + "--fact-history-es-auth-file", + "/tmp/history.auth", + "--fact-history-lookback", + "14d", + "--fact-history-index", + "history-*", + "--fact-history-max-candidate-nodes", + "16", + "--fact-history-query-timeout", + "30.0", + "--fact-min-repeat-count-for-avoid", + "2", + "--fact-max-attribution-avoids-per-cycle", + "1", + ] + + +def test_fact_agent_manager_stops_agent_on_startup_failure(tmp_path): + socket_path = str(tmp_path / "fact-agent.sock") + log_file = str(tmp_path / "fact-agent.log") + process = MagicMock() + process.pid = 123 + process.poll.return_value = None + + with ( + patch.object(fact_manager, "_fact_agent_command", return_value=["nvrx-fact-agent"]), + patch.object(fact_manager.subprocess, "Popen", return_value=process), + patch.object(fact_manager, "notify_fact_agent", side_effect=ConnectionRefusedError("nope")), + patch.object(fact_manager.time, "sleep", return_value=None), + ): + manager = fact_manager.FactAgentManager( + fact_url="http://fact.example:8001/latest", + socket_path=socket_path, + startup_timeout_s=0.1, + log_file=log_file, + ) + + try: + manager.start_if_needed() + except TimeoutError: + pass + else: + raise AssertionError("expected startup timeout") + + process.terminate.assert_called_once() + process.wait.assert_called() + + +def test_fact_agent_manager_prefers_graceful_shutdown(tmp_path): + socket_path = str(tmp_path / "fact-agent.sock") + log_file = str(tmp_path / "fact-agent.log") + process = MagicMock() + process.pid = 123 + process.poll.return_value = None + + with patch.object(fact_manager, "notify_fact_agent", return_value={"accepted": True}) as notify: + manager = fact_manager.FactAgentManager( + fact_url="http://fact.example:8001/latest", + socket_path=socket_path, + log_file=log_file, + ) + manager.process = process + + manager.stop() + + notify.assert_called_once_with( + socket_path=socket_path, + payload={"event": "shutdown"}, + timeout_s=manager.rpc_timeout_s, + ) + process.terminate.assert_not_called() + process.wait.assert_called_once_with(timeout=fact_manager._FACT_AGENT_STOP_TIMEOUT) diff --git a/tests/fault_tolerance/unit/test_config.py b/tests/fault_tolerance/unit/test_config.py index a4a3dc21..3c4558a3 100644 --- a/tests/fault_tolerance/unit/test_config.py +++ b/tests/fault_tolerance/unit/test_config.py @@ -186,6 +186,42 @@ def test_read_from_yaml(): assert ft.rank_out_of_section_timeout == 333.0 +def test_read_health_logging_from_yaml(): + yaml_lines = [ + "fault_tolerance:", + " health_logging:", + " prefix: /lustre/logs/job_health.log", + " dmesg:", + " enabled: true", + " fact_result:", + " enabled: true", + " healthcheck:", + " enabled: true", + ] + with tmp_yaml_file(yaml_lines) as temp_file: + ft = fault_tolerance.FaultToleranceConfig.from_yaml_file(temp_file) + assert ft.health_logging.prefix == "/lustre/logs/job_health.log" + assert ft.health_logging.dmesg.enabled is True + assert ft.health_logging.fact_result.enabled is True + assert ft.health_logging.healthcheck.enabled is True + + +def test_read_fact_agent_options_from_yaml(): + yaml_lines = [ + "fault_tolerance:", + " fact_url: http://fact-yaml.example:8001/latest", + " fact_agent_socket_path: /tmp/nvrx-fact-agent-test.sock", + " fact_agent_rpc_timeout: 1.5", + " fact_agent_store_timeout: 12", + ] + with tmp_yaml_file(yaml_lines) as temp_file: + ft = fault_tolerance.FaultToleranceConfig.from_yaml_file(temp_file) + assert ft.fact_url == "http://fact-yaml.example:8001/latest" + assert ft.fact_agent_socket_path == "/tmp/nvrx-fact-agent-test.sock" + assert ft.fact_agent_rpc_timeout == 1.5 + assert ft.fact_agent_store_timeout == 12 + + def test_read_from_yaml_nested(): YAML_LINES = [ "some_other_section:", @@ -238,3 +274,130 @@ def test_to_yaml_file(): ref_conf.to_yaml_file(temp_file.name) restored_conf = fault_tolerance.FaultToleranceConfig.from_yaml_file(temp_file.name) assert restored_conf == ref_conf + + +def test_health_logging_cli_overrides_yaml(): + from nvidia_resiliency_ext.fault_tolerance.launcher import get_args_parser + + yaml_lines = [ + "fault_tolerance:", + " health_logging:", + " prefix: /lustre/logs/from_yaml.log", + " dmesg:", + " enabled: false", + " fact_result:", + " enabled: false", + " healthcheck:", + " enabled: false", + ] + with tmp_yaml_file(yaml_lines) as temp_file: + parser = get_args_parser() + args = parser.parse_args( + [ + "--ft-cfg-path", + temp_file, + "--ft-health-log-prefix", + "/lustre/logs/from_cli.log", + "--ft-enable-health-log-dmesg", + "true", + "--ft-enable-fact-result-artifact", + "true", + "dummy.py", + ] + ) + ft = fault_tolerance.FaultToleranceConfig.from_args(args) + + assert ft.health_logging.prefix == "/lustre/logs/from_cli.log" + assert ft.health_logging.dmesg.enabled is True + assert ft.health_logging.fact_result.enabled is True + assert ft.health_logging.healthcheck.enabled is False + + +def test_fact_agent_cli_overrides_yaml(): + from nvidia_resiliency_ext.fault_tolerance.launcher import get_args_parser + + yaml_lines = [ + "fault_tolerance:", + " fact_url: http://fact-yaml.example:8001/latest", + " fact_agent_socket_path: /tmp/yaml-fact-agent.sock", + " fact_agent_rpc_timeout: 3", + " fact_policy_ready_timeout: 60", + " fact_agent_store_timeout: 45", + ] + with tmp_yaml_file(yaml_lines) as temp_file: + parser = get_args_parser() + args = parser.parse_args( + [ + "--ft-cfg-path", + temp_file, + "--ft-fact-url", + "http://fact-cli.example:8001/latest", + "--ft-fact-agent-socket-path", + "/tmp/cli-fact-agent.sock", + "--ft-fact-agent-rpc-timeout", + "1.25", + "--ft-fact-policy-ready-timeout", + "9", + "--ft-fact-agent-store-timeout", + "8", + "dummy.py", + ] + ) + ft = fault_tolerance.FaultToleranceConfig.from_args(args) + + assert ft.fact_url == "http://fact-cli.example:8001/latest" + assert ft.fact_agent_socket_path == "/tmp/cli-fact-agent.sock" + assert ft.fact_agent_rpc_timeout == 1.25 + assert ft.fact_policy_ready_timeout == 9 + assert ft.fact_agent_store_timeout == 8 + + +def test_fact_history_cli_overrides_yaml(): + from nvidia_resiliency_ext.fault_tolerance.launcher import get_args_parser + + yaml_lines = [ + "fault_tolerance:", + " fact_history_es_url: http://history-yaml.example", + " fact_history_es_auth_file: /tmp/history-yaml.auth", + " fact_history_lookback: 7d", + " fact_history_index: history-yaml-*", + " fact_history_max_candidate_nodes: 8", + " fact_history_query_timeout: 12", + " fact_min_repeat_count_for_avoid: 3", + " fact_max_attribution_avoids_per_cycle: 2", + ] + with tmp_yaml_file(yaml_lines) as temp_file: + parser = get_args_parser() + args = parser.parse_args( + [ + "--ft-cfg-path", + temp_file, + "--ft-fact-history-es-url", + "http://history-cli.example", + "--ft-fact-history-es-auth-file", + "/tmp/history-cli.auth", + "--ft-fact-history-lookback", + "14d", + "--ft-fact-history-index", + "history-cli-*", + "--ft-fact-history-max-candidate-nodes", + "16", + "--ft-fact-history-query-timeout", + "30", + "--ft-fact-min-repeat-count-for-avoid", + "2", + "--ft-fact-max-attribution-avoids-per-cycle", + "1", + "dummy.py", + ] + ) + ft = fault_tolerance.FaultToleranceConfig.from_args(args) + + assert ft.fact_history_es_url == "http://history-cli.example" + assert ft.fact_history_es_auth_file == "/tmp/history-cli.auth" + assert ft.fact_history_lookback == "14d" + assert ft.fact_history_index == "history-cli-*" + assert ft.fact_history_max_candidate_nodes == 16 + assert ft.fact_history_query_timeout == 30 + assert ft.fact_min_repeat_count_for_avoid == 2 + assert ft.fact_max_attribution_avoids_per_cycle == 1 diff --git a/tests/fault_tolerance/unit/test_launcher.py b/tests/fault_tolerance/unit/test_launcher.py index 429d1c62..5e9b49cd 100644 --- a/tests/fault_tolerance/unit/test_launcher.py +++ b/tests/fault_tolerance/unit/test_launcher.py @@ -23,12 +23,15 @@ import sys import tempfile import unittest +from datetime import datetime, timezone +from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest from nvidia_resiliency_ext import fault_tolerance from nvidia_resiliency_ext.fault_tolerance.config import FaultToleranceConfig +from nvidia_resiliency_ext.shared_utils.health_check import AttrSvcResult WORLD_SIZE = 4 DEFAULT_TIMEOUT = 90 @@ -130,6 +133,357 @@ def test_legacy_rdzv_impl_injects_use_libuv_false(): assert config.rdzv_configs["use_libuv"] is False +def test_fact_url_starts_launcher_managed_fact_agent(tmp_path): + from nvidia_resiliency_ext.fault_tolerance import launcher + + parser = launcher.get_args_parser() + args = parser.parse_args( + [ + "--nnodes", + "1", + "--nproc-per-node", + "1", + "--rdzv-endpoint", + "127.0.0.1:29500", + "--ft-fact-url", + "http://fact.example:8001/latest", + "train.py", + ] + ) + fact_agent_manager = MagicMock() + fact_agent_manager.start_if_needed.return_value = SimpleNamespace( + socket_path=str(tmp_path / "managed-fact-agent.sock") + ) + + with ( + patch.object( + launcher.LocalElasticAgent, + "setup_rank_monitors_early", + return_value={}, + ), + patch.object( + launcher, + "FactAgentManager", + return_value=fact_agent_manager, + ) as manager_cls, + patch.object(launcher, "_FACT_AGENT_MANAGER", None), + patch.dict( + os.environ, + {"SLURM_JOB_USER": "slurm-user", "SLURM_CLUSTER_NAME": "slurm-cluster"}, + clear=False, + ), + ): + config, _, _ = launcher.config_from_args(args) + + manager_cls.assert_called_once() + manager_kwargs = manager_cls.call_args.kwargs + assert manager_kwargs["fact_url"] == "http://fact.example:8001/latest" + assert manager_kwargs["socket_path"] is None + assert manager_kwargs["rpc_timeout_s"] == 2.0 + assert manager_kwargs["run_id"] == "none" + assert manager_kwargs["rdzv_endpoint"] == "127.0.0.1:29500" + assert manager_kwargs["store_timeout_s"] == 60.0 + assert manager_kwargs["is_store_host"] is True + assert manager_kwargs["job_id"] == "none" + assert manager_kwargs["ranks_per_node"] == 1 + assert manager_kwargs["username"] == "slurm-user" + assert manager_kwargs["cluster"] == "slurm-cluster" + assert manager_kwargs["health_log_prefix"] is None + assert manager_kwargs["dmesg_artifact_enabled"] is False + assert manager_kwargs["result_artifact_enabled"] is False + assert manager_kwargs["grpc_server_address"] is None + assert manager_kwargs["grpc_node_id"] is None + assert manager_kwargs["fact_history_es_url"] is None + assert manager_kwargs["fact_history_es_auth_file"] is None + fact_agent_manager.start_if_needed.assert_called_once() + assert config.fault_tol_cfg.fact_agent_socket_path == str(tmp_path / "managed-fact-agent.sock") + + +def test_fact_agent_start_failure_is_nonfatal(caplog): + from nvidia_resiliency_ext.fault_tolerance import launcher + + parser = launcher.get_args_parser() + args = parser.parse_args( + [ + "--nnodes", + "1", + "--nproc-per-node", + "1", + "--rdzv-endpoint", + "127.0.0.1:29500", + "--ft-fact-url", + "http://fact.example:8001/latest", + "train.py", + ] + ) + fact_agent_manager = MagicMock() + fact_agent_manager.start_if_needed.side_effect = RuntimeError("boom") + + with ( + patch.object( + launcher.LocalElasticAgent, + "setup_rank_monitors_early", + return_value={}, + ), + patch.object(launcher, "FactAgentManager", return_value=fact_agent_manager), + patch.object(launcher, "_FACT_AGENT_MANAGER", None), + caplog.at_level(logging.WARNING), + ): + config, _, _ = launcher.config_from_args(args) + + fact_agent_manager.start_if_needed.assert_called_once() + assert config.fault_tol_cfg.fact_url == "http://fact.example:8001/latest" + assert "Failed to start local nvrx-fact-agent" in caplog.text + + +def test_fact_result_artifact_flag_passes_to_fact_agent_manager(tmp_path): + from nvidia_resiliency_ext.fault_tolerance import launcher + + parser = launcher.get_args_parser() + args = parser.parse_args( + [ + "--nnodes", + "1", + "--nproc-per-node", + "1", + "--rdzv-endpoint", + "127.0.0.1:29500", + "--ft-fact-url", + "http://fact.example:8001/latest", + "--ft-health-log-prefix", + str(tmp_path / "job_health.log"), + "--ft-enable-fact-result-artifact", + "true", + "--ft-per-cycle-applog-prefix", + str(tmp_path / "train.log"), + "--ft-enable-log-server", + "true", + "train.py", + ] + ) + fact_agent_manager = MagicMock() + grpc_proc = MagicMock() + + class FakePipeBasedLogsSpecs: + def __init__( + self, + base_log_file, + launcher_pipe_fd=None, + launcher_log_file=None, + grpc_server_address=None, + node_id=None, + ): + self.base_log_file = base_log_file + self.grpc_server_address = grpc_server_address + self.node_id = node_id + + with ( + patch.object( + launcher.LocalElasticAgent, + "setup_rank_monitors_early", + return_value={}, + ), + patch.object(launcher, "PipeBasedLogsSpecs", FakePipeBasedLogsSpecs), + patch.object(launcher, "_start_grpc_log_servers", return_value=[grpc_proc]), + patch.object(launcher, "FactAgentManager", return_value=fact_agent_manager) as manager_cls, + patch.object(launcher, "_FACT_AGENT_MANAGER", None), + patch.object(launcher, "_GRPC_SERVER_PROCESSES", None), + ): + launcher.config_from_args(args) + + assert manager_cls.call_args.kwargs["result_artifact_enabled"] is True + assert manager_cls.call_args.kwargs["grpc_server_address"] == "localhost:50051" + + +def test_fact_result_artifact_requires_grpc_log_aggregation(tmp_path): + from nvidia_resiliency_ext.fault_tolerance import launcher + + parser = launcher.get_args_parser() + args = parser.parse_args( + [ + "--nnodes", + "1", + "--nproc-per-node", + "1", + "--rdzv-endpoint", + "127.0.0.1:29500", + "--ft-fact-url", + "http://fact.example:8001/latest", + "--ft-health-log-prefix", + str(tmp_path / "job_health.log"), + "--ft-enable-fact-result-artifact", + "true", + "train.py", + ] + ) + + with pytest.raises(ValueError, match="require gRPC log aggregation"): + launcher.config_from_args(args) + + +def test_fact_dmesg_artifact_requires_grpc_log_aggregation(tmp_path): + from nvidia_resiliency_ext.fault_tolerance import launcher + + parser = launcher.get_args_parser() + args = parser.parse_args( + [ + "--nnodes", + "1", + "--nproc-per-node", + "1", + "--rdzv-endpoint", + "127.0.0.1:29500", + "--ft-fact-url", + "http://fact.example:8001/latest", + "--ft-health-log-prefix", + str(tmp_path / "job_health.log"), + "--ft-enable-health-log-dmesg", + "true", + "train.py", + ] + ) + + with pytest.raises(ValueError, match="require gRPC log aggregation"): + launcher.config_from_args(args) + + +def test_fact_agent_notification_includes_cycle_start_time(): + from nvidia_resiliency_ext.fault_tolerance import launcher + + agent = launcher.LocalElasticAgent.__new__(launcher.LocalElasticAgent) + agent._ft_cfg = SimpleNamespace( + fact_url="http://fact.example/latest", + fact_agent_socket_path="/tmp/fact-agent.sock", + fact_agent_rpc_timeout=2.0, + fact_policy_ready_timeout=60.0, + ) + agent._is_store_host = True + agent._node_id = "node-a" + agent._fact_agent_cycle_start_time = datetime(2026, 5, 10, 12, 0, 0, tzinfo=timezone.utc) + agent._rdzv_handler = SimpleNamespace( + _this_node=SimpleNamespace(addr="node-a"), + get_active_node_addrs=lambda: ["node-a:29500", "node-b:29500"], + ) + + with patch.object( + launcher, + "notify_fact_agent", + return_value={"accepted": True}, + ) as notify: + launcher.LocalElasticAgent._notify_fact_agent(agent, SimpleNamespace(), 3) + + payload = notify.call_args.kwargs["payload"] + cycle_end_time = datetime.fromisoformat(payload.pop("cycle_end_time")) + assert cycle_end_time.tzinfo is not None + assert payload == { + "event": "cycle_failed", + "cycle": 3, + "cycle_start_time": "2026-05-10T12:00:00+00:00", + "expected_nodes": ["node-a", "node-b"], + } + + +def test_fact_agent_rpc_timeout_must_be_positive(): + from nvidia_resiliency_ext.fault_tolerance import launcher + + parser = launcher.get_args_parser() + args = parser.parse_args( + [ + "--nnodes", + "1", + "--nproc-per-node", + "1", + "--rdzv-endpoint", + "127.0.0.1:29500", + "--ft-fact-url", + "http://fact.example:8001/latest", + "--ft-fact-agent-rpc-timeout", + "0", + "train.py", + ] + ) + + with pytest.raises(ValueError, match="--ft-fact-agent-rpc-timeout must be positive"): + launcher.config_from_args(args) + + +def test_fact_policy_ready_timeout_must_be_non_negative(): + from nvidia_resiliency_ext.fault_tolerance import launcher + + parser = launcher.get_args_parser() + args = parser.parse_args( + [ + "--nnodes", + "1", + "--nproc-per-node", + "1", + "--rdzv-endpoint", + "127.0.0.1:29500", + "--ft-fact-url", + "http://fact.example:8001/latest", + "--ft-fact-policy-ready-timeout", + "-1", + "train.py", + ] + ) + + with pytest.raises(ValueError, match="--ft-fact-policy-ready-timeout must be non-negative"): + launcher.config_from_args(args) + + +def test_fact_avoid_nodes_waits_for_ready_policy(): + from nvidia_resiliency_ext.fault_tolerance import launcher + + agent = launcher.LocalElasticAgent.__new__(launcher.LocalElasticAgent) + agent._ft_cfg = SimpleNamespace( + fact_url="http://fact.example/latest", + fact_agent_socket_path="/tmp/fact-agent.sock", + fact_agent_rpc_timeout=2.0, + fact_policy_ready_timeout=1.0, + ) + agent._is_store_host = True + agent._last_fact_agent_cycle = 7 + + with patch.object( + launcher, + "notify_fact_agent", + side_effect=[ + {"cycle_id": "7", "status": "pending", "avoid_nodes": []}, + {"cycle_id": "7", "status": "ready", "avoid_nodes": ["node-a"]}, + ], + ) as notify, patch.object(launcher.time, "sleep") as sleep: + nodes = launcher.LocalElasticAgent.get_fact_avoid_nodes_for_rendezvous(agent) + + assert nodes == ["node-a"] + assert notify.call_count == 2 + sleep.assert_called_once() + + +def test_fact_avoid_nodes_does_not_wait_for_skipped_policy(): + from nvidia_resiliency_ext.fault_tolerance import launcher + + agent = launcher.LocalElasticAgent.__new__(launcher.LocalElasticAgent) + agent._ft_cfg = SimpleNamespace( + fact_url="http://fact.example/latest", + fact_agent_socket_path="/tmp/fact-agent.sock", + fact_agent_rpc_timeout=2.0, + fact_policy_ready_timeout=60.0, + ) + agent._is_store_host = True + agent._last_fact_agent_cycle = 7 + + with patch.object( + launcher, + "notify_fact_agent", + return_value={"cycle_id": "7", "status": "skipped", "avoid_nodes": []}, + ) as notify, patch.object(launcher.time, "sleep") as sleep: + nodes = launcher.LocalElasticAgent.get_fact_avoid_nodes_for_rendezvous(agent) + + assert nodes == [] + notify.assert_called_once() + sleep.assert_not_called() + + def test_rank_not_send_initial_hb(tmp_dir): # If one rank does not send initial heartbeat, # FT should terminate the rank, and launcher should kill all other ranks @@ -355,6 +709,7 @@ def _make_agent_spec(rdzv_round=1): """Minimal WorkerSpec-like object for testing launcher cycle-info env interaction.""" spec = MagicMock() spec.rdzv_handler = MagicMock() + spec.rdzv_handler._attribution_service = None spec.rdzv_handler.round.return_value = rdzv_round spec.rdzv_handler.get_active_node_addrs.return_value = ["node001", "node002"] spec.rdzv_handler.get_standby_node_addrs.return_value = ["node003"] @@ -508,6 +863,7 @@ def _make_agent(self): def test_handle_restart_decision_progress_terminate(self): """Returns False without restarting when progress tracker says terminate early.""" agent = self._make_agent() + agent._rdzv_handler._attribution_service = None agent._progress_tracker = MagicMock() agent._progress_tracker.should_terminate_early.return_value = True agent._remaining_restarts = 2 @@ -527,6 +883,7 @@ def test_handle_restart_decision_progress_terminate(self): def test_handle_restart_decision_restarts_remaining(self): """Returns True and decrements _remaining_restarts when restarts are available.""" agent = self._make_agent() + agent._rdzv_handler._attribution_service = None agent._progress_tracker = MagicMock() agent._progress_tracker.should_terminate_early.return_value = False agent._remaining_restarts = 2 @@ -544,9 +901,39 @@ def test_handle_restart_decision_restarts_remaining(self): mock_restart.assert_called_once() mock_open.assert_not_called() + def test_handle_restart_decision_attrsvc_stop_blocks_restart(self): + """Returns False without consuming a restart when attrsvc recommends STOP.""" + agent = self._make_agent() + attrsvc = MagicMock() + attrsvc.get_last_result.return_value = AttrSvcResult( + result={"state": "STOP"}, + recommendation="STOP", + should_stop=True, + log_path="/path/to/cycle_0.log", + ) + agent._rdzv_handler._attribution_service = attrsvc + agent._progress_tracker = MagicMock() + agent._progress_tracker.should_terminate_early.return_value = False + agent._remaining_restarts = 2 + + with ( + patch.object(agent, '_restart_workers') as mock_restart, + patch.object(agent, '_open_rendezvous_for_restart') as mock_open, + ): + result = agent._handle_restart_decision( + role="test", spec=self.spec, log_msg="[%s] restarting", open_rendezvous=True + ) + + self.assertFalse(result) + self.assertEqual(agent._remaining_restarts, 2) + attrsvc.get_last_result.assert_called_once() + mock_restart.assert_not_called() + mock_open.assert_not_called() + def test_handle_restart_decision_no_restarts_left(self): """Returns False when _remaining_restarts is 0.""" agent = self._make_agent() + agent._rdzv_handler._attribution_service = None agent._progress_tracker = MagicMock() agent._progress_tracker.should_terminate_early.return_value = False agent._remaining_restarts = 0 @@ -562,6 +949,7 @@ def test_handle_restart_decision_no_restarts_left(self): def test_handle_restart_decision_open_rendezvous_called_when_requested(self): """Calls _open_rendezvous_for_restart() when open_rendezvous=True.""" agent = self._make_agent() + agent._rdzv_handler._attribution_service = None agent._progress_tracker = MagicMock() agent._progress_tracker.should_terminate_early.return_value = False agent._remaining_restarts = 1 @@ -800,6 +1188,38 @@ def wait(self): assert (log_dir / "grpc_diag_leaf_1.log").is_file() +def test_main_stops_fact_agent_before_grpc_log_servers(): + from nvidia_resiliency_ext.fault_tolerance import launcher + + order = [] + args = SimpleNamespace(ft_log_server_graceful_shutdown_timeout=1.25) + grpc_proc = object() + fact_manager = MagicMock() + attr_manager = MagicMock() + fact_manager.stop.side_effect = lambda: order.append("fact") + attr_manager.stop.side_effect = lambda: order.append("attr") + + def stop_grpc(procs, timeout): + assert procs == [grpc_proc] + assert timeout == 1.25 + order.append("grpc") + + with ( + patch.object(launcher, "parse_args", return_value=args), + patch.object(launcher, "run", return_value=None), + patch.object(launcher, "stop_grpc_log_servers", side_effect=stop_grpc), + patch.object(launcher.sys, "exit", side_effect=SystemExit) as sys_exit, + patch.object(launcher, "_FACT_AGENT_MANAGER", fact_manager), + patch.object(launcher, "_ATTRIBUTION_MANAGER", attr_manager), + patch.object(launcher, "_GRPC_SERVER_PROCESSES", [grpc_proc]), + pytest.raises(SystemExit), + ): + launcher.main([]) + + assert order == ["fact", "attr", "grpc"] + sys_exit.assert_called_once_with(0) + + def test_managed_attribution_listen_port_rejects_log_funnel_overlap(): from nvidia_resiliency_ext.fault_tolerance.launcher import ( LogFunnelPorts, diff --git a/tests/fault_tolerance/unit/test_per_cycle_logs.py b/tests/fault_tolerance/unit/test_per_cycle_logs.py index 5f7ed66f..e386d220 100644 --- a/tests/fault_tolerance/unit/test_per_cycle_logs.py +++ b/tests/fault_tolerance/unit/test_per_cycle_logs.py @@ -16,6 +16,7 @@ MultiplexingReaderThread, PipeSubprocessHandler, _should_filter_line, + get_source_cycle_log_file, ) @@ -76,6 +77,11 @@ def test_filter_without_newline(self): assert _should_filter_line("Rank 1117: Error") is False +def test_get_source_cycle_log_file(): + path = get_source_cycle_log_file("/lustre/logs/job_health.log", "dmesg", 3) + assert path == "/lustre/logs/job_health_dmesg_cycle3.log" + + class TestMultiplexingReaderThread: """Tests for MultiplexingReaderThread class.""" diff --git a/tests/shared_utils/test_health_check.py b/tests/shared_utils/test_health_check.py index 50e2f85a..6d358d0d 100644 --- a/tests/shared_utils/test_health_check.py +++ b/tests/shared_utils/test_health_check.py @@ -35,6 +35,7 @@ def _attribution_item(raw_text, reason_code): "auto_resume_explanation": "", "attribution_text": "", "checkpoint_saved_flag": 0, + "action": reason_code, "primary_issues": [], "secondary_issues": [], } @@ -618,7 +619,7 @@ def test_non_http_endpoint_does_not_create_http_client(self, mock_client): mock_client.assert_not_called() @patch("nvidia_resiliency_ext.shared_utils.health_check.httpx.Client") - def test_get_results_returns_stop_decision(self, mock_client): + def test_get_results_returns_stop_recommendation(self, mock_client): client = mock_client.return_value.__enter__.return_value response = MagicMock() response.status_code = 200 @@ -640,9 +641,13 @@ def test_get_results_returns_stop_decision(self, mock_client): client.get.return_value = response service = AttributionService(endpoint="http://attr.example:8000/") - should_stop = service._get_results("/tmp/train.log") + result = service._get_results("/tmp/train.log") - self.assertTrue(should_stop) + self.assertIsNotNone(result) + self.assertEqual(result.recommendation.action, "STOP") + self.assertEqual(result.recommendation.reason, "STOP - DONT RESTART") + self.assertTrue(result.should_stop) + self.assertEqual(result.log_path, "/tmp/train.log") mock_client.assert_called_once_with(base_url="http://attr.example:8000", timeout=60.0) client.get.assert_called_once_with( "/logs", @@ -671,9 +676,11 @@ def test_get_results_maps_restart_recommendation_to_no_stop(self, mock_client): client.get.return_value = response service = AttributionService(endpoint="http://attr.example:8000/") - should_stop = service._get_results("/tmp/train.log") + result = service._get_results("/tmp/train.log") - self.assertFalse(should_stop) + self.assertIsNotNone(result) + self.assertFalse(result.should_stop) + self.assertEqual(result.recommendation.action, "RESTART") @patch("nvidia_resiliency_ext.shared_utils.health_check.httpx.Client") def test_get_results_maps_continue_recommendation_to_no_stop(self, mock_client): @@ -696,9 +703,11 @@ def test_get_results_maps_continue_recommendation_to_no_stop(self, mock_client): client.get.return_value = response service = AttributionService(endpoint="http://attr.example:8000/") - should_stop = service._get_results("/tmp/train.log") + result = service._get_results("/tmp/train.log") - self.assertFalse(should_stop) + self.assertIsNotNone(result) + self.assertFalse(result.should_stop) + self.assertEqual(result.recommendation.action, "CONTINUE") if __name__ == "__main__":