Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Security

- Fixes a critical logic inversion in path traversal protection during checkpoint extraction.
- Enables `weights_only=True` in `torch.load` calls for enhanced security when loading model weights. Note: This may be a breaking change for legacy checkpoints containing custom Python objects not in the allowed list.

### Dependencies

## [2.0.0] - 2026-XX-YY
Expand Down
28 changes: 20 additions & 8 deletions physicsnemo/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,17 +259,21 @@ def _setup_logger(self):

@staticmethod
def _safe_members(tar, local_path):
resolved_local_path = os.path.join(os.path.realpath(local_path), "")
for member in tar.getmembers():
if (
".." in member.name
or os.path.isabs(member.name)
or os.path.realpath(os.path.join(local_path, member.name)).startswith(
os.path.realpath(local_path)
or not os.path.realpath(os.path.join(local_path, member.name)).startswith(
resolved_local_path + os.sep
)
Comment thread
RinZ27 marked this conversation as resolved.
and os.path.realpath(os.path.join(local_path, member.name)) != resolved_local_path
):
yield member
logging.getLogger("core.module").warning(
f"Skipping potentially malicious file: {member.name}"
)
else:
print(f"Skipping potentially malicious file: {member.name}")
yield member

@classmethod
def _backward_compat_arg_mapper(
Expand Down Expand Up @@ -747,7 +751,9 @@ def load(
model_bytes = archive.read("model.pt")

# Load state dict after closing archive
model_dict = torch.load(io.BytesIO(model_bytes), map_location=device)
model_dict = torch.load(
io.BytesIO(model_bytes), map_location=device, weights_only=True
)
Comment thread
RinZ27 marked this conversation as resolved.

# Load state_dict into the model
_load_state_dict_with_logging(self, model_dict, strict=strict)
Expand All @@ -773,7 +779,9 @@ def load(

# Load the model weights
model_dict = torch.load(
local_path.joinpath("model.pt"), map_location=device
local_path.joinpath("model.pt"),
map_location=device,
weights_only=True,
)

# Load state dict into the model
Expand Down Expand Up @@ -1054,7 +1062,9 @@ def _from_checkpoint_process(
model_bytes = archive.read("model.pt")

# Load state dict after closing archive
model_dict = torch.load(io.BytesIO(model_bytes), map_location=model.device)
model_dict = torch.load(
io.BytesIO(model_bytes), map_location=model.device, weights_only=True
)

# Load state_dict into the model
_load_state_dict_with_logging(model, model_dict, strict=strict)
Expand Down Expand Up @@ -1096,7 +1106,9 @@ def _from_checkpoint_process(

# Load the model weights
model_dict = torch.load(
local_path.joinpath("model.pt"), map_location=model.device
local_path.joinpath("model.pt"),
map_location=model.device,
weights_only=True,
)

# Load state_dict into the model
Expand Down