diff --git a/src/ewatercycle/_forcings/caravan.py b/src/ewatercycle/_forcings/caravan.py index 8a535f1c..96150497 100644 --- a/src/ewatercycle/_forcings/caravan.py +++ b/src/ewatercycle/_forcings/caravan.py @@ -151,7 +151,7 @@ def generate( # type: ignore[override] end_time: str, directory: str, variables: tuple[str, ...] = (), - shape: str | Path | None = None, + shape_in: str | Path | None = None, **kwargs, ) -> "CaravanForcing": """Retrieve caravan for a model. @@ -164,7 +164,8 @@ def generate( # type: ignore[override] directory: Directory in which forcing should be written. variables: Variables which are needed for model, if not specified will default to all. - shape: (Optional) Path to a shape file. + shape_in: (Optional) Path to a shape file of the basin, or the combined.shp + file of all basins. If none is specified, will be downloaded automatically. kwargs: Additional keyword arguments. basin_id: The ID of the desired basin. Data sets can be explored using @@ -188,8 +189,19 @@ def generate( # type: ignore[override] ds_basin = ds.sel(basin_id=basin_id.encode()) ds_basin_time = crop_ds(ds_basin, start_time, end_time) - if shape is None: + if shape_in is None: shape = get_shapefiles(Path(directory), basin_id) + elif Path(shape_in).name == "combined.shp": + shape = Path(directory) / f"{basin_id}.shp" + extract_basin_shapefile(basin_id, Path(shape_in), shape) + elif Path(shape).name != f"{basin_id}.shp": + msg = ( + "shape must either point to a shapefile of the basin ID" + "Or to the combined.shp file that contains all basins." + ) + raise ValueError(msg) + else: + shape = shape_in if len(variables) == 0: variables = ds_basin_time.data_vars.keys() # type: ignore[assignment] @@ -299,9 +311,6 @@ def extract_basin_shapefile( # kind of clunky but it works: select filtered polygon if i == basin_index: geom = feat.geometry - if geom.type != "Polygon": - msg = "Only polygons are supported" - raise ValueError(msg) # Add the signed area of the polygon and a timestamp # to the feature properties map.