Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,30 @@ def _compute_data_structure_key_from_plan(items: List[WriteItem]) -> str:
return hashlib.sha256(str(structure_info).encode()).hexdigest()


def _compute_tensor_data_ptrs(items: List[WriteItem], tensors: List[Any]) -> Tuple[Tuple, ...]:
"""Compute a storage identity fingerprint for tensors cached via CUDA IPC."""
ptrs = []
Comment on lines +119 to +121

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Silent truncation in zip over items/tensors

zip(items, tensors) silently stops at the shorter of the two sequences if their lengths ever differ. Because gpu_items and gpu_data are always produced together by separate_cacheable, they should always be the same length in practice — but an assertion would make this contract explicit and catch future regressions early.

Suggested change
def _compute_tensor_data_ptrs(items: List[WriteItem], tensors: List[Any]) -> Tuple[Tuple, ...]:
"""Compute a storage identity fingerprint for tensors cached via CUDA IPC."""
ptrs = []
assert len(items) == len(tensors), (
f"items and tensors must be the same length, got {len(items)} vs {len(tensors)}"
)
ptrs = []
for item, tensor in zip(items, tensors):

for item, tensor in zip(items, tensors):
chunk = item.tensor_data.chunk if item.tensor_data is not None else None
chunk_info = (tuple(chunk.offsets), tuple(chunk.sizes)) if chunk is not None else None

if isinstance(tensor, torch.Tensor):
storage = tensor.untyped_storage()
tensor_info = (
str(tensor.device),
tensor.data_ptr(),
storage.data_ptr(),
tensor.storage_offset(),
tuple(tensor.size()),
tuple(tensor.stride()),
str(tensor.dtype),
)
else:
tensor_info = (type(tensor).__qualname__,)
ptrs.append((item.index.fqn, chunk_info, tensor_info))
return tuple(ptrs)


@_disable_gc()
def get_write_results_queue(mp_mode: str = 'spawn') -> mp.Queue:
"""Get or create a multiprocessing queue for write results.
Expand Down Expand Up @@ -160,6 +184,11 @@ class FileSystemWriterAsync(FileSystemWriter):
# Class-level cache to track identifiers that have been sent to worker across instances
_cached_identifiers: set = set()

# Training-side CUDA IPC tensor fingerprint cache.
# Key: SHA-256 of plan items (same as ConsistentDataIdentifier).
# Value: data/storage pointer fingerprint for the GPU tensors sent to the worker.
_cached_tensor_data_ptrs: ClassVar[Dict[str, Tuple[Tuple, ...]]] = {}

# Training-side shm tensor cache: reuses allocations across checkpoints.
# Only populated when use_cpu_shm_for_gpu_tensors=True AND use_cached_data_structure=True.
# Key: SHA-256 of plan items (same as ConsistentDataIdentifier).
Expand Down Expand Up @@ -404,17 +433,27 @@ def separate_cacheable(items, resolved_data, dequantized_flags, include_dequanti
)
elif cache_exists:
# --- original GPU IPC path, reuse ---
current_data_ptrs = _compute_tensor_data_ptrs(gpu_items, gpu_data)
cache_data_is_current = (
FileSystemWriterAsync._cached_tensor_data_ptrs.get(key) == current_data_ptrs
)
self.consistent_data_identifier = ConsistentDataIdentifier(key)
self.cached_tensor_data = None # Signal to reuse cached data
self.cached_tensor_data = None if cache_data_is_current else (gpu_items, gpu_data)
if not cache_data_is_current:
FileSystemWriterAsync._cached_tensor_data_ptrs[key] = current_data_ptrs
cache_action = "Reusing" if cache_data_is_current else "Refreshing"
logger.debug(
f"Reusing cached GPU tensors (key={key}), "
f"resolved {len(uncached_items)} uncached tensors fresh"
f"{cache_action} cached GPU tensors (key={key}), "
f"{len(uncached_items)} uncached tensors passed fresh"
)
elif gpu_items:
# --- original GPU IPC path, first time ---
self.consistent_data_identifier = ConsistentDataIdentifier(key)
self.cached_tensor_data = (gpu_items, gpu_data)
FileSystemWriterAsync._cached_identifiers.add(key)
FileSystemWriterAsync._cached_tensor_data_ptrs[key] = _compute_tensor_data_ptrs(
gpu_items, gpu_data
)
logger.debug(
f"Caching {len(gpu_items)} GPU tensors (key={key}), "
f"{len(uncached_items)} uncached tensors passed fresh"
Expand Down Expand Up @@ -493,6 +532,7 @@ def cleanup_tensor_caches(cls) -> None:
logger.info(f"Clearing shm tensor cache ({len(cls._shm_tensor_cache)} entries)")
cls._shm_tensor_cache.clear()
cls._cached_identifiers.clear()
cls._cached_tensor_data_ptrs.clear()

@classmethod
def register_shm_drain_callback(cls, fn: Optional[Callable[[], None]]) -> None:
Expand Down Expand Up @@ -1195,6 +1235,9 @@ def retrieve_write_results(self) -> Union[List[WriteResult], WRAPPED_EXCEPTION]:
FileSystemWriterAsync._cached_identifiers.discard(
self.consistent_data_identifier.key
)
FileSystemWriterAsync._cached_tensor_data_ptrs.pop(
self.consistent_data_identifier.key, None
)
try:
raise RuntimeError(
f'Worker failure: {write_results_or_exc}'
Expand Down
58 changes: 58 additions & 0 deletions tests/checkpointing/unit/test_async_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def test_cached_data_structure(self, tmp_path_dist_ckpt):
# reset inside the worker. _cached_identifiers lives in the main process and is
# NOT cleared by close(), so clear it here to avoid cross-test contamination.
FileSystemWriterAsync._cached_identifiers.clear()
FileSystemWriterAsync._cached_tensor_data_ptrs.clear()
async_queue = AsyncCallsQueue(persistent=True, is_daemon=True)

model = FSDP(Model((1024, 1024), 8))
Expand Down Expand Up @@ -329,6 +330,62 @@ def test_cached_data_structure(self, tmp_path_dist_ckpt):
ckpt_dir.cleanup()
async_queue.close()

def test_cached_data_structure_refreshes_when_tensor_data_ptr_changes(self, tmp_path_dist_ckpt):
"""
Verifies that the CUDA IPC cache is refreshed when the logical checkpoint
structure is unchanged but source tensors have been reallocated.
"""
Utils.initialize_distributed()

FileSystemWriterAsync._cached_identifiers.clear()
FileSystemWriterAsync._cached_tensor_data_ptrs.clear()
async_queue = AsyncCallsQueue(persistent=True, is_daemon=True)

model = FSDP(Model((1024, 1024), 8))
base_state_dict = model.state_dict()
planner = DefaultSavePlanner()
tracked_key = next(
key for key, value in base_state_dict.items() if isinstance(value, torch.Tensor)
)

state_dicts = []
data_ptrs = []
last_ckpt_dir = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Test correctness relies on implicit memory-reuse prevention

state_dicts.append(state_dict) keeps all three state dicts (and their tensors) alive throughout the loop, which prevents the CUDA allocator from recycling a freed address and producing a coincidental pointer match. This invariant is load-bearing for the len(set(data_ptrs)) == 3 assertion: if a state dict were dropped before the end of the loop, the allocator could reuse the address and collapse two entries to the same pointer, making the assertion meaningless. A short comment here would document the intent and guard against accidental simplification later.

for i in range(3):
state_dict = {
key: (
torch.full_like(value, float(i)) if isinstance(value, torch.Tensor) else value
)
for key, value in base_state_dict.items()
}
state_dicts.append(state_dict)
data_ptrs.append(state_dict[tracked_key].data_ptr())

ckpt_dir = TempNamedDir(tmp_path_dist_ckpt / f'cuda_ipc_ptr_ckpt_{i}', sync=True)
self.async_save_checkpoint(
ckpt_dir,
state_dict,
planner,
async_queue,
use_cached_data_structure=True,
)
async_queue.maybe_finalize_async_calls(blocking=True, no_dist=False)
if last_ckpt_dir is not None:
last_ckpt_dir.cleanup()
last_ckpt_dir = ckpt_dir

assert len(set(data_ptrs)) == 3

loaded = self.load_checkpoint(last_ckpt_dir, deepcopy(state_dicts[-1]))
for key, tensor in loaded.items():
assert torch.all(tensor.cpu() == 2.0), (
f"Key '{key}': expected 2.0 from refreshed CUDA IPC cache, "
f"got unique values {tensor.cpu().unique().tolist()}"
)

last_ckpt_dir.cleanup()
async_queue.close()

def test_cpu_shm_for_gpu_tensors(self, tmp_path_dist_ckpt):
"""CPU shm path: D2H done in training process, worker streams from shm.

Expand All @@ -341,6 +398,7 @@ def test_cpu_shm_for_gpu_tensors(self, tmp_path_dist_ckpt):

# Clear class-level caches to avoid cross-test contamination
FileSystemWriterAsync._cached_identifiers.clear()
FileSystemWriterAsync._cached_tensor_data_ptrs.clear()
FileSystemWriterAsync._shm_tensor_cache.clear()
async_queue = AsyncCallsQueue(persistent=True, is_daemon=True)

Expand Down
Loading