Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).

### Added

- Added `weights_only` parameter to `torch_load` for explicit control over safe deserialization ([#10666](https://github.com/pyg-team/pytorch_geometric/pull/10666))

### Changed

- Dropped support for TorchScript in `GATConv` and `GATv2Conv` for correctness ([#10596](https://github.com/pyg-team/pytorch_geometric/pull/10596))

### Deprecated

- Deprecated the implicit `weights_only=False` fallback in `torch_load`; pass `weights_only` explicitly ([#10666](https://github.com/pyg-team/pytorch_geometric/pull/10666))
- Deprecated support for `torch-spline-conv` in favor of `pyg-lib>=0.6.0` ([#10622](https://github.com/pyg-team/pytorch_geometric/pull/10622))

### Removed
Expand Down
26 changes: 25 additions & 1 deletion test/io/test_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,5 +128,29 @@ def test_torch_save_load(tmp_fs_path):
path = osp.join(tmp_fs_path, 'x.pt')

fs.torch_save(x, path)
out = fs.torch_load(path)
out = fs.torch_load(path, weights_only=True)
assert torch.equal(x, out)


@pytest.mark.skipif(
not torch_geometric.typing.WITH_PT24,
reason='weights_only requires PyTorch >= 2.4',
)
def test_torch_load_fallback_warning(tmp_fs_path):
# Save an object that cannot be loaded with weights_only=True:
path = osp.join(tmp_fs_path, 'data.pt')
fs.torch_save(object(), path)

with pytest.warns(FutureWarning, match='weights_only'):
fs.torch_load(path)


@pytest.mark.skipif(
not torch_geometric.typing.WITH_PT24,
reason='weights_only requires PyTorch >= 2.4',
)
def test_torch_load_weights_only_false(tmp_fs_path):
path = osp.join(tmp_fs_path, 'data.pt')
fs.torch_save(object(), path)
out = fs.torch_load(path, weights_only=False)
assert isinstance(out, object)
9 changes: 7 additions & 2 deletions torch_geometric/data/in_memory_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,14 @@ def save(cls, data_list: Sequence[BaseData], path: str) -> None:
data, slices = cls.collate(data_list)
fs.torch_save((data.to_dict(), slices, data.__class__), path)

def load(self, path: str, data_cls: Type[BaseData] = Data) -> None:
def load(
self,
path: str,
data_cls: Type[BaseData] = Data,
weights_only: Optional[bool] = None,
) -> None:
r"""Loads the dataset from the file path :obj:`path`."""
out = fs.torch_load(path)
out = fs.torch_load(path, weights_only=weights_only)
assert isinstance(out, tuple)
assert len(out) == 2 or len(out) == 3
if len(out) == 2: # Backward compatibility.
Expand Down
45 changes: 43 additions & 2 deletions torch_geometric/io/fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,38 @@ def torch_save(data: Any, path: str) -> None:
f.write(buffer.getvalue())


def torch_load(path: str, map_location: Any = None) -> Any:
def torch_load(
path: str,
map_location: Any = None,
weights_only: Optional[bool] = None,
) -> Any:
r"""Load a PyTorch file from a given path using :func:`fsspec`.

Args:
path (str): The path to the file to load.
map_location: A simplified version of :attr:`torch.load`'s
:attr:`map_location`.
weights_only (bool, optional): If :obj:`True`, only weights will be
loaded and an error will be raised on failure (*i.e.*, no
fallback to :obj:`weights_only=False`).
If :obj:`False`, :obj:`weights_only=False` will be used
directly.
If :obj:`None` (default), the current fallback behavior is
preserved: first tries :obj:`weights_only=True`, then falls
back to :obj:`weights_only=False` on
:class:`~pickle.UnpicklingError`. (default: :obj:`None`)
"""
if torch_geometric.typing.WITH_PT24:
if weights_only is True:
with fsspec.open(path, 'rb') as f:
return torch.load(f, map_location, weights_only=True)

if weights_only is False:
with fsspec.open(path, 'rb') as f:
return torch.load(f, map_location, weights_only=False)

# Default behavior (weights_only=None): try weights_only=True
# first, then fall back to weights_only=False on failure.
try:
with fsspec.open(path, 'rb') as f:
return torch.load(f, map_location, weights_only=True)
Expand All @@ -230,10 +260,21 @@ def torch_load(path: str, map_location: Any = None) -> Any:
warnings.warn(
f"{warn_msg} Please use "
f"`torch.serialization.{match.group()}` to "
f"allowlist this global.", stacklevel=2)
f"allowlist this global if you trust it.",
stacklevel=2)
else:
warnings.warn(warn_msg, stacklevel=2)

warnings.warn(
"Falling back to `weights_only=False` because the file at"
f"'{path}' could not be loaded with `weights_only=True`."
"In a future release, this fallback will be removed. Pass "
"`weights_only=False` explicitly if you need to load "
"custom Python objects, allowlist doesn't work, and you"
"trust the source.",
FutureWarning,
stacklevel=2,
)
with fsspec.open(path, 'rb') as f:
return torch.load(f, map_location, weights_only=False)
else:
Expand Down