diff --git a/CHANGELOG.md b/CHANGELOG.md index 1556c9f8a1..0da2a62fea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Security +- Fixes a critical logic inversion in path traversal protection during checkpoint extraction. +- Enables `weights_only=True` in `torch.load` calls for enhanced security when loading model weights. Note: This may be a breaking change for legacy checkpoints containing custom Python objects not in the allowed list. + ### Dependencies ## [2.0.0] - 2026-XX-YY diff --git a/physicsnemo/core/module.py b/physicsnemo/core/module.py index ac6d9b0f55..e54d24696c 100644 --- a/physicsnemo/core/module.py +++ b/physicsnemo/core/module.py @@ -259,17 +259,21 @@ def _setup_logger(self): @staticmethod def _safe_members(tar, local_path): + resolved_local_path = os.path.join(os.path.realpath(local_path), "") for member in tar.getmembers(): if ( ".." in member.name or os.path.isabs(member.name) - or os.path.realpath(os.path.join(local_path, member.name)).startswith( - os.path.realpath(local_path) + or not os.path.realpath(os.path.join(local_path, member.name)).startswith( + resolved_local_path + os.sep ) + and os.path.realpath(os.path.join(local_path, member.name)) != resolved_local_path ): - yield member + logging.getLogger("core.module").warning( + f"Skipping potentially malicious file: {member.name}" + ) else: - print(f"Skipping potentially malicious file: {member.name}") + yield member @classmethod def _backward_compat_arg_mapper( @@ -747,7 +751,9 @@ def load( model_bytes = archive.read("model.pt") # Load state dict after closing archive - model_dict = torch.load(io.BytesIO(model_bytes), map_location=device) + model_dict = torch.load( + io.BytesIO(model_bytes), map_location=device, weights_only=True + ) # Load state_dict into the model _load_state_dict_with_logging(self, model_dict, strict=strict) @@ -773,7 +779,9 @@ def load( # Load the model weights model_dict = torch.load( - local_path.joinpath("model.pt"), map_location=device + local_path.joinpath("model.pt"), + map_location=device, + weights_only=True, ) # Load state dict into the model @@ -1054,7 +1062,9 @@ def _from_checkpoint_process( model_bytes = archive.read("model.pt") # Load state dict after closing archive - model_dict = torch.load(io.BytesIO(model_bytes), map_location=model.device) + model_dict = torch.load( + io.BytesIO(model_bytes), map_location=model.device, weights_only=True + ) # Load state_dict into the model _load_state_dict_with_logging(model, model_dict, strict=strict) @@ -1096,7 +1106,9 @@ def _from_checkpoint_process( # Load the model weights model_dict = torch.load( - local_path.joinpath("model.pt"), map_location=model.device + local_path.joinpath("model.pt"), + map_location=model.device, + weights_only=True, ) # Load state_dict into the model