From dd6927f0de9a764ada581ee1944ed14d35318621 Mon Sep 17 00:00:00 2001 From: Hexin Wang Date: Fri, 16 May 2025 19:46:49 -0700 Subject: [PATCH 1/3] Test UT. --- examples/fault_tolerance/basic_ft_example.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/fault_tolerance/basic_ft_example.py b/examples/fault_tolerance/basic_ft_example.py index 52abf292..90af79cc 100644 --- a/examples/fault_tolerance/basic_ft_example.py +++ b/examples/fault_tolerance/basic_ft_example.py @@ -22,6 +22,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader, DistributedSampler + # FT: import NVRx import nvidia_resiliency_ext.fault_tolerance as ft From 5e71458d9ec05be41a3070b5c1580e9f4c0369c5 Mon Sep 17 00:00:00 2001 From: Hexin Wang Date: Fri, 16 May 2025 19:49:16 -0700 Subject: [PATCH 2/3] . --- examples/fault_tolerance/basic_ft_example.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/fault_tolerance/basic_ft_example.py b/examples/fault_tolerance/basic_ft_example.py index 90af79cc..515c9ae2 100644 --- a/examples/fault_tolerance/basic_ft_example.py +++ b/examples/fault_tolerance/basic_ft_example.py @@ -22,11 +22,10 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader, DistributedSampler - # FT: import NVRx import nvidia_resiliency_ext.fault_tolerance as ft -# Simple example of using the FT library with PyTorch DDP. +# Simple example of using the FT library with PyTorch DDP.. # This script trains a dummy model on dummy data. CPU is used for training. # After each epoch, FT timeouts are calculated and saved to the file "./ft_state.json". # From f981737834e1cf69e5e42fc5ac883c0f8af10b23 Mon Sep 17 00:00:00 2001 From: Hexin Wang Date: Mon, 19 May 2025 18:54:27 -0700 Subject: [PATCH 3/3] Removed NVTE_FUSED_ATTN for Nemo >= 25.02 --- tests/ptl_resiliency/func/nemo20/straggler_test_llama3.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/ptl_resiliency/func/nemo20/straggler_test_llama3.py b/tests/ptl_resiliency/func/nemo20/straggler_test_llama3.py index 5f0968bb..a26672a9 100644 --- a/tests/ptl_resiliency/func/nemo20/straggler_test_llama3.py +++ b/tests/ptl_resiliency/func/nemo20/straggler_test_llama3.py @@ -75,7 +75,6 @@ def local_executor( "NCCL_NVLS_ENABLE": "0", "NVTE_DP_AMAX_REDUCE_INTERVAL": "0", "NVTE_ASYNC_AMAX_REDUCTION": "1", - "NVTE_FUSED_ATTN": "0", } executor = run.LocalExecutor(ntasks_per_node=devices, launcher="torchrun", env_vars=env_vars)