-
Notifications
You must be signed in to change notification settings - Fork 1
Implement fetching by doi and custom hashes #61
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
710500e
a1d0800
68524b6
ff6fd70
e064711
f9856db
57f222e
9522d57
f560c1f
c6eb311
1d67899
4c00163
4389a51
d5eb021
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| import functools | ||
| import hashlib | ||
| import json | ||
| import re | ||
| import pathlib | ||
| import urllib.request | ||
|
|
||
|
|
@@ -19,19 +20,27 @@ | |
|
|
||
| CACHE_DIR = platformdirs.user_cache_path() / "OPENFF_NAGL_MODELS" | ||
|
|
||
|
|
||
| def get_release_metadata() -> list[dict]: | ||
| return json.loads(urllib.request.urlopen(RELEASES_URL).read().decode("utf-8")) | ||
|
|
||
|
|
||
| @functools.lru_cache() | ||
| def get_model(filename: str) -> str: | ||
| def get_model( | ||
| filename: str, doi: None | str = None, file_hash: None | str = None | ||
| ) -> str: | ||
| """Return the path of a model as cached on disk, downloading if necessary.""" | ||
| pathlib.Path(CACHE_DIR).mkdir(exist_ok=True) | ||
|
|
||
| cached_path = CACHE_DIR / filename | ||
|
|
||
| check_hash = file_hash | ||
| if check_hash is None and filename in KNOWN_HASHES: | ||
| check_hash = KNOWN_HASHES[filename] | ||
|
|
||
| if cached_path.exists(): | ||
| assert _get_sha256(cached_path) == KNOWN_HASHES[filename] | ||
| if check_hash: | ||
| assert _get_sha256(cached_path) == check_hash | ||
|
|
||
| return cached_path.as_posix() | ||
|
|
||
|
|
@@ -55,12 +64,33 @@ def get_model(filename: str) -> str: | |
| assert cached_path.exists() | ||
| assert path_to_file == cached_path.as_posix() | ||
|
|
||
| assert _get_sha256(cached_path) == KNOWN_HASHES[filename], ( | ||
| f"Hash mismatch for {filename}" | ||
| ) | ||
| if check_hash: | ||
| assert ( | ||
| _get_sha256(cached_path) == check_hash | ||
| ), f"Hash mismatch for {filename}" | ||
|
|
||
| return cached_path.as_posix() | ||
|
|
||
| if doi: | ||
| zenodo_id = re.findall("10.5072/zenodo.([0-9]+)", doi)[0] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will we support non-Zenodo DOIs? Figshare etc?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not in this version of the NAGLCharges spec, but other sources could be added in a higher NAGLCharges section version. |
||
|
|
||
| # Remove "sandbox." to convert this to "real" zenodo before merge | ||
| # Or keep in with a testing flag? | ||
| file_url = ( | ||
| f"https://sandbox.zenodo.org/api/records/{zenodo_id}/files/{filename}" | ||
| ) | ||
| path_to_file, _ = urllib.request.urlretrieve( | ||
| file_url, filename=cached_path.as_posix() | ||
| ) | ||
| assert cached_path.exists() | ||
| assert path_to_file == cached_path.as_posix() | ||
|
|
||
| if check_hash: | ||
| assert ( | ||
| _get_sha256(cached_path) == file_hash | ||
| ), f"Hash mismatch for {filename}" | ||
| return cached_path.as_posix() | ||
|
|
||
| raise FileNotFoundError( | ||
| f"Could not find asset with name '{filename}' in any release" | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| import json | ||
| import os | ||
| import pathlib | ||
| import shutil | ||
| import urllib.request | ||
|
|
@@ -59,11 +60,27 @@ def test_get_known_models(monkeypatch, known_model): | |
| assert "OPENFF_NAGL_MODELS" in get_model(known_model) | ||
|
|
||
|
|
||
| def test_access_internet_with_empty_cache(): | ||
| cache_path = platformdirs.user_cache_path() / "OPENFF_NAGL_MODELS" | ||
| @pytest.fixture | ||
| def hide_cache(): | ||
| cache_dir = platformdirs.user_cache_path() / "OPENFF_NAGL_MODELS" | ||
| alt_dir = str(cache_dir) + "_temp" | ||
|
|
||
| if os.path.exists(alt_dir): | ||
| raise FileExistsError(f"Temporary directory already exists: {alt_dir}") | ||
|
|
||
| if os.path.exists(cache_dir): | ||
| shutil.move(cache_dir, alt_dir) | ||
|
|
||
| yield | ||
|
|
||
| if cache_path.exists(): | ||
| shutil.rmtree(cache_path) | ||
| if os.path.exists(alt_dir): | ||
| if os.path.exists(cache_dir): | ||
| shutil.rmtree(cache_dir) | ||
| shutil.move(alt_dir, cache_dir) | ||
|
|
||
|
|
||
| def test_access_internet_with_empty_cache(hide_cache): | ||
| cache_path = platformdirs.user_cache_path() / "OPENFF_NAGL_MODELS" | ||
|
|
||
| disable_socket() | ||
|
|
||
|
|
@@ -147,3 +164,18 @@ def test_all_models_loadable(model, monkeypatch): | |
| ) | ||
|
|
||
| GNNModel.load(get_model(model), eval_mode=True) | ||
|
|
||
|
|
||
| def test_get_model_by_doi(hide_cache): | ||
| get_model( | ||
| "my_favorite_model.pt", | ||
| doi="10.5072/zenodo.278300", | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This record must be sand-box only? This is my first Google result, which seems unlikely to be what you actually want to point to: https://zenodo.org/records/14335473 A comment or note about where this lives and how the hash was generated would be useful for future developers, I don't think anything else would be necessary here |
||
| file_hash="127eb0b9512f22546f8b455582bcd85b2521866d32b86d231fee26d4771b1d81", | ||
| ) | ||
|
|
||
|
|
||
| def test_get_model_hash_comparison_fails(): | ||
| with pytest.raises(AssertionError): | ||
| get_model( | ||
| "my_favorite_model.pt", doi="10.5072/zenodo.278300", file_hash="wrong_hash" | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.