Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions ami/main/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,7 +1475,11 @@ def get_queryset(self) -> QuerySet:
qs = self.get_taxa_observed(qs, project, include_unobserved=include_unobserved)
if self.action == "retrieve":
qs = self.get_taxa_observed(
qs, project, include_unobserved=include_unobserved, apply_default_filters=False
qs,
project,
include_unobserved=include_unobserved,
apply_default_score_filter=True,
apply_default_taxa_filter=False,
)
qs = qs.prefetch_related(
Prefetch(
Expand All @@ -1493,7 +1497,12 @@ def get_queryset(self) -> QuerySet:
return qs

def get_taxa_observed(
self, qs: QuerySet, project: Project, include_unobserved=False, apply_default_filters=True
self,
qs: QuerySet,
project: Project,
include_unobserved=False,
apply_default_score_filter=True,
apply_default_taxa_filter=True,
) -> QuerySet:
"""
If a project is passed, only return taxa that have been observed.
Expand All @@ -1511,15 +1520,21 @@ def get_taxa_observed(
# Respects apply_defaults flag: build_occurrence_default_filters_q checks it internally
from ami.main.models_future.filters import build_occurrence_default_filters_q

default_filters_q = build_occurrence_default_filters_q(project, self.request, occurrence_accessor="")
default_filters_q = build_occurrence_default_filters_q(
project,
self.request,
occurrence_accessor="",
apply_default_score_filter=apply_default_score_filter,
apply_default_taxa_filter=apply_default_taxa_filter,
)

# Combine base occurrence filters with default filters
base_filter = models.Q(
occurrence_filters,
determination_id=models.OuterRef("id"),
)
if apply_default_filters:
base_filter = base_filter & default_filters_q

base_filter = base_filter & default_filters_q

# Count occurrences - uses composite index (determination_id, project_id, event_id, determination_score)
occurrences_count_subquery = models.Subquery(
Expand Down
66 changes: 61 additions & 5 deletions ami/main/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,17 @@ def get_or_create_default_collection(project: "Project") -> "SourceImageCollecti
return collection


def get_project_default_filters():
"""
Read default taxa names from Django settings (read from environment variables)
and return corresponding Taxon objects.
"""
include_taxa = list(Taxon.objects.filter(name__in=settings.DEFAULT_INCLUDE_TAXA))
exclude_taxa = list(Taxon.objects.filter(name__in=settings.DEFAULT_EXCLUDE_TAXA))

return {"default_include_taxa": include_taxa, "default_exclude_taxa": exclude_taxa}


def get_or_create_default_project(user: User) -> "Project":
"""
Create a default project for a user.
Expand All @@ -158,8 +169,20 @@ def get_or_create_default_project(user: User) -> "Project":
when the project is saved for the first time.
If the project already exists, it will be returned without modification.
"""
project, _created = Project.objects.get_or_create(name="Scratch Project", owner=user, create_defaults=True)
logger.info(f"Created default project for user {user}")
project, created = Project.objects.get_or_create(name="Scratch Project", owner=user)
if created:
logger.info(f"Created default project for user {user}")
defaults = get_project_default_filters()

if defaults["default_include_taxa"]:
project.default_filters_include_taxa.set(defaults["default_include_taxa"])
logger.info(f"Set {len(defaults['default_include_taxa'])} default include taxa for project {project}")
if defaults["default_exclude_taxa"]:
project.default_filters_exclude_taxa.set(defaults["default_exclude_taxa"])
logger.info(f"Set {len(defaults['default_exclude_taxa'])} default exclude taxa for project {project}")
project.save()
Comment thread
coderabbitai[bot] marked this conversation as resolved.
else:
logger.info(f"Loaded existing default project for user {user}")
return project


Expand Down Expand Up @@ -678,6 +701,18 @@ def get_first_and_last_timestamps(self) -> tuple[datetime.datetime, datetime.dat
)
return (first, last)

def get_detections_count(self) -> int | None:
"""Return detections count filtered by project default filters"""

qs = Detection.objects.filter(source_image__deployment=self)
filter_q = build_occurrence_default_filters_q(
project=self.project,
request=None,
occurrence_accessor="occurrence",
)

return qs.filter(filter_q).distinct().count()

def first_date(self) -> datetime.date | None:
return self.first_capture_timestamp.date() if self.first_capture_timestamp else None

Expand Down Expand Up @@ -883,7 +918,7 @@ def update_calculated_fields(self, save=False):

self.events_count = self.events.count()
self.captures_count = self.data_source_total_files or self.captures.count()
self.detections_count = Detection.objects.filter(Q(source_image__deployment=self)).count()
self.detections_count = self.get_detections_count()
occ_qs = self.occurrences.filter(event__isnull=False).apply_default_filters( # type: ignore
project=self.project,
request=None,
Expand Down Expand Up @@ -1048,7 +1083,15 @@ def get_captures_count(self) -> int:
return self.captures.distinct().count()

def get_detections_count(self) -> int | None:
return Detection.objects.filter(Q(source_image__event=self)).count()
"""Return detections count filtered by project default filters"""
qs = Detection.objects.filter(source_image__event=self)
filter_q = build_occurrence_default_filters_q(
project=self.project,
request=None,
occurrence_accessor="occurrence",
)

return qs.filter(filter_q).distinct().count()

def get_occurrences_count(self, classification_threshold: float = 0) -> int:
"""
Expand Down Expand Up @@ -1753,7 +1796,20 @@ def size_display(self) -> str:
return filesizeformat(self.size)

def get_detections_count(self) -> int:
return self.detections.distinct().count()
"""
Return detections count filtered by project default filters.
"""
project = self.project
if not project:
return self.detections.distinct().count()

q = build_occurrence_default_filters_q(
project=project,
request=None,
occurrence_accessor="occurrence",
)

return self.detections.filter(q).distinct().count()

def get_base_url(self) -> str | None:
"""
Expand Down
30 changes: 16 additions & 14 deletions ami/main/models_future/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def build_occurrence_default_filters_q(
project: "Project | None" = None,
request: "Request | None" = None,
occurrence_accessor: str = "",
apply_default_score_filter: bool = True,
apply_default_taxa_filter: bool = True,
) -> Q:
"""
Build a Q filter that applies default filters (score threshold + taxa) for Occurrence relationships.
Expand Down Expand Up @@ -194,19 +196,19 @@ def build_occurrence_default_filters_q(
return Q()

filter_q = Q()

# Build score threshold filter
score_threshold = get_default_classification_threshold(project, request)
filter_q &= build_occurrence_score_threshold_q(score_threshold, occurrence_accessor)

# Build taxa inclusion/exclusion filter
# For taxa filtering, we need to append "__determination" to the occurrence accessor
prefix = f"{occurrence_accessor}__" if occurrence_accessor else ""
taxon_accessor = f"{prefix}determination"
include_taxa = project.default_filters_include_taxa.all()
exclude_taxa = project.default_filters_exclude_taxa.all()
taxa_q = build_taxa_recursive_filter_q(include_taxa, exclude_taxa, taxon_accessor)
if taxa_q:
filter_q &= taxa_q
if apply_default_score_filter:
# Build score threshold filter
score_threshold = get_default_classification_threshold(project, request)
filter_q &= build_occurrence_score_threshold_q(score_threshold, occurrence_accessor)
if apply_default_taxa_filter:
# Build taxa inclusion/exclusion filter
# For taxa filtering, we need to append "__determination" to the occurrence accessor
prefix = f"{occurrence_accessor}__" if occurrence_accessor else ""
taxon_accessor = f"{prefix}determination"
include_taxa = project.default_filters_include_taxa.all()
exclude_taxa = project.default_filters_exclude_taxa.all()
taxa_q = build_taxa_recursive_filter_q(include_taxa, exclude_taxa, taxon_accessor)
if taxa_q:
filter_q &= taxa_q

return filter_q
81 changes: 80 additions & 1 deletion ami/main/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from django.dispatch import receiver
from guardian.shortcuts import assign_perm

from ami.main.models import Project
from ami.users.roles import BasicMember, ProjectManager, create_roles_for_project

from .models import Project, User
from .models import User

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -110,3 +111,81 @@ def delete_project_groups(sender, instance, **kwargs):
prefix = f"{instance.pk}_"
# Find and delete all groups that start with {project_id}_
Group.objects.filter(name__startswith=prefix).delete()


# ============================================================================
# Project Default Filters Update Signals
# ============================================================================
# These signals handle efficient updates to calculated fields for project-related
# objects (such as Deployments and Events) whenever a project's default filter
# values change.
#
# Specifically, they trigger recalculation of cached counts when:
# - The project's default score threshold is updated
# - The project's default include taxa are modified
# - The project's default exclude taxa are modified
#
# This ensures that cached counts (e.g., occurrences_count, taxa_count) remain
# accurate and consistent with the active filter configuration for each project.
# ============================================================================


def refresh_cached_counts_for_project(project: Project):
"""
Refresh cached counts for Deployments and Events belonging to a project.
"""
logger.info(f"Refreshing cached counts for project {project.pk} ({project.name})")
project.update_related_calculated_fields()


@receiver(pre_save, sender=Project)
def cache_old_threshold(sender, instance, **kwargs):
"""
Cache the previous default score threshold before saving the Project.

We do this because:
- In post_save, the instance already contains the NEW value.
- To detect whether the threshold actually changed, we must read the OLD
value from the database before the update happens.
- This allows us to accurately detect threshold changes and then trigger
recalculation of cached filtered counts (Events, Deployments, etc.).

The cached value is stored on the instance as `_old_threshold` so it can be
safely accessed in the post_save handler.
"""
if instance.pk:
instance._old_threshold = Project.objects.get(pk=instance.pk).default_filters_score_threshold
else:
instance._old_threshold = None


@receiver(post_save, sender=Project)
def threshold_updated(sender, instance, **kwargs):
"""
After saving the Project, compare the previously cached threshold with the new value.
If the default score threshold changed, we refresh all cached counts using the new filters.

This two-step (pre_save + post_save) pattern is required because:
- post_save instances already contain the updated value
- so the old threshold would be lost without caching it in pre_save
"""
old_threshold = instance._old_threshold
new_threshold = instance.default_filters_score_threshold
if old_threshold is not None and old_threshold != new_threshold:
refresh_cached_counts_for_project(instance)


@receiver(m2m_changed, sender=Project.default_filters_include_taxa.through)
def include_taxa_updated(sender, instance: Project, action, **kwargs):
"""Refresh cached counts when include taxa are modified."""
if action in ["post_add", "post_remove", "post_clear"]:
logger.info(f"Include taxa updated for project {instance.pk} (action={action})")
refresh_cached_counts_for_project(instance)


@receiver(m2m_changed, sender=Project.default_filters_exclude_taxa.through)
def exclude_taxa_updated(sender, instance: Project, action, **kwargs):
"""Refresh cached counts when exclude taxa are modified."""
if action in ["post_add", "post_remove", "post_clear"]:
logger.info(f"Exclude taxa updated for project {instance.pk} (action={action})")
refresh_cached_counts_for_project(instance)
3 changes: 3 additions & 0 deletions config/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,6 @@
"DEFAULT_PROCESSING_SERVICE_ENDPOINT", default=None # type: ignore[no-untyped-call]
)
DEFAULT_PIPELINES_ENABLED = env.list("DEFAULT_PIPELINES_ENABLED", default=None) # type: ignore[no-untyped-call]
# Default taxa filters
DEFAULT_INCLUDE_TAXA = env.list("DEFAULT_INCLUDE_TAXA", default=[]) # type: ignore[no-untyped-call]
DEFAULT_EXCLUDE_TAXA = env.list("DEFAULT_EXCLUDE_TAXA", default=[]) # type: ignore[no-untyped-call]