From 54485006b355c5955129a8b8fbf0294a20e3126e Mon Sep 17 00:00:00 2001 From: Zaz Brown Date: Sat, 11 Apr 2026 13:19:43 +0000 Subject: [PATCH 1/7] fix: add weights_only param to torch_load Add a `weights_only` parameter to `fs.torch_load` so callers can control the security-sensitive fallback behavior: - `weights_only=None` (default): preserve current fallback behavior - `weights_only=True`: strict mode, raise on failure (no fallback) - `weights_only=False`: skip the weights_only=True attempt entirely The default behavior is unchanged, keeping full backward compatibility. --- torch_geometric/io/fs.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/torch_geometric/io/fs.py b/torch_geometric/io/fs.py index c88a04ad5a9a..b666b1ef4320 100644 --- a/torch_geometric/io/fs.py +++ b/torch_geometric/io/fs.py @@ -214,8 +214,38 @@ def torch_save(data: Any, path: str) -> None: f.write(buffer.getvalue()) -def torch_load(path: str, map_location: Any = None) -> Any: +def torch_load( + path: str, + map_location: Any = None, + weights_only: Optional[bool] = None, +) -> Any: + r"""Load a PyTorch file from a given path using :func:`fsspec`. + + Args: + path (str): The path to the file to load. + map_location: A simplified version of :attr:`torch.load`'s + :attr:`map_location`. + weights_only (bool, optional): If :obj:`True`, only weights will be + loaded and an error will be raised on failure (*i.e.*, no + fallback to :obj:`weights_only=False`). + If :obj:`False`, :obj:`weights_only=False` will be used + directly. + If :obj:`None` (default), the current fallback behavior is + preserved: first tries :obj:`weights_only=True`, then falls + back to :obj:`weights_only=False` on + :class:`~pickle.UnpicklingError`. (default: :obj:`None`) + """ if torch_geometric.typing.WITH_PT24: + if weights_only is True: + with fsspec.open(path, 'rb') as f: + return torch.load(f, map_location, weights_only=True) + + if weights_only is False: + with fsspec.open(path, 'rb') as f: + return torch.load(f, map_location, weights_only=False) + + # Default behavior (weights_only=None): try weights_only=True + # first, then fall back to weights_only=False on failure. try: with fsspec.open(path, 'rb') as f: return torch.load(f, map_location, weights_only=True) From 38f6e88b9f9707acb911e3b6e66881dbd0ae0636 Mon Sep 17 00:00:00 2001 From: Zaz Brown Date: Sat, 11 Apr 2026 14:33:17 +0000 Subject: [PATCH 2/7] fix: add trust caveat to safe_globals warn --- torch_geometric/io/fs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_geometric/io/fs.py b/torch_geometric/io/fs.py index b666b1ef4320..ace75ef6fdf3 100644 --- a/torch_geometric/io/fs.py +++ b/torch_geometric/io/fs.py @@ -260,7 +260,8 @@ def torch_load( warnings.warn( f"{warn_msg} Please use " f"`torch.serialization.{match.group()}` to " - f"allowlist this global.", stacklevel=2) + f"allowlist this global if you trust it.", + stacklevel=2) else: warnings.warn(warn_msg, stacklevel=2) From 89fea339761926a8fcd9e01778a9639d16f5e1c1 Mon Sep 17 00:00:00 2001 From: Zaz Brown Date: Sat, 11 Apr 2026 14:33:39 +0000 Subject: [PATCH 3/7] feat: deprecate weights_only fallback Emit a FutureWarning when the weights_only=True attempt fails and torch_load silently falls back to weights_only=False. In a future release this fallback will be removed - callers should pass an explicit weights_only value. --- torch_geometric/io/fs.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/torch_geometric/io/fs.py b/torch_geometric/io/fs.py index ace75ef6fdf3..4f68bde27883 100644 --- a/torch_geometric/io/fs.py +++ b/torch_geometric/io/fs.py @@ -265,6 +265,16 @@ def torch_load( else: warnings.warn(warn_msg, stacklevel=2) + warnings.warn( + "Falling back to `weights_only=False` because the file at" + f"'{path}' could not be loaded with `weights_only=True`." + "In a future release, this fallback will be removed. Pass " + "`weights_only=False` explicitly if you need to load " + "custom Python objects, allowlist doesn't work, and you" + "trust the source.", + FutureWarning, + stacklevel=2, + ) with fsspec.open(path, 'rb') as f: return torch.load(f, map_location, weights_only=False) else: From fda09d3eed74b48cad6aa2ccce158dfd6a8545eb Mon Sep 17 00:00:00 2001 From: Zaz Brown Date: Sat, 11 Apr 2026 14:36:25 +0000 Subject: [PATCH 4/7] test: add test for weights_only fallback warning --- test/io/test_fs.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/test/io/test_fs.py b/test/io/test_fs.py index d9227c6a230b..e33c454e1fe1 100644 --- a/test/io/test_fs.py +++ b/test/io/test_fs.py @@ -128,5 +128,18 @@ def test_torch_save_load(tmp_fs_path): path = osp.join(tmp_fs_path, 'x.pt') fs.torch_save(x, path) - out = fs.torch_load(path) + out = fs.torch_load(path, weights_only=True) assert torch.equal(x, out) + + +@pytest.mark.skipif( + not torch_geometric.typing.WITH_PT24, + reason='weights_only requires PyTorch >= 2.4', +) +def test_torch_load_fallback_warning(tmp_fs_path): + # Save an object that cannot be loaded with weights_only=True: + path = osp.join(tmp_fs_path, 'data.pt') + fs.torch_save(object(), path) + + with pytest.warns(FutureWarning, match='weights_only'): + fs.torch_load(path) From f70303f9e474343c409ce3c383a39bb3fe1c204b Mon Sep 17 00:00:00 2001 From: Zaz Brown Date: Sat, 11 Apr 2026 16:45:57 +0000 Subject: [PATCH 5/7] test: cover weights_only=False path --- test/io/test_fs.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/io/test_fs.py b/test/io/test_fs.py index e33c454e1fe1..e39ad880dc3f 100644 --- a/test/io/test_fs.py +++ b/test/io/test_fs.py @@ -143,3 +143,14 @@ def test_torch_load_fallback_warning(tmp_fs_path): with pytest.warns(FutureWarning, match='weights_only'): fs.torch_load(path) + + +@pytest.mark.skipif( + not torch_geometric.typing.WITH_PT24, + reason='weights_only requires PyTorch >= 2.4', +) +def test_torch_load_weights_only_false(tmp_fs_path): + path = osp.join(tmp_fs_path, 'data.pt') + fs.torch_save(object(), path) + out = fs.torch_load(path, weights_only=False) + assert isinstance(out, object) From 3d262f393505f944562a5946feb4b442ea4f9881 Mon Sep 17 00:00:00 2001 From: Zaz Brown Date: Sat, 11 Apr 2026 16:46:13 +0000 Subject: [PATCH 6/7] docs: add changelog for weights_only --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f4125aab5b53..a0c12305cdcd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,12 +7,15 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). ### Added +- Added `weights_only` parameter to `torch_load` for explicit control over safe deserialization ([#10666](https://github.com/pyg-team/pytorch_geometric/pull/10666)) + ### Changed - Dropped support for TorchScript in `GATConv` and `GATv2Conv` for correctness ([#10596](https://github.com/pyg-team/pytorch_geometric/pull/10596)) ### Deprecated +- Deprecated the implicit `weights_only=False` fallback in `torch_load`; pass `weights_only` explicitly ([#10666](https://github.com/pyg-team/pytorch_geometric/pull/10666)) - Deprecated support for `torch-spline-conv` in favor of `pyg-lib>=0.6.0` ([#10622](https://github.com/pyg-team/pytorch_geometric/pull/10622)) ### Removed From ffdc5357fcb55b169719fde429eec533eb043e60 Mon Sep 17 00:00:00 2001 From: Zaz Brown Date: Sat, 11 Apr 2026 17:00:42 +0000 Subject: [PATCH 7/7] feat: pass weights_only through InMemoryDataset.load --- torch_geometric/data/in_memory_dataset.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torch_geometric/data/in_memory_dataset.py b/torch_geometric/data/in_memory_dataset.py index bcfcba02ee2e..c186250f45c6 100644 --- a/torch_geometric/data/in_memory_dataset.py +++ b/torch_geometric/data/in_memory_dataset.py @@ -126,9 +126,14 @@ def save(cls, data_list: Sequence[BaseData], path: str) -> None: data, slices = cls.collate(data_list) fs.torch_save((data.to_dict(), slices, data.__class__), path) - def load(self, path: str, data_cls: Type[BaseData] = Data) -> None: + def load( + self, + path: str, + data_cls: Type[BaseData] = Data, + weights_only: Optional[bool] = None, + ) -> None: r"""Loads the dataset from the file path :obj:`path`.""" - out = fs.torch_load(path) + out = fs.torch_load(path, weights_only=weights_only) assert isinstance(out, tuple) assert len(out) == 2 or len(out) == 3 if len(out) == 2: # Backward compatibility.