diff --git a/ami/main/api/views.py b/ami/main/api/views.py index 75951535a..00325bc4b 100644 --- a/ami/main/api/views.py +++ b/ami/main/api/views.py @@ -1501,7 +1501,11 @@ def get_queryset(self) -> QuerySet: qs = self.get_taxa_observed(qs, project, include_unobserved=include_unobserved) if self.action == "retrieve": qs = self.get_taxa_observed( - qs, project, include_unobserved=include_unobserved, apply_default_filters=False + qs, + project, + include_unobserved=include_unobserved, + apply_default_score_filter=True, + apply_default_taxa_filter=False, ) qs = qs.prefetch_related( Prefetch( @@ -1519,7 +1523,12 @@ def get_queryset(self) -> QuerySet: return qs def get_taxa_observed( - self, qs: QuerySet, project: Project, include_unobserved=False, apply_default_filters=True + self, + qs: QuerySet, + project: Project, + include_unobserved=False, + apply_default_score_filter=True, + apply_default_taxa_filter=True, ) -> QuerySet: """ If a project is passed, only return taxa that have been observed. @@ -1537,15 +1546,21 @@ def get_taxa_observed( # Respects apply_defaults flag: build_occurrence_default_filters_q checks it internally from ami.main.models_future.filters import build_occurrence_default_filters_q - default_filters_q = build_occurrence_default_filters_q(project, self.request, occurrence_accessor="") + default_filters_q = build_occurrence_default_filters_q( + project, + self.request, + occurrence_accessor="", + apply_default_score_filter=apply_default_score_filter, + apply_default_taxa_filter=apply_default_taxa_filter, + ) # Combine base occurrence filters with default filters base_filter = models.Q( occurrence_filters, determination_id=models.OuterRef("id"), ) - if apply_default_filters: - base_filter = base_filter & default_filters_q + + base_filter = base_filter & default_filters_q # Count occurrences - uses composite index (determination_id, project_id, event_id, determination_score) occurrences_count_subquery = models.Subquery( diff --git a/ami/main/models.py b/ami/main/models.py index 2de69f815..17eaa650e 100644 --- a/ami/main/models.py +++ b/ami/main/models.py @@ -152,16 +152,41 @@ def get_or_create_default_collection(project: "Project") -> "SourceImageCollecti return collection +def get_project_default_filters(): + """ + Read default taxa names from Django settings (read from environment variables) + and return corresponding Taxon objects. + """ + include_taxa = list(Taxon.objects.filter(name__in=settings.DEFAULT_INCLUDE_TAXA)) + exclude_taxa = list(Taxon.objects.filter(name__in=settings.DEFAULT_EXCLUDE_TAXA)) + + return {"default_include_taxa": include_taxa, "default_exclude_taxa": exclude_taxa} + + def get_or_create_default_project(user: User) -> "Project": """ Create a default project for a user. - Default related objects like devices and research sites will be created - when the project is saved for the first time. - If the project already exists, it will be returned without modification. + When a new project is created, default related objects (device, site, + deployment, collection, processing service) and default taxa filters are + initialized explicitly. ``get_or_create`` bypasses ``ProjectManager.create``, + so we call ``create_related_defaults`` here instead of relying on the manager. """ - project, _created = Project.objects.get_or_create(name="Scratch Project", owner=user, create_defaults=True) - logger.info(f"Created default project for user {user}") + project, created = Project.objects.get_or_create(name="Scratch Project", owner=user) + if created: + logger.info(f"Created default project for user {user}") + Project.objects.create_related_defaults(project) + defaults = get_project_default_filters() + + if defaults["default_include_taxa"]: + project.default_filters_include_taxa.set(defaults["default_include_taxa"]) + logger.info(f"Set {len(defaults['default_include_taxa'])} default include taxa for project {project}") + if defaults["default_exclude_taxa"]: + project.default_filters_exclude_taxa.set(defaults["default_exclude_taxa"]) + logger.info(f"Set {len(defaults['default_exclude_taxa'])} default exclude taxa for project {project}") + project.save() + else: + logger.info(f"Loaded existing default project for user {user}") return project @@ -317,7 +342,7 @@ def summary_data(self): def update_related_calculated_fields(self): """ - Update calculated fields for all related events and deployments. + Update calculated fields for all related events, deployments, and source images. """ # Update events for event in self.events.all(): @@ -327,6 +352,10 @@ def update_related_calculated_fields(self): for deployment in self.deployments.all(): deployment.update_calculated_fields(save=True) + # Update source image cached detection counts using the project's default filters + # so SourceImage.detections_count stays consistent with get_detections_count(). + update_detection_counts(qs=SourceImage.objects.filter(project=self), project=self) + def save(self, *args, **kwargs): super().save(*args, **kwargs) # Add owner to members @@ -767,6 +796,23 @@ def get_first_and_last_timestamps(self) -> tuple[datetime.datetime, datetime.dat ) return (first, last) + def get_detections_count(self) -> int | None: + """ + Return detections count filtered by project default filters. + + Excludes null-bbox placeholder detections (records indicating an image + was processed and no detections were found) to stay consistent with + ``SourceImage.get_detections_count`` and ``Event.get_detections_count``. + """ + qs = Detection.objects.filter(source_image__deployment=self).exclude(NULL_DETECTIONS_FILTER) + filter_q = build_occurrence_default_filters_q( + project=self.project, + request=None, + occurrence_accessor="occurrence", + ) + + return qs.filter(filter_q).distinct().count() + def first_date(self) -> datetime.date | None: return self.first_capture_timestamp.date() if self.first_capture_timestamp else None @@ -999,7 +1045,7 @@ def update_calculated_fields(self, save=False): self.events_count = self.events.count() self.captures_count = self.data_source_total_files or self.captures.count() - self.detections_count = Detection.objects.filter(Q(source_image__deployment=self)).count() + self.detections_count = self.get_detections_count() occ_qs = self.occurrences.filter(event__isnull=False).apply_default_filters( # type: ignore project=self.project, request=None, @@ -1164,7 +1210,20 @@ def get_captures_count(self) -> int: return self.captures.distinct().count() def get_detections_count(self) -> int | None: - return Detection.objects.filter(Q(source_image__event=self)).count() + """ + Return detections count filtered by project default filters. + + Excludes null-bbox placeholder detections to stay consistent with + ``SourceImage.get_detections_count`` and ``Deployment.get_detections_count``. + """ + qs = Detection.objects.filter(source_image__event=self).exclude(NULL_DETECTIONS_FILTER) + filter_q = build_occurrence_default_filters_q( + project=self.project, + request=None, + occurrence_accessor="occurrence", + ) + + return qs.filter(filter_q).distinct().count() def get_occurrences_count(self, classification_threshold: float = 0) -> int: """ @@ -1889,9 +1948,23 @@ def size_display(self) -> str: return filesizeformat(self.size) def get_detections_count(self) -> int: - # Detections count excludes detections without bounding boxes - # Detections with null bounding boxes are valid and indicates the image was successfully processed - return self.detections.exclude(NULL_DETECTIONS_FILTER).count() + """ + Return detections count filtered by project default filters. + + Excludes detections without bounding boxes — those are placeholder records + indicating the image was successfully processed and no detections were found. + """ + qs = self.detections.exclude(NULL_DETECTIONS_FILTER) + project = self.project + if not project: + return qs.distinct().count() + + q = build_occurrence_default_filters_q( + project=project, + request=None, + occurrence_accessor="occurrence", + ) + return qs.filter(q).distinct().count() def get_was_processed(self, algorithm_key: str | None = None) -> bool: """ @@ -2069,22 +2142,34 @@ class Meta: ] -def update_detection_counts(qs: models.QuerySet[SourceImage] | None = None, null_only=False) -> int: +def update_detection_counts( + qs: models.QuerySet[SourceImage] | None = None, + null_only=False, + project: "Project | None" = None, +) -> int: """ Update the detection count for all source images using a bulk update query. + When ``project`` is provided, the count is filtered by that project's default + filters so the cached ``SourceImage.detections_count`` stays consistent with + ``SourceImage.get_detections_count()``. + @TODO Needs testing. """ qs = qs or SourceImage.objects.all() if null_only: qs = qs.filter(detections_count__isnull=True) + detection_qs = Detection.objects.filter(source_image_id=models.OuterRef("pk")).exclude(NULL_DETECTIONS_FILTER) + if project is not None: + filter_q = build_occurrence_default_filters_q( + project=project, + request=None, + occurrence_accessor="occurrence", + ) + detection_qs = detection_qs.filter(filter_q) subquery = models.Subquery( - Detection.objects.filter(source_image_id=models.OuterRef("pk")) - .exclude(NULL_DETECTIONS_FILTER) - .values("source_image_id") - .annotate(count=models.Count("id")) - .values("count"), + detection_qs.values("source_image_id").annotate(count=models.Count("id")).values("count"), output_field=models.IntegerField(), ) start_time = time.time() diff --git a/ami/main/models_future/filters.py b/ami/main/models_future/filters.py index 6689065c2..8b8782dce 100644 --- a/ami/main/models_future/filters.py +++ b/ami/main/models_future/filters.py @@ -129,6 +129,8 @@ def build_occurrence_default_filters_q( project: "Project | None" = None, request: "Request | None" = None, occurrence_accessor: str = "", + apply_default_score_filter: bool = True, + apply_default_taxa_filter: bool = True, ) -> Q: """ Build a Q filter that applies default filters (score threshold + taxa) for Occurrence relationships. @@ -194,19 +196,19 @@ def build_occurrence_default_filters_q( return Q() filter_q = Q() - - # Build score threshold filter - score_threshold = get_default_classification_threshold(project, request) - filter_q &= build_occurrence_score_threshold_q(score_threshold, occurrence_accessor) - - # Build taxa inclusion/exclusion filter - # For taxa filtering, we need to append "__determination" to the occurrence accessor - prefix = f"{occurrence_accessor}__" if occurrence_accessor else "" - taxon_accessor = f"{prefix}determination" - include_taxa = project.default_filters_include_taxa.all() - exclude_taxa = project.default_filters_exclude_taxa.all() - taxa_q = build_taxa_recursive_filter_q(include_taxa, exclude_taxa, taxon_accessor) - if taxa_q: - filter_q &= taxa_q + if apply_default_score_filter: + # Build score threshold filter + score_threshold = get_default_classification_threshold(project, request) + filter_q &= build_occurrence_score_threshold_q(score_threshold, occurrence_accessor) + if apply_default_taxa_filter: + # Build taxa inclusion/exclusion filter + # For taxa filtering, we need to append "__determination" to the occurrence accessor + prefix = f"{occurrence_accessor}__" if occurrence_accessor else "" + taxon_accessor = f"{prefix}determination" + include_taxa = project.default_filters_include_taxa.all() + exclude_taxa = project.default_filters_exclude_taxa.all() + taxa_q = build_taxa_recursive_filter_q(include_taxa, exclude_taxa, taxon_accessor) + if taxa_q: + filter_q &= taxa_q return filter_q diff --git a/ami/main/signals.py b/ami/main/signals.py index a81ee13b0..e36e41937 100644 --- a/ami/main/signals.py +++ b/ami/main/signals.py @@ -1,13 +1,16 @@ import logging from django.contrib.auth.models import Group +from django.db import transaction from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save from django.dispatch import receiver from guardian.shortcuts import assign_perm +from ami.main.models import Project +from ami.main.tasks import refresh_project_cached_counts from ami.users.roles import BasicMember, ProjectManager, create_roles_for_project -from .models import Project, User +from .models import User logger = logging.getLogger(__name__) @@ -110,3 +113,87 @@ def delete_project_groups(sender, instance, **kwargs): prefix = f"{instance.pk}_" # Find and delete all groups that start with {project_id}_ Group.objects.filter(name__startswith=prefix).delete() + + +# ============================================================================ +# Project Default Filters Update Signals +# ============================================================================ +# These signals handle efficient updates to calculated fields for project-related +# objects (such as Deployments and Events) whenever a project's default filter +# values change. +# +# Specifically, they trigger recalculation of cached counts when: +# - The project's default score threshold is updated +# - The project's default include taxa are modified +# - The project's default exclude taxa are modified +# +# This ensures that cached counts (e.g., occurrences_count, taxa_count) remain +# accurate and consistent with the active filter configuration for each project. +# ============================================================================ + + +def refresh_cached_counts_for_project(project: Project): + """ + Enqueue a Celery task to refresh cached counts for a project's Deployments + and Events after the surrounding transaction commits. + + This fan-out can iterate hundreds of events and dozens of deployments, so + running it inline in the request/save path would block the caller. The + ``transaction.on_commit`` wrapper guarantees the task only runs if the + triggering save succeeds. + """ + logger.info(f"Scheduling cached-count refresh for project {project.pk} ({project.name})") + transaction.on_commit(lambda: refresh_project_cached_counts.delay(project.pk)) + + +@receiver(pre_save, sender=Project) +def cache_old_threshold(sender, instance, **kwargs): + """ + Cache the previous default score threshold before saving the Project. + + We do this because: + - In post_save, the instance already contains the NEW value. + - To detect whether the threshold actually changed, we must read the OLD + value from the database before the update happens. + - This allows us to accurately detect threshold changes and then trigger + recalculation of cached filtered counts (Events, Deployments, etc.). + + The cached value is stored on the instance as `_old_threshold` so it can be + safely accessed in the post_save handler. + """ + if instance.pk: + instance._old_threshold = Project.objects.get(pk=instance.pk).default_filters_score_threshold + else: + instance._old_threshold = None + + +@receiver(post_save, sender=Project) +def threshold_updated(sender, instance, **kwargs): + """ + After saving the Project, compare the previously cached threshold with the new value. + If the default score threshold changed, we refresh all cached counts using the new filters. + + This two-step (pre_save + post_save) pattern is required because: + - post_save instances already contain the updated value + - so the old threshold would be lost without caching it in pre_save + """ + old_threshold = instance._old_threshold + new_threshold = instance.default_filters_score_threshold + if old_threshold is not None and old_threshold != new_threshold: + refresh_cached_counts_for_project(instance) + + +@receiver(m2m_changed, sender=Project.default_filters_include_taxa.through) +def include_taxa_updated(sender, instance: Project, action, **kwargs): + """Refresh cached counts when include taxa are modified.""" + if action in ["post_add", "post_remove", "post_clear"]: + logger.info(f"Include taxa updated for project {instance.pk} (action={action})") + refresh_cached_counts_for_project(instance) + + +@receiver(m2m_changed, sender=Project.default_filters_exclude_taxa.through) +def exclude_taxa_updated(sender, instance: Project, action, **kwargs): + """Refresh cached counts when exclude taxa are modified.""" + if action in ["post_add", "post_remove", "post_clear"]: + logger.info(f"Exclude taxa updated for project {instance.pk} (action={action})") + refresh_cached_counts_for_project(instance) diff --git a/ami/main/tasks.py b/ami/main/tasks.py new file mode 100644 index 000000000..16f927a3f --- /dev/null +++ b/ami/main/tasks.py @@ -0,0 +1,25 @@ +import logging + +from config import celery_app + +logger = logging.getLogger(__name__) + + +@celery_app.task(ignore_result=True) +def refresh_project_cached_counts(project_id: int) -> None: + """Refresh cached counts for all Events and Deployments in a project. + + Dispatched from signals on ``Project.default_filters_*`` changes. The work + fans out to every Event and Deployment in the project, so it must not run + inline in the request/save path. + """ + from ami.main.models import Project + + try: + project = Project.objects.get(pk=project_id) + except Project.DoesNotExist: + logger.warning(f"Project {project_id} not found; skipping cached-count refresh") + return + + logger.info(f"Refreshing cached counts for project {project.pk} ({project.name})") + project.update_related_calculated_fields() diff --git a/ami/main/tests.py b/ami/main/tests.py index d105b6277..3ca75d6ba 100644 --- a/ami/main/tests.py +++ b/ami/main/tests.py @@ -3977,3 +3977,151 @@ def test_list_pipelines_public_project_non_member(self): self.client.force_authenticate(user=non_member) response = self.client.get(url) self.assertEqual(response.status_code, status.HTTP_200_OK) + + +class TestCachedCountsDefaultFilters(APITestCase): + """Tests for PR #1045: default-filter-aware cached counts. + + Covers three behaviors introduced by the PR: + 1. Taxa list and detail views return the same ``occurrences_count`` when + only the score threshold is applied (no include/exclude filters). + 2. ``get_detections_count`` is consistent across the + SourceImage/Event/Deployment hierarchy and ignores null-bbox + placeholder detections. + 3. The project default-filter signal only fans out a refresh task when + the threshold actually changes. + """ + + def setUp(self) -> None: + self.project, self.deployment = setup_test_project(reuse=False) + create_taxa(project=self.project) + create_captures(deployment=self.deployment, num_nights=2, images_per_night=3) + create_occurrences(deployment=self.deployment, num=6, determination_score=0.9) + return super().setUp() + + def test_taxa_list_and_detail_occurrences_count_parity(self): + """List and detail views must agree on occurrences_count under the score threshold. + + Regression for the hazard identified in PR #1045: the list view applies + both score and taxa filters, while the detail view bypasses the taxa + filter. With no include/exclude taxa configured, both endpoints must + still return identical counts for the same taxon. + """ + self.project.default_filters_score_threshold = 0.5 + self.project.save() + + list_response = self.client.get(f"/api/v2/taxa/?project_id={self.project.pk}") + self.assertEqual(list_response.status_code, 200) + + results = list_response.json()["results"] + self.assertGreater(len(results), 0, "Expected at least one observed taxon") + + # Pick a taxon that actually has occurrences so the parity check is meaningful + taxa_with_occs = [r for r in results if r["occurrences_count"] > 0] + self.assertGreater(len(taxa_with_occs), 0, "Expected at least one taxon with occurrences") + + for taxon_from_list in taxa_with_occs: + detail_response = self.client.get(f"/api/v2/taxa/{taxon_from_list['id']}/?project_id={self.project.pk}") + self.assertEqual(detail_response.status_code, 200) + detail_count = detail_response.json()["occurrences_count"] + self.assertEqual( + detail_count, + taxon_from_list["occurrences_count"], + f"Taxon {taxon_from_list['id']} has count {taxon_from_list['occurrences_count']} in list " + f"but {detail_count} in detail", + ) + + def test_detection_count_hierarchy_consistency(self): + """Deployment/Event/SourceImage detection counts must agree and skip null-bbox rows. + + Regression for the fix in this PR: previously + ``SourceImage.get_detections_count`` excluded ``NULL_DETECTIONS_FILTER`` + while Deployment and Event did not, causing the hierarchy to return + divergent numbers. + """ + # Seed a null-bbox placeholder (a successful "no detections" marker) on one capture + first_capture = self.deployment.captures.first() + assert first_capture is not None + Detection.objects.create( + source_image=first_capture, + timestamp=first_capture.timestamp, + bbox=None, + ) + + images_total = sum(img.get_detections_count() or 0 for img in self.deployment.captures.all()) + events_total = sum(event.get_detections_count() or 0 for event in self.deployment.events.all()) + deployment_total = self.deployment.get_detections_count() or 0 + + self.assertEqual(images_total, events_total) + self.assertEqual(events_total, deployment_total) + self.assertGreater(deployment_total, 0, "Expected real detections to survive the null-bbox filter") + + # And crucially the null-bbox placeholder must NOT be counted anywhere + raw_detection_count = Detection.objects.filter(source_image__deployment=self.deployment).count() + self.assertEqual( + deployment_total, + raw_detection_count - 1, + "Deployment detection count should exclude the 1 null-bbox placeholder", + ) + + def test_refresh_signal_only_fires_on_threshold_change(self): + """Saving a Project without changing the threshold must not enqueue a refresh.""" + from unittest.mock import patch + + with patch("ami.main.signals.refresh_project_cached_counts.delay") as mock_delay: + # Save without changing threshold — should NOT enqueue + with self.captureOnCommitCallbacks(execute=True): + self.project.description = "unrelated edit" + self.project.save() + self.assertEqual( + mock_delay.call_count, + 0, + "Saving a Project without a threshold change should not enqueue a refresh task", + ) + + # Save with a changed threshold — should enqueue exactly once + with self.captureOnCommitCallbacks(execute=True): + self.project.default_filters_score_threshold = 0.77 + self.project.save() + self.assertEqual( + mock_delay.call_count, + 1, + "Changing the threshold should enqueue exactly one refresh task", + ) + mock_delay.assert_called_with(self.project.pk) + + def test_source_image_cached_counts_refresh_on_threshold_change(self): + """SourceImage.detections_count cache must track filter-aware get_detections_count(). + + Regression for CodeRabbit review on PR #1045: Project.update_related_calculated_fields + previously updated Event and Deployment cached counts but left + SourceImage.detections_count unchanged, so the cached field would diverge + from the filter-aware getter after a default-filters change. + """ + self.project.default_filters_score_threshold = 0.5 + self.project.save() + self.project.update_related_calculated_fields() + + for image in self.deployment.captures.all(): + image.refresh_from_db() + self.assertEqual( + image.detections_count, + image.get_detections_count(), + f"SourceImage {image.pk} cache ({image.detections_count}) differs from " + f"filter-aware count ({image.get_detections_count()}) at threshold 0.5", + ) + + # Raise the threshold above the seeded determination_score of 0.9 — every + # occurrence-linked detection should now drop out of the filtered count. + self.project.default_filters_score_threshold = 0.95 + self.project.save() + self.project.update_related_calculated_fields() + + for image in self.deployment.captures.all(): + image.refresh_from_db() + self.assertEqual( + image.detections_count, + image.get_detections_count(), + f"SourceImage {image.pk} cache stale after raising threshold: " + f"cache={image.detections_count}, fresh={image.get_detections_count()}", + ) diff --git a/config/settings/base.py b/config/settings/base.py index e3ca14047..be6f3fada 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -548,3 +548,6 @@ def _celery_result_backend_url(redis_url): "DEFAULT_PROCESSING_SERVICE_ENDPOINT", default=None # type: ignore[no-untyped-call] ) DEFAULT_PIPELINES_ENABLED = env.list("DEFAULT_PIPELINES_ENABLED", default=None) # type: ignore[no-untyped-call] +# Default taxa filters +DEFAULT_INCLUDE_TAXA = env.list("DEFAULT_INCLUDE_TAXA", default=[]) # type: ignore[no-untyped-call] +DEFAULT_EXCLUDE_TAXA = env.list("DEFAULT_EXCLUDE_TAXA", default=[]) # type: ignore[no-untyped-call]