diff --git a/test/test_util.py b/test/test_util.py index b50fedc..bd3f341 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -85,6 +85,29 @@ def test_write_stac(tmp_path, dataset, write_datasets, pre_existing_catalog): } +orig_import = __import__ +def import_mock(name, *args): + if name == 'pystac': + raise ModuleNotFoundError("No module named 'pystac'") + return orig_import(name, *args) + + +def test_write_stac_no_pystac(tmp_path, dataset): + # Import hooks are the recommended "clean" way to do this, but don't work + # in this case. + with mock.patch('builtins.__import__', side_effect=import_mock): + # pytest imports xcengine.util long before we can patch __import__, + # so we delete pystac from util's namespace (if present) instead. + # This gives a NameError on access rather than a ModuleNotFoundError + # on import, but the important thing is to break any implementation + # that tries to import pystac without checking if it's available. + import xcengine.util + xcengine.util.__dict__.pop("pystac", None) + from xcengine.util import write_stac + write_stac({"ds1": dataset}, tmp_path) + # We want nothing to happen here, so no explicit assertions. + + @pytest.mark.parametrize("eoap_mode", [False, True]) @pytest.mark.parametrize("ds2_format", [None, "zarr", "netcdf"]) def test_save_datasets(tmp_path, dataset, eoap_mode, ds2_format): diff --git a/xcengine/util.py b/xcengine/util.py index 71dd2b4..8c528c6 100644 --- a/xcengine/util.py +++ b/xcengine/util.py @@ -8,11 +8,9 @@ import shutil from typing import NamedTuple, Mapping -import pystac import xarray as xr from xarray import Dataset - def clear_directory(directory: pathlib.Path) -> None: for path in directory.iterdir(): if path.is_dir(): @@ -24,6 +22,12 @@ def clear_directory(directory: pathlib.Path) -> None: def write_stac( datasets: Mapping[str, xr.Dataset], stac_root: pathlib.Path ) -> None: + try: + import pystac + except ModuleNotFoundError: + # If pystac isn't present, we assume that stage-out is not required + # and exit quietly. + return catalog_path = stac_root / "catalog.json" if catalog_path.exists(): # Assume that the user code generated its own stage-out data