diff --git a/stac_fastapi/api/tests/test_api.py b/stac_fastapi/api/tests/test_api.py index 5f9aa73e3..2441efdf6 100644 --- a/stac_fastapi/api/tests/test_api.py +++ b/stac_fastapi/api/tests/test_api.py @@ -69,12 +69,13 @@ def _assert_dependency_not_applied(api, routes): response = client.request( method=route["method"].lower(), url=path, - content=route["payload"], + content=route.get("payload"), headers={"content-type": "application/json"}, ) + assert ( 200 <= response.status_code < 300 - ), "Authenticated requests should be accepted" + ), "Unauthenticated requests should be accepted" assert response.json() == "dummy response" def test_openapi_content_type(self): @@ -86,6 +87,53 @@ def test_openapi_content_type(self): == "application/vnd.oai.openapi+json;version=3.0" ) + def test_build_api_with_transaction_dependencies(self, collection, item): + settings = config.ApiSettings() + dependencies = [Depends(must_be_bob)] + api = self._build_api( + extensions=[ + TransactionExtension( + client=DummyTransactionsClient(), + settings=settings, + route_dependencies=dependencies, + ) + ] + ) + self._assert_dependency_applied( + api, + [ + {"path": "/collections", "method": "POST", "payload": collection}, + { + "path": "/collections/{collectionId}", + "method": "PUT", + "payload": collection, + }, + { + "path": "/collections/{collectionId}", + "method": "DELETE", + "payload": collection, + }, + { + "path": "/collections/{collectionId}/items", + "method": "POST", + "payload": item, + }, + { + "path": "/collections/{collectionId}/items/{itemId}", + "method": "PUT", + "payload": item, + }, + { + "path": "/collections/{collectionId}/items/{itemId}", + "method": "DELETE", + "payload": item, + }, + ], + ) + self._assert_dependency_not_applied( + api, [{"path": "/collections/{collectionId}", "method": "GET"}] + ) + def test_build_api_with_route_dependencies(self, collection, item): routes = [ {"path": "/collections", "method": "POST", "payload": collection}, @@ -401,22 +449,22 @@ def test_add_default_method_route_dependencies_after_building_api( class DummyCoreClient(core.BaseCoreClient): def all_collections(self, *args, **kwargs): - ... + return "dummy response" def get_collection(self, *args, **kwargs): - ... + return "dummy response" def get_item(self, *args, **kwargs): - ... + return "dummy response" def get_search(self, *args, **kwargs): - ... + return "dummy response" def post_search(self, *args, **kwargs): - ... + return "dummy response" def item_collection(self, *args, **kwargs): - ... + return "dummy response" class DummyTransactionsClient(BaseTransactionsClient): diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction/transaction.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction/transaction.py index ca61c9e60..3d7f2a353 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction/transaction.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction/transaction.py @@ -1,10 +1,11 @@ """Transaction extension.""" from enum import Enum -from typing import List, Optional, Type, Union +from typing import List, Optional, Sequence, Type, Union import attr from fastapi import APIRouter, Body, FastAPI +from fastapi.params import Depends from pydantic import TypeAdapter from stac_pydantic import Collection, Item, ItemCollection from stac_pydantic.shared import MimeTypes @@ -183,6 +184,7 @@ class TransactionExtension(ApiExtension): schema_href: Optional[str] = attr.ib(default=None) router: APIRouter = attr.ib(factory=APIRouter) response_class: Type[Response] = attr.ib(default=JSONResponse) + route_dependencies: Optional[Sequence[Depends]] = attr.ib(default=None) def register_create_item(self): """Register create item endpoint (POST /collections/{collection_id}/items).""" @@ -204,6 +206,7 @@ def register_create_item(self): response_model_exclude_none=True, methods=["POST"], endpoint=create_async_endpoint(self.client.create_item, PostItem), + dependencies=self.route_dependencies, ) def register_update_item(self): @@ -226,6 +229,7 @@ def register_update_item(self): response_model_exclude_none=True, methods=["PUT"], endpoint=create_async_endpoint(self.client.update_item, PutItem), + dependencies=self.route_dependencies, ) def register_patch_item(self): @@ -267,6 +271,7 @@ def register_patch_item(self): self.client.patch_item, PatchItem, ), + dependencies=self.route_dependencies, ) def register_delete_item(self): @@ -289,6 +294,7 @@ def register_delete_item(self): response_model_exclude_none=True, methods=["DELETE"], endpoint=create_async_endpoint(self.client.delete_item, ItemUri), + dependencies=self.route_dependencies, ) def register_create_collection(self): @@ -311,6 +317,7 @@ def register_create_collection(self): response_model_exclude_none=True, methods=["POST"], endpoint=create_async_endpoint(self.client.create_collection, Collection), + dependencies=self.route_dependencies, ) def register_update_collection(self): @@ -332,6 +339,7 @@ def register_update_collection(self): response_model_exclude_none=True, methods=["PUT"], endpoint=create_async_endpoint(self.client.update_collection, PutCollection), + dependencies=self.route_dependencies, ) def register_patch_collection(self): @@ -372,6 +380,7 @@ def register_patch_collection(self): self.client.patch_collection, PatchCollection, ), + dependencies=self.route_dependencies, ) def register_delete_collection(self): @@ -393,6 +402,7 @@ def register_delete_collection(self): response_model_exclude_none=True, methods=["DELETE"], endpoint=create_async_endpoint(self.client.delete_collection, CollectionUri), + dependencies=self.route_dependencies, ) def register(self, app: FastAPI) -> None: diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py b/stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py index aec905dff..035d82752 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py @@ -2,10 +2,11 @@ import abc from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Sequence, Union import attr from fastapi import APIRouter, FastAPI +from fastapi.params import Depends from pydantic import BaseModel from stac_fastapi.api.models import create_request_model @@ -113,6 +114,7 @@ class BulkTransactionExtension(ApiExtension): client: Union[AsyncBaseBulkTransactionsClient, BaseBulkTransactionsClient] = attr.ib() conformance_classes: List[str] = attr.ib(default=list()) schema_href: Optional[str] = attr.ib(default=None) + route_dependencies: Optional[Sequence[Depends]] = attr.ib(default=None) def register(self, app: FastAPI) -> None: """Register the extension with a FastAPI application. @@ -136,5 +138,6 @@ def register(self, app: FastAPI) -> None: endpoint=create_async_endpoint( self.client.bulk_item_insert, items_request_model ), + dependencies=self.route_dependencies, ) app.include_router(router, tags=["Bulk Transaction Extension"])