diff --git a/src/nvidia_resiliency_ext/checkpointing/async_ckpt/core.py b/src/nvidia_resiliency_ext/checkpointing/async_ckpt/core.py index 1819ae3a..dc3084fb 100644 --- a/src/nvidia_resiliency_ext/checkpointing/async_ckpt/core.py +++ b/src/nvidia_resiliency_ext/checkpointing/async_ckpt/core.py @@ -19,7 +19,7 @@ """ import logging -from abc import ABC, abstractmethod +from abc import ABCMeta, ABC, abstractmethod from collections import deque from queue import Empty from time import sleep, time @@ -129,6 +129,21 @@ def execute_finalize_fns(self, validate_matching_call_idx: bool = True) -> int: "That probably means not all ranks are participating in async finalization" return self.call_idx +# Singleton metaclass +class Singleton(type): + _instances = {} + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + def clear(cls): + if cls in cls._instances: + del cls._instances[cls] + +# class SingletonABCMeta(Singleton, ABCMeta): +# pass + class AsyncCaller(ABC): """Wrapper around mp.Process that ensures correct semantic of distributed finalization. @@ -192,8 +207,13 @@ def sync_all_async_calls(self, is_alive: int) -> bool: return ten[0] == 0 @abstractmethod - def close(self): - """Terminate the async caller at exit of an application or some termination conditions""" + def close(self, abort=False): + """Terminate the async caller at exit of an application or some termination conditions + + Args: + abort (bool, optional): Default to False. Needs to be manually set to true when + the checkpoint async process needs to be aborted. + """ logger.info(f"AsyncCaller: {torch.distributed.get_rank()}, Destroying Async Caller") @abstractmethod @@ -283,15 +303,23 @@ def is_current_async_call_done(self, blocking: bool = False, no_dist: bool = Fal is_done = True return is_done - def close(self): + def close(self, abort=False): """For TemporalAsyncCaller, this method is called explictly in `is_current_async_calls_done` This method make sure the TemporalAsyncCaller terminated with all its assigned async request completed + + Args: + abort (bool, optional): Default to False. Needs to be manually set to true when + the checkpoint async process needs to be aborted. """ if self.process: logger.debug(f"rank: {self.rank}, joining self.process") - self.process.join() + if abort: + logger.warning(f"Temporal worker aborted in rank {torch.distributed.get_rank()}") + self.process.kill() + else: + self.process.join() self.process = None logger.debug( "TemporalAsyncCaller: Async process join finished " @@ -302,6 +330,14 @@ def close(self): def __del__(self): pass + def _debug_is_async_process_running(self): + """ + For unit test purpose + """ + if self.process is None: + return False + return self.process.is_alive() + class PersistentAsyncCaller(AsyncCaller): """Wrapper around mp.Process that ensures correct semantic of distributed finalization. @@ -360,7 +396,7 @@ def schedule_async_call(self, async_req: AsyncRequest) -> None: ), ) self.process.start() - logger.info(f"PersistentAsyncCaller: {self.rank}, Started Async Caller") + logger.error(f" Aarti PersistentAsyncCaller: {self.rank}, Started Async Caller {self.process}") if async_req.preload_fn: self.preload_q.put(async_req.call_idx) @@ -429,21 +465,38 @@ def is_current_async_call_done(self, blocking: bool = False, no_dist: bool = Fal return is_done - def close(self): + def close(self, abort=False): """Wait on the left async requests and terminate the PersistentAsyncCaller Signals the PersistentAsyncCaller by sending a 'DONE' message to make it terminated + + Args: + abort (bool, optional): Default to False. Needs to be manually set to true when + the checkpoint async process needs to be aborted. """ logger.info(f"PersistentAsyncCaller: {self.rank}, Destroying Async Caller") if self.process: - self.queue.put('DONE') - self.queue.join() - self.process.join() + if abort: + logger.error(f"Persistent worker aborted in rank {torch.distributed.get_rank()}") + self.process.kill() + else: + self.queue.put('DONE') + self.queue.join() + self.process.join() + self.process = None def __del__(self): self.close() + def _debug_is_async_process_running(self): + """ + For unit test purpose + """ + if self.process is None: + return False + return self.process.is_alive() + @staticmethod @_disable_gc() def async_loop( @@ -514,7 +567,7 @@ class _ActiveAsyncRequest(NamedTuple): async_request: AsyncRequest -class AsyncCallsQueue: +class AsyncCallsQueue(metaclass=Singleton): """Manages a queue of async calls. Allows adding a new async call with `schedule_async_request` and finalizing @@ -549,6 +602,7 @@ def schedule_async_request(self, async_request: AsyncRequest) -> int: """ self.call_idx += 1 async_caller = self._get_async_caller() + logger.error(f"Aarti persistent caller created. But process not yet created persistentObject = {async_caller}") # Backward compatibility for local checkpointing built with the old AsyncRequest if len(async_request._fields) != len(AsyncRequest._fields): async_request = AsyncRequest(**async_request._asdict()) @@ -556,6 +610,7 @@ def schedule_async_request(self, async_request: AsyncRequest) -> int: async_caller.schedule_async_call( async_request._replace(call_idx=self.call_idx, finalize_fns=[]) ) + logger.error(f"Aarti persistent caller created. Now I expect process to be created") self.async_calls.append(_ActiveAsyncRequest(self.call_idx, async_caller, async_request)) return self.call_idx @@ -593,8 +648,39 @@ def get_num_unfinalized_calls(self): """Get the number of active async calls.""" return len(self.async_calls) - def close(self): - """Finalize all calls upon closing.""" - self.maybe_finalize_async_calls(blocking=True) + def close(self, abort=False): + """Finalize all calls upon closing. + + Args: + abort (bool, optional): Default to False. Needs to be manually set to true when + the checkpoint async process needs to be aborted. + """ + logger.error(f"Aarti 2222 {self} {abort} {self.persistent} {self.persistent_caller}") + if not abort: + self.maybe_finalize_async_calls(blocking=True) if self.persistent and self.persistent_caller: - self.persistent_caller.close() + self.persistent_caller.close(abort=abort) + # Reset all class params + self.call_idx = -1 + self.persistent_caller = None + # Clear the singleton registory of async worker + Singleton.clear(AsyncCallsQueue) + +def abort_nvrx_checkpoint(): + # we have a singleton persistent worker in our async calls queue + # close the async calls queue which will clear the singleton object + # to ensure a clean restart + + + # # When persistent_caller is singleton + # logger.error(f"Aarti 00000 abort_nvrx_checkpoint called") + # persistent_caller = AsyncCallsQueue(persistent=True)._get_async_caller() + # logger.error(f"Aarti 11111 close called on object {persistent_caller}") + # persistent_caller.close(abort=True) + + logger.error(f"Aarti 00000 abort_nvrx_checkpoint called") + async_queue_singleton = AsyncCallsQueue(persistent=True) + logger.error(f"Aarti 11111 close called on object {async_queue_singleton}") + async_queue_singleton.close(abort=True) + + # TBD: Create singleton for result_queue diff --git a/src/nvidia_resiliency_ext/checkpointing/async_ckpt/torch_ckpt.py b/src/nvidia_resiliency_ext/checkpointing/async_ckpt/torch_ckpt.py index 5e05e44b..773aa8ae 100644 --- a/src/nvidia_resiliency_ext/checkpointing/async_ckpt/torch_ckpt.py +++ b/src/nvidia_resiliency_ext/checkpointing/async_ckpt/torch_ckpt.py @@ -74,3 +74,16 @@ def finalize_async_save(self, blocking: bool = False, no_dist=True, terminate=Fa self._async_calls_queue.maybe_finalize_async_calls(blocking, no_dist=no_dist) if terminate: self._async_calls_queue.close() + + def _get_async_calls_queue(self): + """ + Function introduced for unit test purpose to validate the state of the Async workers + """ + return self._async_calls_queue + + def close(self): + if self._async_calls_queue is not None: + self._async_calls_queue.close() + + def __del__(self): + self.close() diff --git a/src/nvidia_resiliency_ext/inprocess/abort.py b/src/nvidia_resiliency_ext/inprocess/abort.py index 820e44cf..aef68136 100644 --- a/src/nvidia_resiliency_ext/inprocess/abort.py +++ b/src/nvidia_resiliency_ext/inprocess/abort.py @@ -22,6 +22,7 @@ import torch from nvidia_resiliency_ext.attribution.trace_analyzer.trace_collector import TorchFRTraceCollector +from nvidia_resiliency_ext.checkpointing.async_ckpt.core import AsyncCallsQueue, abort_nvrx_checkpoint from . import utils from .state import FrozenState @@ -185,3 +186,35 @@ def __call__(self, state: FrozenState) -> FrozenState: te_fp8.FP8GlobalStateManager.reset() return state + + +class AbortCheckpoint(Abort): + r''' + Aborts Async Checkpoint processes + + ''' + # def __init__(self): + # # Persistent worker once created is a singleton object + # # During abort, we mus reset this object to a clean state + # log = logging.getLogger(__name__) + # log.error(f"Aarti 1111 init AbortCheckpoint ") + # self.async_queue = AsyncCallsQueue(persistent=True) + # log.error(f"Aarti 2222 {self.__dict__}") + + def __call__(self, state: FrozenState) -> FrozenState: + # log = logging.getLogger(__name__) + # log.error(f"Aarti 33333 {self.__dict__}") + # self.async_queue = AsyncCallsQueue(persistent=True) + # log.error(f"Aarti 4444 {self.__dict__} and singleton persistent_worker = {self.async_queue.persistent_caller}") + # if self.async_queue is not None: + # self.async_queue.close(abort=True) + # del self.async_queue + # log.error(f"Aarti 5555 {self.__dict__}") + # # global _results_queue + # # if _results_queue is not None: + # # _results_queue._manager.shutdown() + # # del _results_queue + # # _results_queue = None + # # self.async_queue = None + abort_nvrx_checkpoint() + return state diff --git a/src/nvidia_resiliency_ext/inprocess/rank_assignment.py b/src/nvidia_resiliency_ext/inprocess/rank_assignment.py index abe5268c..bc4bdc4e 100644 --- a/src/nvidia_resiliency_ext/inprocess/rank_assignment.py +++ b/src/nvidia_resiliency_ext/inprocess/rank_assignment.py @@ -22,6 +22,7 @@ import enum import heapq import itertools +import logging import warnings from typing import Callable, Optional, Union @@ -174,6 +175,12 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx: else: mode = Mode.INACTIVE active_rank = None + # Log deactivation if transitioning from ACTIVE to INACTIVE + if state.mode == Mode.ACTIVE: + log = logging.getLogger(__name__) + log.info( + f"[In-process] Rank deactivated (rank={state.rank}) due to max active world size limit ({active_world_size})" + ) state = dataclasses.replace( state, @@ -215,6 +222,12 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx: else: mode = Mode.INACTIVE active_rank = None + # Log deactivation if transitioning from ACTIVE to INACTIVE + if state.mode == Mode.ACTIVE: + log = logging.getLogger(__name__) + log.info( + f"[In-process] Rank deactivated (rank={state.rank}) due to divisibility requirement (active_world_size={active_world_size}, divisor={divisor})" + ) state = dataclasses.replace( state, @@ -511,6 +524,7 @@ def __init__( self.tree = None self.rank_map = {} self.init_rank_map = {} + self.current_state = None def build_tree(self, state, store): key = [ @@ -560,6 +574,8 @@ def build_tree(self, state, store): def replace_with_inactive(self, terminated_active_ranks): replaced_terminate_active_ranks = set() + log = logging.getLogger(__name__) + for terminated_active_rank in sorted(terminated_active_ranks): terminated_active_node = self.rank_map[terminated_active_rank] @@ -571,6 +587,14 @@ def replace_with_inactive(self, terminated_active_ranks): ): if parent.inactive_nodes: _, inactive = parent.inactive_nodes.popitem() + + # If this rank is an inactive rank becoming active, log the transition + if self.current_state.initial_rank == inactive.state.initial_rank: + log.info( + f"[In-process] Hot spare activated: terminated rank (initial_rank={terminated_active_node.state.initial_rank}, active_rank={terminated_active_node.state.active_rank}) " + f"replaced by spare rank (initial_rank={inactive.state.initial_rank}, new_active_rank={terminated_active_node.state.active_rank})" + ) + inactive.activate(terminated_active_node.state.active_rank) replaced_terminate_active_ranks.add(terminated_active_rank) break @@ -601,12 +625,19 @@ def replace_with_backfill(self, unhandled_terminations): key=lambda node: node.state.active_rank, ) + log = logging.getLogger(__name__) for backfill_node, terminated_node in itertools.zip_longest( reversed(largest_active_nodes), terminated_nodes, fillvalue=None, ): if backfill_node is not None: + # If this rank is being backfilled, log it + if self.current_state.initial_rank == backfill_node.state.initial_rank: + log.info( + f"[In-process] Rank backfilled (initial_rank={backfill_node.state.initial_rank}) " + f"active_rank changed from {backfill_node.state.active_rank} to {terminated_node.state.active_rank})" + ) replaced_active.add(backfill_node.state.active_rank) backfill_node.state.active_rank = terminated_node.state.active_rank else: @@ -616,23 +647,41 @@ def replace_with_backfill(self, unhandled_terminations): def shift_ranks(self, replaced_active, unhandled_terminations): sorted_replaced_active = sorted(replaced_active) + log = logging.getLogger(__name__) for n in self.rank_map.values(): n.state.active_world_size -= len(unhandled_terminations) if n.state.active_rank is not None: + old_active_rank = n.state.active_rank count_less = bisect.bisect_left(sorted_replaced_active, n.state.active_rank) n.state.active_rank -= count_less + # If this rank's active_rank shifted, log it + if ( + self.current_state.initial_rank == n.state.initial_rank + and old_active_rank != n.state.active_rank + ): + log.info( + f"[In-process] Rank shifted (initial_rank={n.state.initial_rank}, active_rank changed from {old_active_rank} to {n.state.active_rank})" + ) + def filter_active_world_size(self): active_world_size = next(iter(self.rank_map.values())).state.active_world_size new_active_world_size = self.world_size_filter(active_world_size) assert new_active_world_size <= active_world_size + log = logging.getLogger(__name__) for leaf in self.tree.iter_leaves(): leaf.state.active_world_size = new_active_world_size if leaf.state.mode == Mode.ACTIVE and leaf.state.active_rank >= new_active_world_size: + # If this rank is being deactivated due to world size filter, log it + if self.current_state.initial_rank == leaf.state.initial_rank: + log.info( + f"[In-process] Rank deactivated (initial_rank={leaf.state.initial_rank}, " + f"active_rank={leaf.state.active_rank}) due to world size filter (>= {new_active_world_size})" + ) leaf.deactivate() def recompute_rank(self): @@ -667,6 +716,9 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx: store = ctx.store terminated_ranks = ctx.terminated_ranks + # Store the current rank's initial_rank for logging purposes + self.current_state = state.freeze() + if self.tree is None: self.build_tree(state, store) @@ -685,7 +737,14 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx: terminated_active_ranks = set( rank for rank in terminated_ranks if self.rank_map[rank].state.mode == Mode.ACTIVE ) + + log = logging.getLogger(__name__) for terminated_rank in terminated_ranks: + # If this rank is being terminated, log it + if self.current_state.initial_rank == self.rank_map[terminated_rank].state.initial_rank: + log.info( + f"[In-process] Rank terminated (initial_rank={self.current_state.initial_rank})" + ) self.rank_map[terminated_rank].terminate() replaced_terminate_active_ranks = self.replace_with_inactive(terminated_active_ranks) @@ -749,7 +808,12 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx: terminated_ranks = utils.format_rank_set(terminated_ranks) raise RankDiscarded(f'{rank=} {terminated_ranks=}') elif rank >= world_size: + log = logging.getLogger(__name__) + old_rank = rank rank = ordered_terminated_ranks[rank - world_size] + log.info( + f"[In-process] Rank reassigned (rank changed from {old_rank} to {rank}) to fill gap" + ) state = dataclasses.replace( state, @@ -802,7 +866,11 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx: terminated_ranks = utils.format_rank_set(terminated_ranks) raise RankDiscarded(f'{rank=} {terminated_ranks=}') else: + old_rank = rank rank = rank - sum(rank > terminated_rank for terminated_rank in terminated_ranks) + if old_rank != rank: + log = logging.getLogger(__name__) + log.info(f"[In-process] Rank shifted (rank changed from {old_rank} to {rank})") state = dataclasses.replace( state, @@ -912,7 +980,12 @@ def __call__(self, ctx: RankAssignmentCtx) -> RankAssignmentCtx: timeout=self.timeout, ) - if not self.condition(int(store.get(prefixed_key))): + group_count = int(store.get(prefixed_key)) + if not self.condition(group_count): + log = logging.getLogger(__name__) + log.info( + f"[In-process] Rank marked for termination (rank={rank}, group_key={key}, group_count={group_count}) due to failed group condition" + ) store.append(RANKS_TO_TERMINATE, f'{rank},') store.barrier( diff --git a/tests/checkpointing/unit/test_async_save.py b/tests/checkpointing/unit/test_async_save.py index d731afc7..ef8c5652 100644 --- a/tests/checkpointing/unit/test_async_save.py +++ b/tests/checkpointing/unit/test_async_save.py @@ -14,6 +14,7 @@ # limitations under the License. import torch +from nvidia_resiliency_ext.checkpointing.async_ckpt.core import abort_nvrx_checkpoint from nvidia_resiliency_ext.checkpointing.async_ckpt.torch_ckpt import TorchAsyncCheckpoint from . import TempNamedDir @@ -57,3 +58,65 @@ def test_async_is_equivalent_to_sync(self, tmp_path_dist_ckpt): assert torch.equal( loaded_async_state_dict[k], state_dict[k] ), f"loaded_async_state_dict[{k}] != src_state_dict[{k}]" + ckpt_impl.close() + + + + def test_persistent_async_cp_abort(self, tmp_path_dist_ckpt): + Utils.initialize_distributed() + model = TestModel((1024, 1024), 10) + ckpt_impl = TorchAsyncCheckpoint(persistent_queue=True) + state_dict = model.state_dict() + + with ( + TempNamedDir(tmp_path_dist_ckpt / 'test_equivalence_async') as async_ckpt_dir, + TempNamedDir(tmp_path_dist_ckpt / 'test_equivalence_sync') as sync_ckpt_dir, + ): + # Save Sync CP state for reference + ckpt_impl.save(state_dict, sync_ckpt_dir / 'test') + + # Save and finalize async CP + ckpt_impl.async_save(state_dict, async_ckpt_dir / 'test') + ckpt_impl.finalize_async_save(blocking=True) + + # Validate that NVRx CP workers are initialized + async_calls_queue = ckpt_impl._get_async_calls_queue() + assert async_calls_queue is not None, "After saving async CP, we expect valid object" + async_process = async_calls_queue._get_async_caller() + assert async_process is not None, "After a valid CP save, we expect async process to be running" + assert async_process._debug_is_async_process_running(), "Valid async process expected" + + # Abort the CP workers to mock the action of inprocess restarts + abort_nvrx_checkpoint() + + # Validate clean-up of NVrx CP workers is done + async_calls_queue = ckpt_impl._get_async_calls_queue() + assert async_calls_queue is not None, "We expect a valid state of AsyncCallsQueue" + async_process = async_calls_queue._get_async_caller() + if async_process is not None: + assert async_process._debug_is_async_process_running() is False, "After abort async process stops" + + # Re-start CP process by doing another async CP state. + ckpt_impl.async_save(state_dict, async_ckpt_dir / 'test') + ckpt_impl.finalize_async_save(blocking=True) + + # Validate that NVRx CP workers are initialized + async_calls_queue = ckpt_impl._get_async_calls_queue() + assert async_calls_queue is not None, "After saving async CP, we expect valid object" + async_process = async_calls_queue._get_async_caller() + assert async_process is not None, "After a valid CP save, we expect async process to be running" + assert async_process._debug_is_async_process_running(), "Valid async process expected" + + # load and compare the re-started async-cp state with the reference sync CP + device = torch.device(f"cuda:{torch.cuda.current_device()}") + loaded_async_state_dict = torch.load(async_ckpt_dir / 'test', map_location=device) + loaded_sync_state_dict = torch.load(sync_ckpt_dir / 'test', map_location=device) + for k in loaded_sync_state_dict.keys(): + assert k in loaded_async_state_dict, f"{k} is not in loaded async state_dict" + assert torch.equal( + loaded_async_state_dict[k], loaded_sync_state_dict[k] + ), f"loaded_async_state_dict[{k}] != loaded_sync_state_dict[{k}]" + assert torch.equal( + loaded_async_state_dict[k], state_dict[k] + ), f"loaded_async_state_dict[{k}] != src_state_dict[{k}]" + ckpt_impl.close()