-
Notifications
You must be signed in to change notification settings - Fork 53
Refresh async checkpoint IPC cache on pointer change #314
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -299,6 +299,7 @@ def test_cached_data_structure(self, tmp_path_dist_ckpt): | |
| # reset inside the worker. _cached_identifiers lives in the main process and is | ||
| # NOT cleared by close(), so clear it here to avoid cross-test contamination. | ||
| FileSystemWriterAsync._cached_identifiers.clear() | ||
| FileSystemWriterAsync._cached_tensor_data_ptrs.clear() | ||
| async_queue = AsyncCallsQueue(persistent=True, is_daemon=True) | ||
|
|
||
| model = FSDP(Model((1024, 1024), 8)) | ||
|
|
@@ -329,6 +330,62 @@ def test_cached_data_structure(self, tmp_path_dist_ckpt): | |
| ckpt_dir.cleanup() | ||
| async_queue.close() | ||
|
|
||
| def test_cached_data_structure_refreshes_when_tensor_data_ptr_changes(self, tmp_path_dist_ckpt): | ||
| """ | ||
| Verifies that the CUDA IPC cache is refreshed when the logical checkpoint | ||
| structure is unchanged but source tensors have been reallocated. | ||
| """ | ||
| Utils.initialize_distributed() | ||
|
|
||
| FileSystemWriterAsync._cached_identifiers.clear() | ||
| FileSystemWriterAsync._cached_tensor_data_ptrs.clear() | ||
| async_queue = AsyncCallsQueue(persistent=True, is_daemon=True) | ||
|
|
||
| model = FSDP(Model((1024, 1024), 8)) | ||
| base_state_dict = model.state_dict() | ||
| planner = DefaultSavePlanner() | ||
| tracked_key = next( | ||
| key for key, value in base_state_dict.items() if isinstance(value, torch.Tensor) | ||
| ) | ||
|
|
||
| state_dicts = [] | ||
| data_ptrs = [] | ||
| last_ckpt_dir = None | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| for i in range(3): | ||
| state_dict = { | ||
| key: ( | ||
| torch.full_like(value, float(i)) if isinstance(value, torch.Tensor) else value | ||
| ) | ||
| for key, value in base_state_dict.items() | ||
| } | ||
| state_dicts.append(state_dict) | ||
| data_ptrs.append(state_dict[tracked_key].data_ptr()) | ||
|
|
||
| ckpt_dir = TempNamedDir(tmp_path_dist_ckpt / f'cuda_ipc_ptr_ckpt_{i}', sync=True) | ||
| self.async_save_checkpoint( | ||
| ckpt_dir, | ||
| state_dict, | ||
| planner, | ||
| async_queue, | ||
| use_cached_data_structure=True, | ||
| ) | ||
| async_queue.maybe_finalize_async_calls(blocking=True, no_dist=False) | ||
| if last_ckpt_dir is not None: | ||
| last_ckpt_dir.cleanup() | ||
| last_ckpt_dir = ckpt_dir | ||
|
|
||
| assert len(set(data_ptrs)) == 3 | ||
|
|
||
| loaded = self.load_checkpoint(last_ckpt_dir, deepcopy(state_dicts[-1])) | ||
| for key, tensor in loaded.items(): | ||
| assert torch.all(tensor.cpu() == 2.0), ( | ||
| f"Key '{key}': expected 2.0 from refreshed CUDA IPC cache, " | ||
| f"got unique values {tensor.cpu().unique().tolist()}" | ||
| ) | ||
|
|
||
| last_ckpt_dir.cleanup() | ||
| async_queue.close() | ||
|
|
||
| def test_cpu_shm_for_gpu_tensors(self, tmp_path_dist_ckpt): | ||
| """CPU shm path: D2H done in training process, worker streams from shm. | ||
|
|
||
|
|
@@ -341,6 +398,7 @@ def test_cpu_shm_for_gpu_tensors(self, tmp_path_dist_ckpt): | |
|
|
||
| # Clear class-level caches to avoid cross-test contamination | ||
| FileSystemWriterAsync._cached_identifiers.clear() | ||
| FileSystemWriterAsync._cached_tensor_data_ptrs.clear() | ||
| FileSystemWriterAsync._shm_tensor_cache.clear() | ||
| async_queue = AsyncCallsQueue(persistent=True, is_daemon=True) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
zipoveritems/tensorszip(items, tensors)silently stops at the shorter of the two sequences if their lengths ever differ. Becausegpu_itemsandgpu_dataare always produced together byseparate_cacheable, they should always be the same length in practice — but an assertion would make this contract explicit and catch future regressions early.