Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions ami/main/checks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Data integrity checks for the main app.

Each module under this package defines one or more integrity check pairs:
a ``get_*`` function returning a queryset of affected rows, and a
``reconcile_*`` function that attempts to repair them. Both are exported
from this package so callers can compose individual checks from
management commands, post-job hooks, and periodic Celery tasks.
"""

from ami.main.checks.occurrences import (
IntegrityCheckResult,
get_occurrences_missing_determination,
reconcile_missing_determinations,
)

__all__ = [
"IntegrityCheckResult",
"get_occurrences_missing_determination",
"reconcile_missing_determinations",
]
99 changes: 99 additions & 0 deletions ami/main/checks/occurrences.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""Integrity checks for Occurrence records."""

import dataclasses
import logging

logger = logging.getLogger(__name__)


@dataclasses.dataclass
class IntegrityCheckResult:
checked: int = 0
fixed: int = 0
unfixable: int = 0


def get_occurrences_missing_determination(
project_id: int | None = None,
job_id: int | None = None,
):
"""Return occurrences that have classifications but no determination set.

Occurrences without any classifications are excluded because they
legitimately have no determination yet.
"""
from ami.main.models import Occurrence

qs = Occurrence.objects.filter(
determination__isnull=True,
detections__classifications__isnull=False,
).distinct()

if project_id is not None:
qs = qs.filter(project_id=project_id)

if job_id is not None:
from ami.jobs.models import Job

job = Job.objects.get(pk=job_id)
if job.pipeline_id:
qs = qs.filter(
detections__classifications__algorithm__in=job.pipeline.algorithms.all(),
project_id=job.project_id,
)

return qs


def reconcile_missing_determinations(
project_id: int | None = None,
job_id: int | None = None,
occurrence_ids: list[int] | None = None,
dry_run: bool = True,
) -> IntegrityCheckResult:
"""Find occurrences missing determinations and repair them.

Re-runs ``update_occurrence_determination`` on each affected row so the
best available identification or prediction is promoted to the
determination field. Occurrences that can't be resolved (e.g. no viable
prediction) are counted as ``unfixable``.
"""
from ami.main.models import Occurrence, update_occurrence_determination

if occurrence_ids is not None:
occurrences = Occurrence.objects.filter(
pk__in=occurrence_ids,
determination__isnull=True,
detections__classifications__isnull=False,
).distinct()
else:
occurrences = get_occurrences_missing_determination(
project_id=project_id,
job_id=job_id,
)

result = IntegrityCheckResult(checked=occurrences.count())

if result.checked == 0 or dry_run:
return result

logger.info("Found %d occurrences missing determination", result.checked)

for occurrence in occurrences.iterator():
try:
updated = update_occurrence_determination(occurrence, current_determination=None, save=True)
if updated:
result.fixed += 1
else:
result.unfixable += 1
except Exception:
result.unfixable += 1
logger.exception("Error reconciling occurrence %s", occurrence.pk)

logger.info(
"Integrity check reconciliation: %d fixed, %d unfixable out of %d checked",
result.fixed,
result.unfixable,
result.checked,
)
return result
40 changes: 40 additions & 0 deletions ami/main/management/commands/check_data_integrity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import argparse

from django.core.management.base import BaseCommand

from ami.main.checks import reconcile_missing_determinations


class Command(BaseCommand):
help = "Run data integrity checks and optionally fix issues."

def add_arguments(self, parser):
parser.add_argument(
"--dry-run",
action=argparse.BooleanOptionalAction,
default=True,
help="Report issues without fixing them (default). Pass --no-dry-run to apply fixes.",
)
parser.add_argument("--project", type=int, help="Limit to a specific project ID")
parser.add_argument("--job", type=int, help="Limit to occurrences related to a specific job ID")

def handle(self, *args, **options):
dry_run = options["dry_run"]
if dry_run:
self.stdout.write("DRY RUN — no changes will be made. Pass --no-dry-run to apply fixes.\n")

result = reconcile_missing_determinations(
project_id=options.get("project"),
job_id=options.get("job"),
dry_run=dry_run,
)

self.stdout.write(f"Occurrences missing determination: {result.checked}")
if result.fixed:
self.stdout.write(self.style.SUCCESS(f" Fixed: {result.fixed}"))
if result.unfixable:
self.stdout.write(self.style.WARNING(f" Unfixable: {result.unfixable}"))
if result.checked == 0:
self.stdout.write(self.style.SUCCESS("No issues found."))
elif dry_run:
self.stdout.write(self.style.NOTICE("Run with --no-dry-run to apply fixes."))
29 changes: 29 additions & 0 deletions ami/main/tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import logging

from ami.tasks import default_soft_time_limit, default_time_limit
from config import celery_app

logger = logging.getLogger(__name__)


@celery_app.task(soft_time_limit=default_soft_time_limit, time_limit=default_time_limit)
def check_data_integrity():
"""Periodic integrity check for occurrence data.

Schedule via django_celery_beat in the Django admin:
Task: ami.main.tasks.check_data_integrity
"""
from ami.main.checks import reconcile_missing_determinations

result = reconcile_missing_determinations(dry_run=False)
logger.info(
"Data integrity check: %d checked, %d fixed, %d unfixable",
result.checked,
result.fixed,
result.unfixable,
)
return {
"checked": result.checked,
"fixed": result.fixed,
"unfixable": result.unfixable,
}
126 changes: 126 additions & 0 deletions ami/main/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3782,3 +3782,129 @@ def test_list_pipelines_public_project_non_member(self):
self.client.force_authenticate(user=non_member)
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)


class TestIntegrityChecks(TestCase):
"""Tests for the ami.main.checks integrity check framework."""

def setUp(self):
from ami.tests.fixtures.main import create_captures

self.project, self.deployment = setup_test_project(reuse=False)
create_captures(deployment=self.deployment, num_nights=1, images_per_night=3)
group_images_into_events(deployment=self.deployment)
create_taxa(project=self.project)
create_occurrences(deployment=self.deployment, num=3)

# Every fixture occurrence should start with a determination set.
self.assertEqual(
Occurrence.objects.filter(project=self.project, determination__isnull=False).count(),
3,
)

def _null_out_determinations(self, occurrence_pks: list[int]) -> None:
"""Simulate the partial-save bug by clearing determinations via raw UPDATE."""
Occurrence.objects.filter(pk__in=occurrence_pks).update(
determination=None,
determination_score=None,
)

def test_get_missing_determination_finds_only_affected_rows(self):
from ami.main.checks import get_occurrences_missing_determination

broken_pks = list(Occurrence.objects.filter(project=self.project).values_list("pk", flat=True)[:2])
self._null_out_determinations(broken_pks)

qs = get_occurrences_missing_determination(project_id=self.project.pk)

self.assertEqual(set(qs.values_list("pk", flat=True)), set(broken_pks))

def test_get_missing_determination_excludes_occurrences_without_classifications(self):
from ami.main.checks import get_occurrences_missing_determination

empty_occurrence = Occurrence.objects.create(project=self.project, deployment=self.deployment)

qs = get_occurrences_missing_determination(project_id=self.project.pk)

self.assertNotIn(empty_occurrence.pk, list(qs.values_list("pk", flat=True)))

def test_reconcile_dry_run_reports_without_saving(self):
from ami.main.checks import reconcile_missing_determinations

broken_pks = list(Occurrence.objects.filter(project=self.project).values_list("pk", flat=True)[:2])
self._null_out_determinations(broken_pks)

result = reconcile_missing_determinations(project_id=self.project.pk, dry_run=True)

self.assertEqual(result.checked, 2)
self.assertEqual(result.fixed, 0)
self.assertEqual(result.unfixable, 0)
self.assertEqual(
Occurrence.objects.filter(pk__in=broken_pks, determination__isnull=True).count(),
2,
"dry_run must not modify the database",
)

def test_reconcile_fixes_missing_determinations(self):
from ami.main.checks import reconcile_missing_determinations

broken_pks = list(Occurrence.objects.filter(project=self.project).values_list("pk", flat=True)[:2])
self._null_out_determinations(broken_pks)

result = reconcile_missing_determinations(project_id=self.project.pk, dry_run=False)

self.assertEqual(result.checked, 2)
self.assertEqual(result.fixed, 2)
self.assertEqual(result.unfixable, 0)
self.assertFalse(
Occurrence.objects.filter(pk__in=broken_pks, determination__isnull=True).exists(),
"reconcile should have populated determination from best prediction",
)

def test_reconcile_scoped_by_project_ignores_other_projects(self):
from ami.main.checks import reconcile_missing_determinations
from ami.tests.fixtures.main import create_captures

other_project, other_deployment = setup_test_project(reuse=False)
create_captures(deployment=other_deployment, num_nights=1, images_per_night=3)
group_images_into_events(deployment=other_deployment)
create_taxa(project=other_project)
create_occurrences(deployment=other_deployment, num=2)

other_broken = list(Occurrence.objects.filter(project=other_project).values_list("pk", flat=True))
mine_broken = list(Occurrence.objects.filter(project=self.project).values_list("pk", flat=True)[:1])
self._null_out_determinations(other_broken + mine_broken)

result = reconcile_missing_determinations(project_id=self.project.pk, dry_run=False)

self.assertEqual(result.checked, 1)
self.assertEqual(result.fixed, 1)
self.assertTrue(
Occurrence.objects.filter(pk__in=other_broken, determination__isnull=True).exists(),
"other project's occurrences must not be touched",
)

def test_reconcile_scoped_by_occurrence_ids(self):
from ami.main.checks import reconcile_missing_determinations

all_pks = list(Occurrence.objects.filter(project=self.project).values_list("pk", flat=True))
self._null_out_determinations(all_pks)

result = reconcile_missing_determinations(occurrence_ids=all_pks[:1], dry_run=False)

self.assertEqual(result.checked, 1)
self.assertEqual(result.fixed, 1)
self.assertEqual(
Occurrence.objects.filter(pk__in=all_pks[1:], determination__isnull=True).count(),
len(all_pks) - 1,
"occurrences not listed must not be fixed",
)

def test_reconcile_no_issues_returns_zero_checked(self):
from ami.main.checks import reconcile_missing_determinations

result = reconcile_missing_determinations(project_id=self.project.pk, dry_run=False)

self.assertEqual(result.checked, 0)
self.assertEqual(result.fixed, 0)
self.assertEqual(result.unfixable, 0)
Loading