Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions openff/nagl_models/_dynamic_fetch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
import hashlib
import json
import re
import pathlib
import urllib.request

Expand All @@ -19,19 +20,27 @@

CACHE_DIR = platformdirs.user_cache_path() / "OPENFF_NAGL_MODELS"


def get_release_metadata() -> list[dict]:
return json.loads(urllib.request.urlopen(RELEASES_URL).read().decode("utf-8"))


@functools.lru_cache()
def get_model(filename: str) -> str:
def get_model(
filename: str, doi: None | str = None, file_hash: None | str = None
) -> str:
"""Return the path of a model as cached on disk, downloading if necessary."""
pathlib.Path(CACHE_DIR).mkdir(exist_ok=True)

cached_path = CACHE_DIR / filename

check_hash = file_hash
Comment thread
j-wags marked this conversation as resolved.
Outdated
if check_hash is None and filename in KNOWN_HASHES:
check_hash = KNOWN_HASHES[filename]

if cached_path.exists():
assert _get_sha256(cached_path) == KNOWN_HASHES[filename]
if check_hash:
assert _get_sha256(cached_path) == check_hash

return cached_path.as_posix()

Expand All @@ -55,12 +64,33 @@ def get_model(filename: str) -> str:
assert cached_path.exists()
assert path_to_file == cached_path.as_posix()

assert _get_sha256(cached_path) == KNOWN_HASHES[filename], (
f"Hash mismatch for {filename}"
)
if check_hash:
assert (
_get_sha256(cached_path) == check_hash
), f"Hash mismatch for {filename}"

return cached_path.as_posix()

if doi:
zenodo_id = re.findall("10.5072/zenodo.([0-9]+)", doi)[0]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will we support non-Zenodo DOIs? Figshare etc?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not in this version of the NAGLCharges spec, but other sources could be added in a higher NAGLCharges section version.


# Remove "sandbox." to convert this to "real" zenodo before merge
# Or keep in with a testing flag?
file_url = (
f"https://sandbox.zenodo.org/api/records/{zenodo_id}/files/{filename}"
)
path_to_file, _ = urllib.request.urlretrieve(
file_url, filename=cached_path.as_posix()
)
assert cached_path.exists()
assert path_to_file == cached_path.as_posix()

if check_hash:
assert (
_get_sha256(cached_path) == file_hash
), f"Hash mismatch for {filename}"
return cached_path.as_posix()

raise FileNotFoundError(
f"Could not find asset with name '{filename}' in any release"
)
Expand Down
40 changes: 36 additions & 4 deletions openff/nagl_models/tests/test_dynamic_fetch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import os
import pathlib
import shutil
import urllib.request
Expand Down Expand Up @@ -59,11 +60,27 @@ def test_get_known_models(monkeypatch, known_model):
assert "OPENFF_NAGL_MODELS" in get_model(known_model)


def test_access_internet_with_empty_cache():
cache_path = platformdirs.user_cache_path() / "OPENFF_NAGL_MODELS"
@pytest.fixture
def hide_cache():
cache_dir = platformdirs.user_cache_path() / "OPENFF_NAGL_MODELS"
alt_dir = str(cache_dir) + "_temp"

if os.path.exists(alt_dir):
raise FileExistsError(f"Temporary directory already exists: {alt_dir}")

if os.path.exists(cache_dir):
shutil.move(cache_dir, alt_dir)

yield

if cache_path.exists():
shutil.rmtree(cache_path)
if os.path.exists(alt_dir):
if os.path.exists(cache_dir):
shutil.rmtree(cache_dir)
shutil.move(alt_dir, cache_dir)


def test_access_internet_with_empty_cache(hide_cache):
cache_path = platformdirs.user_cache_path() / "OPENFF_NAGL_MODELS"

disable_socket()

Expand Down Expand Up @@ -147,3 +164,18 @@ def test_all_models_loadable(model, monkeypatch):
)

GNNModel.load(get_model(model), eval_mode=True)


def test_get_model_by_doi(hide_cache):
get_model(
"my_favorite_model.pt",
doi="10.5072/zenodo.278300",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This record must be sand-box only? This is my first Google result, which seems unlikely to be what you actually want to point to: https://zenodo.org/records/14335473

A comment or note about where this lives and how the hash was generated would be useful for future developers, I don't think anything else would be necessary here

file_hash="127eb0b9512f22546f8b455582bcd85b2521866d32b86d231fee26d4771b1d81",
)


def test_get_model_hash_comparison_fails():
with pytest.raises(AssertionError):
get_model(
"my_favorite_model.pt", doi="10.5072/zenodo.278300", file_hash="wrong_hash"
)