Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 57 additions & 8 deletions stac_fastapi/api/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def _build_api(**overrides):
def _assert_dependency_applied(api, routes):
with TestClient(api.app) as client:
for route in routes:
print(route)
response = getattr(client, route["method"].lower())(route["path"])
assert (
response.status_code == 401
Expand Down Expand Up @@ -69,12 +70,13 @@ def _assert_dependency_not_applied(api, routes):
response = client.request(
method=route["method"].lower(),
url=path,
content=route["payload"],
content=route.get("payload"),
headers={"content-type": "application/json"},
)

assert (
200 <= response.status_code < 300
), "Authenticated requests should be accepted"
), "Unauthenticated requests should be accepted"
assert response.json() == "dummy response"

def test_openapi_content_type(self):
Expand All @@ -86,6 +88,53 @@ def test_openapi_content_type(self):
== "application/vnd.oai.openapi+json;version=3.0"
)

def test_build_api_with_transaction_dependencies(self, collection, item):
settings = config.ApiSettings()
dependencies = [Depends(must_be_bob)]
api = self._build_api(
extensions=[
TransactionExtension(
client=DummyTransactionsClient(),
settings=settings,
route_dependencies=dependencies,
)
]
)
self._assert_dependency_applied(
api,
[
{"path": "/collections", "method": "POST", "payload": collection},
{
"path": "/collections/{collectionId}",
"method": "PUT",
"payload": collection,
},
{
"path": "/collections/{collectionId}",
"method": "DELETE",
"payload": collection,
},
{
"path": "/collections/{collectionId}/items",
"method": "POST",
"payload": item,
},
{
"path": "/collections/{collectionId}/items/{itemId}",
"method": "PUT",
"payload": item,
},
{
"path": "/collections/{collectionId}/items/{itemId}",
"method": "DELETE",
"payload": item,
},
],
)
self._assert_dependency_not_applied(
api, [{"path": "/collections/{collectionId}", "method": "GET"}]
)

def test_build_api_with_route_dependencies(self, collection, item):
routes = [
{"path": "/collections", "method": "POST", "payload": collection},
Expand Down Expand Up @@ -401,22 +450,22 @@ def test_add_default_method_route_dependencies_after_building_api(

class DummyCoreClient(core.BaseCoreClient):
def all_collections(self, *args, **kwargs):
...
return "dummy response"

def get_collection(self, *args, **kwargs):
...
return "dummy response"

def get_item(self, *args, **kwargs):
...
return "dummy response"

def get_search(self, *args, **kwargs):
...
return "dummy response"

def post_search(self, *args, **kwargs):
...
return "dummy response"

def item_collection(self, *args, **kwargs):
...
return "dummy response"


class DummyTransactionsClient(BaseTransactionsClient):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Transaction extension."""

from enum import Enum
from typing import List, Optional, Type, Union
from typing import List, Optional, Sequence, Type, Union

import attr
from fastapi import APIRouter, Body, FastAPI
from fastapi.params import Depends
from pydantic import TypeAdapter
from stac_pydantic import Collection, Item, ItemCollection
from stac_pydantic.shared import MimeTypes
Expand Down Expand Up @@ -183,6 +184,7 @@ class TransactionExtension(ApiExtension):
schema_href: Optional[str] = attr.ib(default=None)
router: APIRouter = attr.ib(factory=APIRouter)
response_class: Type[Response] = attr.ib(default=JSONResponse)
route_dependencies: Optional[Sequence[Depends]] = attr.ib(default=None)

def register_create_item(self):
"""Register create item endpoint (POST /collections/{collection_id}/items)."""
Expand All @@ -204,6 +206,7 @@ def register_create_item(self):
response_model_exclude_none=True,
methods=["POST"],
endpoint=create_async_endpoint(self.client.create_item, PostItem),
dependencies=self.route_dependencies,
)

def register_update_item(self):
Expand All @@ -226,6 +229,7 @@ def register_update_item(self):
response_model_exclude_none=True,
methods=["PUT"],
endpoint=create_async_endpoint(self.client.update_item, PutItem),
dependencies=self.route_dependencies,
)

def register_patch_item(self):
Expand Down Expand Up @@ -267,6 +271,7 @@ def register_patch_item(self):
self.client.patch_item,
PatchItem,
),
dependencies=self.route_dependencies,
)

def register_delete_item(self):
Expand All @@ -289,6 +294,7 @@ def register_delete_item(self):
response_model_exclude_none=True,
methods=["DELETE"],
endpoint=create_async_endpoint(self.client.delete_item, ItemUri),
dependencies=self.route_dependencies,
)

def register_create_collection(self):
Expand All @@ -311,6 +317,7 @@ def register_create_collection(self):
response_model_exclude_none=True,
methods=["POST"],
endpoint=create_async_endpoint(self.client.create_collection, Collection),
dependencies=self.route_dependencies,
)

def register_update_collection(self):
Expand All @@ -332,6 +339,7 @@ def register_update_collection(self):
response_model_exclude_none=True,
methods=["PUT"],
endpoint=create_async_endpoint(self.client.update_collection, PutCollection),
dependencies=self.route_dependencies,
)

def register_patch_collection(self):
Expand Down Expand Up @@ -372,6 +380,7 @@ def register_patch_collection(self):
self.client.patch_collection,
PatchCollection,
),
dependencies=self.route_dependencies,
)

def register_delete_collection(self):
Expand All @@ -393,6 +402,7 @@ def register_delete_collection(self):
response_model_exclude_none=True,
methods=["DELETE"],
endpoint=create_async_endpoint(self.client.delete_collection, CollectionUri),
dependencies=self.route_dependencies,
)

def register(self, app: FastAPI) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import abc
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Sequence, Union

import attr
from fastapi import APIRouter, FastAPI
from fastapi.params import Depends
from pydantic import BaseModel

from stac_fastapi.api.models import create_request_model
Expand Down Expand Up @@ -113,6 +114,7 @@ class BulkTransactionExtension(ApiExtension):
client: Union[AsyncBaseBulkTransactionsClient, BaseBulkTransactionsClient] = attr.ib()
conformance_classes: List[str] = attr.ib(default=list())
schema_href: Optional[str] = attr.ib(default=None)
route_dependencies: Optional[Sequence[Depends]] = attr.ib(default=None)

def register(self, app: FastAPI) -> None:
"""Register the extension with a FastAPI application.
Expand All @@ -136,5 +138,6 @@ def register(self, app: FastAPI) -> None:
endpoint=create_async_endpoint(
self.client.bulk_item_insert, items_request_model
),
dependencies=self.route_dependencies,
)
app.include_router(router, tags=["Bulk Transaction Extension"])
Loading