diff --git a/ami/base/pagination.py b/ami/base/pagination.py index 9ebca7b21..2ec26c1ef 100644 --- a/ami/base/pagination.py +++ b/ami/base/pagination.py @@ -1,17 +1,146 @@ -from rest_framework.pagination import LimitOffsetPagination +from django.core.exceptions import ValidationError +from django.forms import BooleanField +from rest_framework.pagination import LimitOffsetPagination, remove_query_param, replace_query_param +from rest_framework.response import Response from .permissions import add_collection_level_permissions +# Query parameter name used to opt out of the total count in paginated list responses. +# Pass ``?with_counts=false`` to skip COUNT(*) for performance on large tables. +WITH_TOTAL_COUNT_PARAM = "with_counts" + class LimitOffsetPaginationWithPermissions(LimitOffsetPagination): + """ + LimitOffsetPagination that lets callers opt out of the expensive COUNT(*) query. + + Default behavior matches DRF's upstream LimitOffsetPagination: ``count`` is + computed (via a capped COUNT(*), see ``LARGE_QUERYSET_THRESHOLD``) and + returned in the response. Callers that don't need the total can pass + ``?with_counts=false`` to skip the count entirely and receive ``count: null`` + instead. In that mode ``next`` / ``previous`` links are still computed + correctly by fetching one extra row to detect whether a following page exists. + + A follow-up PR will flip the default to ``false`` and teach the UI to + request counts only when needed. Until then the default preserves existing + behavior so no frontend changes are required. + """ + + # Sentinel used internally when COUNT(*) is skipped. + _SKIP_COUNT = object() + + # Maximum rows scanned when with_counts=true is requested. If the filtered + # result set contains at least this many rows the full COUNT(*) is abandoned + # and the response falls back to ``count: null``. + LARGE_QUERYSET_THRESHOLD = 10_000 + + def paginate_queryset(self, queryset, request, view=None): + self.request = request + self.limit = self.get_limit(request) + if self.limit is None: + return None + self.offset = self.get_offset(request) + + if self._should_skip_count(request): + # Fetch one extra item to detect whether a next page exists without + # issuing a COUNT(*) on the full table. + page = list(queryset[self.offset : self.offset + self.limit + 1]) + self._has_next = len(page) > self.limit + self.count = self._SKIP_COUNT # type: ignore[assignment] + return page[: self.limit] + + # with_counts=true path: attempt a capped count so we never run a + # full COUNT(*) against a huge result set. + self.count = self._get_capped_count(queryset) + if self.count is self._SKIP_COUNT: + # Result set exceeds LARGE_QUERYSET_THRESHOLD - fall back to the + # probe-based fast path (count stays null in the response). + page = list(queryset[self.offset : self.offset + self.limit + 1]) + self._has_next = len(page) > self.limit + return page[: self.limit] + + if self.count > self.limit and self.template is not None: + self.display_page_controls = True + if self.count == 0 or self.offset > self.count: + return [] + return list(queryset[self.offset : self.offset + self.limit]) + + def get_next_link(self): + if self.count is self._SKIP_COUNT: + if not self._has_next: + return None + url = self.request.build_absolute_uri() + url = replace_query_param(url, self.limit_query_param, self.limit) + return replace_query_param(url, self.offset_query_param, self.offset + self.limit) + return super().get_next_link() + + def get_previous_link(self): + # Previous link logic does not depend on the total count. + if self.count is self._SKIP_COUNT: + if self.offset <= 0: + return None + url = self.request.build_absolute_uri() + url = replace_query_param(url, self.limit_query_param, self.limit) + offset = max(0, self.offset - self.limit) + if offset == 0: + return remove_query_param(url, self.offset_query_param) + return replace_query_param(url, self.offset_query_param, offset) + return super().get_previous_link() + def get_paginated_response(self, data): model = self._get_current_model() project = self._get_project() - paginated_response = super().get_paginated_response(data=data) - paginated_response.data = add_collection_level_permissions( - user=self.request.user, response_data=paginated_response.data, model=model, project=project + count = None if self.count is self._SKIP_COUNT else self.count + response = Response( + { + "count": count, + "next": self.get_next_link(), + "previous": self.get_previous_link(), + "results": data, + } + ) + response.data = add_collection_level_permissions( + user=self.request.user, response_data=response.data, model=model, project=project ) - return paginated_response + return response + + def get_paginated_response_schema(self, schema): + paginated_schema = super().get_paginated_response_schema(schema) + # count is null when the caller passes with_counts=false, or when a + # with_counts=true request exceeds LARGE_QUERYSET_THRESHOLD. + paginated_schema["properties"]["count"]["nullable"] = True + return paginated_schema + + def _get_capped_count(self, queryset): + """ + Run a bounded COUNT(*) that stops scanning after ``LARGE_QUERYSET_THRESHOLD`` + rows. Returns the exact count when the result set is small, or the + ``_SKIP_COUNT`` sentinel when the threshold is reached so callers can + fall back gracefully. + + Django translates ``queryset[:N].count()`` into:: + + SELECT COUNT(*) FROM (SELECT … LIMIT N) sub + + which is always O(N) regardless of total table size. + """ + # Fetch one extra row beyond the threshold so we can distinguish + # "exactly N rows" (exact count returned) from "more than N rows" + # (sentinel returned to avoid the full scan). + capped = queryset[: self.LARGE_QUERYSET_THRESHOLD + 1].count() + if capped <= self.LARGE_QUERYSET_THRESHOLD: + return capped + return self._SKIP_COUNT + + def _should_skip_count(self, request) -> bool: + """Return True when the caller has explicitly opted out of the total count.""" + raw = request.query_params.get(WITH_TOTAL_COUNT_PARAM, None) + if raw is None: + return False + try: + return not BooleanField(required=False).clean(raw) + except ValidationError: + return False def _get_current_model(self): """ diff --git a/ami/main/tests.py b/ami/main/tests.py index 4bfbdc4de..62685e378 100644 --- a/ami/main/tests.py +++ b/ami/main/tests.py @@ -3782,3 +3782,109 @@ def test_list_pipelines_public_project_non_member(self): self.client.force_authenticate(user=non_member) response = self.client.get(url) self.assertEqual(response.status_code, status.HTTP_200_OK) + + +class TestPaginationWithCounts(APITestCase): + """ + Verify the ``with_counts`` opt-out on list endpoints. + + Default behavior preserves DRF's count field (so existing UI code keeps + working). Callers can pass ``with_counts=false`` to skip the COUNT(*) + query and receive ``count: null``. A capped count (see + ``LARGE_QUERYSET_THRESHOLD``) caps the worst-case scan even on the + default path. + """ + + def setUp(self) -> None: + project, deployment = setup_test_project() + create_captures(deployment=deployment, num_nights=2, images_per_night=5) + self.project = project + self.user = User.objects.create_user( # type: ignore + email="pagination_test@insectai.org", + is_staff=True, + is_superuser=True, + ) + self.client.force_authenticate(user=self.user) + return super().setUp() + + def _captures_url(self, **params): + from urllib.parse import urlencode + + base = f"/api/v2/captures/?project_id={self.project.pk}" + if params: + base += "&" + urlencode(params) + return base + + def test_default_response_includes_integer_count(self): + """By default, count is an integer (preserves existing behavior).""" + response = self.client.get(self._captures_url(limit=5)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + data = response.json() + self.assertIsInstance(data["count"], int) + self.assertGreater(data["count"], 0) + + def test_with_counts_true_returns_integer_count(self): + """Explicit with_counts=true also returns an integer count.""" + response = self.client.get(self._captures_url(with_counts="true", limit=5)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + data = response.json() + self.assertIsInstance(data["count"], int) + self.assertGreater(data["count"], 0) + + def test_with_counts_false_returns_null_count(self): + """with_counts=false skips COUNT(*) and returns count: null.""" + response = self.client.get(self._captures_url(with_counts="false", limit=5)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + data = response.json() + self.assertIn("count", data) + self.assertIsNone(data["count"]) + self.assertIn("results", data) + + def test_with_counts_false_next_link_present_when_more_results(self): + """next link is returned even without count when more results exist.""" + total = SourceImage.objects.filter(deployment__project=self.project).count() + limit = max(1, total - 1) + response = self.client.get(self._captures_url(with_counts="false", limit=limit)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + data = response.json() + self.assertIsNone(data["count"]) + self.assertIsNotNone(data["next"]) + + def test_with_counts_false_next_link_absent_on_last_page(self): + """next is None when the current page is the last page.""" + total = SourceImage.objects.filter(deployment__project=self.project).count() + response = self.client.get(self._captures_url(with_counts="false", limit=total)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + data = response.json() + self.assertIsNone(data["count"]) + self.assertIsNone(data["next"]) + + def test_with_counts_false_previous_link_present_with_nonzero_offset(self): + """previous link is returned correctly without count.""" + response = self.client.get(self._captures_url(with_counts="false", limit=2, offset=2)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + data = response.json() + self.assertIsNone(data["count"]) + self.assertIsNotNone(data["previous"]) + + def test_count_falls_back_to_null_when_result_set_exceeds_threshold(self): + """ + When the (default) count path is taken but the result set meets or + exceeds LARGE_QUERYSET_THRESHOLD, count is null and next/previous + links still work via the probe-based path. + """ + from unittest.mock import patch + + from ami.base.pagination import LimitOffsetPaginationWithPermissions + + # Patch the threshold to 1 so even a single row triggers the fallback. + with patch.object(LimitOffsetPaginationWithPermissions, "LARGE_QUERYSET_THRESHOLD", 1): + total = SourceImage.objects.filter(deployment__project=self.project).count() + self.assertGreater(total, 1, "Need at least 2 captures for this test") + + response = self.client.get(self._captures_url(limit=1)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + data = response.json() + self.assertIsNone(data["count"], "count must be null when threshold is exceeded") + self.assertIsNotNone(data["next"], "next link must still be present") + self.assertIsNone(data["previous"])