Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 101 additions & 15 deletions src/nvidia_resiliency_ext/checkpointing/async_ckpt/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"""

import logging
from abc import ABC, abstractmethod
from abc import ABCMeta, ABC, abstractmethod
from collections import deque
from queue import Empty
from time import sleep, time
Expand Down Expand Up @@ -129,6 +129,21 @@ def execute_finalize_fns(self, validate_matching_call_idx: bool = True) -> int:
"That probably means not all ranks are participating in async finalization"
return self.call_idx

# Singleton metaclass
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]

def clear(cls):
if cls in cls._instances:
del cls._instances[cls]

# class SingletonABCMeta(Singleton, ABCMeta):
# pass


class AsyncCaller(ABC):
"""Wrapper around mp.Process that ensures correct semantic of distributed finalization.
Expand Down Expand Up @@ -192,8 +207,13 @@ def sync_all_async_calls(self, is_alive: int) -> bool:
return ten[0] == 0

@abstractmethod
def close(self):
"""Terminate the async caller at exit of an application or some termination conditions"""
def close(self, abort=False):
"""Terminate the async caller at exit of an application or some termination conditions

Args:
abort (bool, optional): Default to False. Needs to be manually set to true when
the checkpoint async process needs to be aborted.
"""
logger.info(f"AsyncCaller: {torch.distributed.get_rank()}, Destroying Async Caller")

@abstractmethod
Expand Down Expand Up @@ -283,15 +303,23 @@ def is_current_async_call_done(self, blocking: bool = False, no_dist: bool = Fal
is_done = True
return is_done

def close(self):
def close(self, abort=False):
"""For TemporalAsyncCaller, this method is called explictly in `is_current_async_calls_done`

This method make sure the TemporalAsyncCaller terminated
with all its assigned async request completed

Args:
abort (bool, optional): Default to False. Needs to be manually set to true when
the checkpoint async process needs to be aborted.
"""
if self.process:
logger.debug(f"rank: {self.rank}, joining self.process")
self.process.join()
if abort:
logger.warning(f"Temporal worker aborted in rank {torch.distributed.get_rank()}")
self.process.kill()
else:
self.process.join()
self.process = None
logger.debug(
"TemporalAsyncCaller: Async process join finished "
Expand All @@ -302,6 +330,14 @@ def close(self):
def __del__(self):
pass

def _debug_is_async_process_running(self):
"""
For unit test purpose
"""
if self.process is None:
return False
return self.process.is_alive()


class PersistentAsyncCaller(AsyncCaller):
"""Wrapper around mp.Process that ensures correct semantic of distributed finalization.
Expand Down Expand Up @@ -360,7 +396,7 @@ def schedule_async_call(self, async_req: AsyncRequest) -> None:
),
)
self.process.start()
logger.info(f"PersistentAsyncCaller: {self.rank}, Started Async Caller")
logger.error(f" Aarti PersistentAsyncCaller: {self.rank}, Started Async Caller {self.process}")

if async_req.preload_fn:
self.preload_q.put(async_req.call_idx)
Expand Down Expand Up @@ -429,21 +465,38 @@ def is_current_async_call_done(self, blocking: bool = False, no_dist: bool = Fal

return is_done

def close(self):
def close(self, abort=False):
"""Wait on the left async requests and terminate the PersistentAsyncCaller

Signals the PersistentAsyncCaller by sending a 'DONE' message to make it terminated

Args:
abort (bool, optional): Default to False. Needs to be manually set to true when
the checkpoint async process needs to be aborted.
"""
logger.info(f"PersistentAsyncCaller: {self.rank}, Destroying Async Caller")
if self.process:
self.queue.put('DONE')
self.queue.join()
self.process.join()
if abort:
logger.error(f"Persistent worker aborted in rank {torch.distributed.get_rank()}")
self.process.kill()
else:
self.queue.put('DONE')
self.queue.join()
self.process.join()

self.process = None

def __del__(self):
self.close()

def _debug_is_async_process_running(self):
"""
For unit test purpose
"""
if self.process is None:
return False
return self.process.is_alive()

@staticmethod
@_disable_gc()
def async_loop(
Expand Down Expand Up @@ -514,7 +567,7 @@ class _ActiveAsyncRequest(NamedTuple):
async_request: AsyncRequest


class AsyncCallsQueue:
class AsyncCallsQueue(metaclass=Singleton):
"""Manages a queue of async calls.

Allows adding a new async call with `schedule_async_request` and finalizing
Expand Down Expand Up @@ -549,13 +602,15 @@ def schedule_async_request(self, async_request: AsyncRequest) -> int:
"""
self.call_idx += 1
async_caller = self._get_async_caller()
logger.error(f"Aarti persistent caller created. But process not yet created persistentObject = {async_caller}")
# Backward compatibility for local checkpointing built with the old AsyncRequest
if len(async_request._fields) != len(AsyncRequest._fields):
async_request = AsyncRequest(**async_request._asdict())
async_request = async_request.freeze()
async_caller.schedule_async_call(
async_request._replace(call_idx=self.call_idx, finalize_fns=[])
)
logger.error(f"Aarti persistent caller created. Now I expect process to be created")
self.async_calls.append(_ActiveAsyncRequest(self.call_idx, async_caller, async_request))
return self.call_idx

Expand Down Expand Up @@ -593,8 +648,39 @@ def get_num_unfinalized_calls(self):
"""Get the number of active async calls."""
return len(self.async_calls)

def close(self):
"""Finalize all calls upon closing."""
self.maybe_finalize_async_calls(blocking=True)
def close(self, abort=False):
"""Finalize all calls upon closing.

Args:
abort (bool, optional): Default to False. Needs to be manually set to true when
the checkpoint async process needs to be aborted.
"""
logger.error(f"Aarti 2222 {self} {abort} {self.persistent} {self.persistent_caller}")
if not abort:
self.maybe_finalize_async_calls(blocking=True)
if self.persistent and self.persistent_caller:
self.persistent_caller.close()
self.persistent_caller.close(abort=abort)
# Reset all class params
self.call_idx = -1
self.persistent_caller = None
# Clear the singleton registory of async worker
Singleton.clear(AsyncCallsQueue)

def abort_nvrx_checkpoint():
# we have a singleton persistent worker in our async calls queue
# close the async calls queue which will clear the singleton object
# to ensure a clean restart


# # When persistent_caller is singleton
# logger.error(f"Aarti 00000 abort_nvrx_checkpoint called")
# persistent_caller = AsyncCallsQueue(persistent=True)._get_async_caller()
# logger.error(f"Aarti 11111 close called on object {persistent_caller}")
# persistent_caller.close(abort=True)

logger.error(f"Aarti 00000 abort_nvrx_checkpoint called")
async_queue_singleton = AsyncCallsQueue(persistent=True)
logger.error(f"Aarti 11111 close called on object {async_queue_singleton}")
async_queue_singleton.close(abort=True)

# TBD: Create singleton for result_queue
13 changes: 13 additions & 0 deletions src/nvidia_resiliency_ext/checkpointing/async_ckpt/torch_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,16 @@ def finalize_async_save(self, blocking: bool = False, no_dist=True, terminate=Fa
self._async_calls_queue.maybe_finalize_async_calls(blocking, no_dist=no_dist)
if terminate:
self._async_calls_queue.close()

def _get_async_calls_queue(self):
"""
Function introduced for unit test purpose to validate the state of the Async workers
"""
return self._async_calls_queue

def close(self):
if self._async_calls_queue is not None:
self._async_calls_queue.close()

def __del__(self):
self.close()
33 changes: 33 additions & 0 deletions src/nvidia_resiliency_ext/inprocess/abort.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch

from nvidia_resiliency_ext.attribution.trace_analyzer.trace_collector import TorchFRTraceCollector
from nvidia_resiliency_ext.checkpointing.async_ckpt.core import AsyncCallsQueue, abort_nvrx_checkpoint

from . import utils
from .state import FrozenState
Expand Down Expand Up @@ -185,3 +186,35 @@ def __call__(self, state: FrozenState) -> FrozenState:
te_fp8.FP8GlobalStateManager.reset()

return state


class AbortCheckpoint(Abort):
r'''
Aborts Async Checkpoint processes

'''
# def __init__(self):
# # Persistent worker once created is a singleton object
# # During abort, we mus reset this object to a clean state
# log = logging.getLogger(__name__)
# log.error(f"Aarti 1111 init AbortCheckpoint ")
# self.async_queue = AsyncCallsQueue(persistent=True)
# log.error(f"Aarti 2222 {self.__dict__}")

def __call__(self, state: FrozenState) -> FrozenState:
# log = logging.getLogger(__name__)
# log.error(f"Aarti 33333 {self.__dict__}")
# self.async_queue = AsyncCallsQueue(persistent=True)
# log.error(f"Aarti 4444 {self.__dict__} and singleton persistent_worker = {self.async_queue.persistent_caller}")
# if self.async_queue is not None:
# self.async_queue.close(abort=True)
# del self.async_queue
# log.error(f"Aarti 5555 {self.__dict__}")
# # global _results_queue
# # if _results_queue is not None:
# # _results_queue._manager.shutdown()
# # del _results_queue
# # _results_queue = None
# # self.async_queue = None
abort_nvrx_checkpoint()
return state
Loading
Loading