diff --git a/openff/nagl_models/_dynamic_fetch.py b/openff/nagl_models/_dynamic_fetch.py index a3008d9..a8f7616 100644 --- a/openff/nagl_models/_dynamic_fetch.py +++ b/openff/nagl_models/_dynamic_fetch.py @@ -1,9 +1,9 @@ import functools import hashlib import json +import re import pathlib import urllib.request - import platformdirs from packaging.version import Version @@ -19,19 +19,73 @@ CACHE_DIR = platformdirs.user_cache_path() / "OPENFF_NAGL_MODELS" + +class HashComparisonFailedException(Exception): + """Exception raised when a NAGL file being loaded fails a comparison to a known or user-provided hash.""" + + +class UnableToParseDOIException(Exception): + """Exception raised when a Zenodo DOI is unable to be parsed according to the expected pattern.""" + + def get_release_metadata() -> list[dict]: return json.loads(urllib.request.urlopen(RELEASES_URL).read().decode("utf-8")) @functools.lru_cache() -def get_model(filename: str) -> str: - """Return the path of a model as cached on disk, downloading if necessary.""" +def get_model( + filename: str, + doi: None | str = None, + file_hash: None | str = None, +) -> str: + """ + Return the path of a model as cached on disk, downloading if necessary. The lookup order of this implementation is: + 1. Try to retrieve the file from the local cache + 2. Try to fetch the file from a release of https://github.com/openforcefield/openff-nagl-models + 3. Try to fetch the file from the DOI, if provided + + This method will raise an HashComparisonFailedException as soon as a hash mismatch is encountered. So if + there's a file with a matching name but a non-matching hash in the local cache, an exception will be raised + immediately, even if a file with a matching name that WOULD satisfy the hash check exists in release + metadata or at a provided Zenodo DOI. + + Parameters + ---------- + filename + The name of the file to search for. + doi + The Zenodo DOI to use as a backup location for fetching the model file if it's not found in the local cache + or in the + [release metadata of an openff-nagl-models release](https://github.com/openforcefield/openff-nagl-models/releases) + on GitHub. For example: "10.5072/zenodo.278300" + file_hash + The sha256 hash of the model file to verify the correct contents. Hash checks are automatically performed + on some OpenFF-released NAGL models. But if the model isn't released by OpenFF and this argument is + not provided or has a value of `None`, then no hash check is performed. Raises HashComparisonFailedException + if unsuccessful. If a user provides a hash value here that disagrees with the known hash for the same file + name, the user-provided hash takes precedence. + + Returns + ------- + str + The path to the file if it was found. If the file wasn't found then a FileNotFoundError is rasied. + + Raises + ------ + HashComparisonFailedException + FileNotFoundError + """ + pathlib.Path(CACHE_DIR).mkdir(exist_ok=True) cached_path = CACHE_DIR / filename + if file_hash is None and filename in KNOWN_HASHES: + file_hash = KNOWN_HASHES[filename] + if cached_path.exists(): - assert _get_sha256(cached_path) == KNOWN_HASHES[filename] + if file_hash: + assert_hash_equal(cached_path, file_hash) return cached_path.as_posix() @@ -47,25 +101,63 @@ def get_model(filename: str) -> str: release = releases[version] for file in release["assets"]: if file["name"] == filename: - path_to_file, _ = urllib.request.urlretrieve( - url=file["browser_download_url"], - filename=cached_path.as_posix(), - ) - - assert cached_path.exists() - assert path_to_file == cached_path.as_posix() - - assert _get_sha256(cached_path) == KNOWN_HASHES[filename], ( - f"Hash mismatch for {filename}" + return _download_and_verify_file( + file["browser_download_url"], cached_path, file_hash ) - return cached_path.as_posix() + if doi: + try: + match = re.search(r"10\.(5072|5281)/zenodo\.([0-9]+)", doi) + if not match: + raise IndexError + prefix, zenodo_id = match.groups() + except (IndexError, AttributeError): + raise UnableToParseDOIException( + f"Unable to parse Zenodo DOI {doi}. DOI values are expected to look " + f"like '10.5281/zenodo.278300' (production) or '10.5072/zenodo.278300' (sandbox)" + ) + + if prefix == "5072": + file_url = ( + f"https://sandbox.zenodo.org/api/records/{zenodo_id}/files/{filename}" + ) + else: + file_url = f"https://zenodo.org/api/records/{zenodo_id}/files/{filename}" + + try: + return _download_and_verify_file(file_url, cached_path, file_hash) + except urllib.error.HTTPError: + raise FileNotFoundError(f"No file at {file_url}") raise FileNotFoundError( f"Could not find asset with name '{filename}' in any release" ) +def assert_hash_equal(cached_path, expected_hash): + actual_hash = _get_sha256(cached_path) + if actual_hash != expected_hash: + raise HashComparisonFailedException( + f"NAGL model file hash check failed. Expected hash is " + f"{expected_hash} but actual hash is {actual_hash}" + ) + + +def _download_and_verify_file( + url: str, cached_path: pathlib.Path, file_hash: None | str = None +) -> str: + """Download a file from URL to cached_path and optionally verify its hash.""" + path_to_file, _ = urllib.request.urlretrieve(url, filename=cached_path.as_posix()) + + assert cached_path.exists() + assert path_to_file == cached_path.as_posix() + + if file_hash: + assert_hash_equal(cached_path, file_hash) + + return cached_path.as_posix() + + def _get_sha256(filename: str) -> str: """Get the SHA256 hash of a file from its path, assuming it's a binary file like a PyTorch model.""" hash = hashlib.sha256() diff --git a/openff/nagl_models/openff_nagl_models.py b/openff/nagl_models/openff_nagl_models.py index 23f1e68..dd55b84 100644 --- a/openff/nagl_models/openff_nagl_models.py +++ b/openff/nagl_models/openff_nagl_models.py @@ -2,6 +2,7 @@ This module only contains the function that will be the entry point that will be used to find the model files. """ + import importlib.resources import os import pathlib @@ -166,7 +167,8 @@ def list_available_nagl_models() -> list[pathlib.Path]: # look for all .pt files in the cache directory, but only those that are # expected to also be found in release assets cached_paths = [ - cached_file for cached_file in CACHE_DIR.rglob("*.pt") + cached_file + for cached_file in CACHE_DIR.rglob("*.pt") if cached_file.name in KNOWN_HASHES ] @@ -205,12 +207,12 @@ def get_models_by_type( -------- Getting the latest pre-release model for am1bcc:: - + >>> from openff.nagl_models.openff_nagl_models import get_models_by_type >>> get_models_by_type(model_type="am1bcc") [PosixPath('/.../openff-nagl-models/openff/nagl_models/models/am1bcc/openff-gnn-am1bcc-0.0.1-alpha.1.pt'), PosixPath('/.../openff-nagl-models/openff/nagl_models/models/am1bcc/openff-gnn-am1bcc-0.1.0-rc.1.pt')] - + """ from packaging.version import Version @@ -221,14 +223,12 @@ def get_models_by_type( "If you are using a custom model, " "please manually specify the path to the model file." ) - + model_files = pathlib.Path(base_dir).glob("*.pt") - + # assume everything follows the openff-gnn--.pt format n_name = len(f"openff-gnn-{model_type}-") - versions_to_paths = { - Version(f.stem[n_name:]): f for f in model_files - } + versions_to_paths = {Version(f.stem[n_name:]): f for f in model_files} versions = sorted(versions_to_paths.keys()) if production_only: versions = [v for v in versions if not v.is_prerelease] diff --git a/openff/nagl_models/tests/test_dynamic_fetch.py b/openff/nagl_models/tests/test_dynamic_fetch.py index 42ea22e..149ff08 100644 --- a/openff/nagl_models/tests/test_dynamic_fetch.py +++ b/openff/nagl_models/tests/test_dynamic_fetch.py @@ -1,4 +1,5 @@ import json +import os import pathlib import shutil import urllib.request @@ -9,7 +10,11 @@ import openff.nagl_models._dynamic_fetch from openff.nagl_models import __file__ as root -from openff.nagl_models._dynamic_fetch import get_model +from openff.nagl_models._dynamic_fetch import ( + get_model, + HashComparisonFailedException, + UnableToParseDOIException, +) def mocked_urlretrieve(url, filename): @@ -59,11 +64,27 @@ def test_get_known_models(monkeypatch, known_model): assert "OPENFF_NAGL_MODELS" in get_model(known_model) -def test_access_internet_with_empty_cache(): - cache_path = platformdirs.user_cache_path() / "OPENFF_NAGL_MODELS" +@pytest.fixture +def hide_cache(): + cache_dir = platformdirs.user_cache_path() / "OPENFF_NAGL_MODELS" + alt_dir = str(cache_dir) + "_temp" + + if os.path.exists(alt_dir): + raise FileExistsError(f"Temporary directory already exists: {alt_dir}") + + if os.path.exists(cache_dir): + shutil.move(cache_dir, alt_dir) + + yield - if cache_path.exists(): - shutil.rmtree(cache_path) + if os.path.exists(alt_dir): + if os.path.exists(cache_dir): + shutil.rmtree(cache_dir) + shutil.move(alt_dir, cache_dir) + + +def test_access_internet_with_empty_cache(hide_cache): + cache_path = platformdirs.user_cache_path() / "OPENFF_NAGL_MODELS" disable_socket() @@ -147,3 +168,53 @@ def test_all_models_loadable(model, monkeypatch): ) GNNModel.load(get_model(model), eval_mode=True) + + +def test_get_model_by_doi_and_hash(hide_cache): + # This test uses a Zenodo sandbox DOI (10.5072 prefix) and the corresponding + # SHA256 hash of the test file uploaded to that sandbox record + get_model( + "my_favorite_model.pt", + doi="10.5072/zenodo.278300", + file_hash="127eb0b9512f22546f8b455582bcd85b2521866d32b86d231fee26d4771b1d81", + ) + + +def test_get_model_by_doi_no_hash(hide_cache): + get_model("my_favorite_model.pt", doi="10.5072/zenodo.278300") + + +def test_get_model_hash_comparison_fails(): + with pytest.raises(HashComparisonFailedException): + get_model( + "my_favorite_model.pt", + doi="10.5072/zenodo.278300", + file_hash="wrong_hash", + ) + + +def test_user_provided_hash_conflicts_with_known_hash(): + with pytest.raises(HashComparisonFailedException): + get_model("openff-gnn-am1bcc-0.1.0-rc.3.pt", file_hash="wrong_hash") + + +def test_malformed_doi(monkeypatch, hide_cache): + with monkeypatch.context() as m: + m.setattr( + urllib.request, + "urlretrieve", + mocked_urlretrieve, + ) + m.setattr( + openff.nagl_models._dynamic_fetch, + "get_release_metadata", + mocked_get_release_metadata, + ) + + with pytest.raises(UnableToParseDOIException): + get_model("my_favorite_model.pt", doi="zenodo.278300") + + +def test_no_matching_file_at_doi(): + with pytest.raises(FileNotFoundError, match="sandbox.zenodo"): + get_model("file_that_doesnt_exist.pt", doi="10.5072/zenodo.278300")