diff --git a/ami/main/checks/__init__.py b/ami/main/checks/__init__.py new file mode 100644 index 000000000..5391761f0 --- /dev/null +++ b/ami/main/checks/__init__.py @@ -0,0 +1,20 @@ +"""Data integrity checks for the main app. + +Each module under this package defines one or more integrity check pairs: +a ``get_*`` function returning a queryset of affected rows, and a +``reconcile_*`` function that attempts to repair them. Both are exported +from this package so callers can compose individual checks from +management commands, post-job hooks, and periodic Celery tasks. +""" + +from ami.main.checks.occurrences import ( + IntegrityCheckResult, + get_occurrences_missing_determination, + reconcile_missing_determinations, +) + +__all__ = [ + "IntegrityCheckResult", + "get_occurrences_missing_determination", + "reconcile_missing_determinations", +] diff --git a/ami/main/checks/occurrences.py b/ami/main/checks/occurrences.py new file mode 100644 index 000000000..d64b0ac72 --- /dev/null +++ b/ami/main/checks/occurrences.py @@ -0,0 +1,99 @@ +"""Integrity checks for Occurrence records.""" + +import dataclasses +import logging + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class IntegrityCheckResult: + checked: int = 0 + fixed: int = 0 + unfixable: int = 0 + + +def get_occurrences_missing_determination( + project_id: int | None = None, + job_id: int | None = None, +): + """Return occurrences that have classifications but no determination set. + + Occurrences without any classifications are excluded because they + legitimately have no determination yet. + """ + from ami.main.models import Occurrence + + qs = Occurrence.objects.filter( + determination__isnull=True, + detections__classifications__isnull=False, + ).distinct() + + if project_id is not None: + qs = qs.filter(project_id=project_id) + + if job_id is not None: + from ami.jobs.models import Job + + job = Job.objects.get(pk=job_id) + if job.pipeline_id: + qs = qs.filter( + detections__classifications__algorithm__in=job.pipeline.algorithms.all(), + project_id=job.project_id, + ) + + return qs + + +def reconcile_missing_determinations( + project_id: int | None = None, + job_id: int | None = None, + occurrence_ids: list[int] | None = None, + dry_run: bool = True, +) -> IntegrityCheckResult: + """Find occurrences missing determinations and repair them. + + Re-runs ``update_occurrence_determination`` on each affected row so the + best available identification or prediction is promoted to the + determination field. Occurrences that can't be resolved (e.g. no viable + prediction) are counted as ``unfixable``. + """ + from ami.main.models import Occurrence, update_occurrence_determination + + if occurrence_ids is not None: + occurrences = Occurrence.objects.filter( + pk__in=occurrence_ids, + determination__isnull=True, + detections__classifications__isnull=False, + ).distinct() + else: + occurrences = get_occurrences_missing_determination( + project_id=project_id, + job_id=job_id, + ) + + result = IntegrityCheckResult(checked=occurrences.count()) + + if result.checked == 0 or dry_run: + return result + + logger.info("Found %d occurrences missing determination", result.checked) + + for occurrence in occurrences.iterator(): + try: + updated = update_occurrence_determination(occurrence, current_determination=None, save=True) + if updated: + result.fixed += 1 + else: + result.unfixable += 1 + except Exception: + result.unfixable += 1 + logger.exception("Error reconciling occurrence %s", occurrence.pk) + + logger.info( + "Integrity check reconciliation: %d fixed, %d unfixable out of %d checked", + result.fixed, + result.unfixable, + result.checked, + ) + return result diff --git a/ami/main/management/commands/check_data_integrity.py b/ami/main/management/commands/check_data_integrity.py new file mode 100644 index 000000000..fa4de5b8b --- /dev/null +++ b/ami/main/management/commands/check_data_integrity.py @@ -0,0 +1,40 @@ +import argparse + +from django.core.management.base import BaseCommand + +from ami.main.checks import reconcile_missing_determinations + + +class Command(BaseCommand): + help = "Run data integrity checks and optionally fix issues." + + def add_arguments(self, parser): + parser.add_argument( + "--dry-run", + action=argparse.BooleanOptionalAction, + default=True, + help="Report issues without fixing them (default). Pass --no-dry-run to apply fixes.", + ) + parser.add_argument("--project", type=int, help="Limit to a specific project ID") + parser.add_argument("--job", type=int, help="Limit to occurrences related to a specific job ID") + + def handle(self, *args, **options): + dry_run = options["dry_run"] + if dry_run: + self.stdout.write("DRY RUN — no changes will be made. Pass --no-dry-run to apply fixes.\n") + + result = reconcile_missing_determinations( + project_id=options.get("project"), + job_id=options.get("job"), + dry_run=dry_run, + ) + + self.stdout.write(f"Occurrences missing determination: {result.checked}") + if result.fixed: + self.stdout.write(self.style.SUCCESS(f" Fixed: {result.fixed}")) + if result.unfixable: + self.stdout.write(self.style.WARNING(f" Unfixable: {result.unfixable}")) + if result.checked == 0: + self.stdout.write(self.style.SUCCESS("No issues found.")) + elif dry_run: + self.stdout.write(self.style.NOTICE("Run with --no-dry-run to apply fixes.")) diff --git a/ami/main/tasks.py b/ami/main/tasks.py new file mode 100644 index 000000000..d5242efdf --- /dev/null +++ b/ami/main/tasks.py @@ -0,0 +1,29 @@ +import logging + +from ami.tasks import default_soft_time_limit, default_time_limit +from config import celery_app + +logger = logging.getLogger(__name__) + + +@celery_app.task(soft_time_limit=default_soft_time_limit, time_limit=default_time_limit) +def check_data_integrity(): + """Periodic integrity check for occurrence data. + + Schedule via django_celery_beat in the Django admin: + Task: ami.main.tasks.check_data_integrity + """ + from ami.main.checks import reconcile_missing_determinations + + result = reconcile_missing_determinations(dry_run=False) + logger.info( + "Data integrity check: %d checked, %d fixed, %d unfixable", + result.checked, + result.fixed, + result.unfixable, + ) + return { + "checked": result.checked, + "fixed": result.fixed, + "unfixable": result.unfixable, + } diff --git a/ami/main/tests.py b/ami/main/tests.py index 4bfbdc4de..22507d19e 100644 --- a/ami/main/tests.py +++ b/ami/main/tests.py @@ -3782,3 +3782,129 @@ def test_list_pipelines_public_project_non_member(self): self.client.force_authenticate(user=non_member) response = self.client.get(url) self.assertEqual(response.status_code, status.HTTP_200_OK) + + +class TestIntegrityChecks(TestCase): + """Tests for the ami.main.checks integrity check framework.""" + + def setUp(self): + from ami.tests.fixtures.main import create_captures + + self.project, self.deployment = setup_test_project(reuse=False) + create_captures(deployment=self.deployment, num_nights=1, images_per_night=3) + group_images_into_events(deployment=self.deployment) + create_taxa(project=self.project) + create_occurrences(deployment=self.deployment, num=3) + + # Every fixture occurrence should start with a determination set. + self.assertEqual( + Occurrence.objects.filter(project=self.project, determination__isnull=False).count(), + 3, + ) + + def _null_out_determinations(self, occurrence_pks: list[int]) -> None: + """Simulate the partial-save bug by clearing determinations via raw UPDATE.""" + Occurrence.objects.filter(pk__in=occurrence_pks).update( + determination=None, + determination_score=None, + ) + + def test_get_missing_determination_finds_only_affected_rows(self): + from ami.main.checks import get_occurrences_missing_determination + + broken_pks = list(Occurrence.objects.filter(project=self.project).values_list("pk", flat=True)[:2]) + self._null_out_determinations(broken_pks) + + qs = get_occurrences_missing_determination(project_id=self.project.pk) + + self.assertEqual(set(qs.values_list("pk", flat=True)), set(broken_pks)) + + def test_get_missing_determination_excludes_occurrences_without_classifications(self): + from ami.main.checks import get_occurrences_missing_determination + + empty_occurrence = Occurrence.objects.create(project=self.project, deployment=self.deployment) + + qs = get_occurrences_missing_determination(project_id=self.project.pk) + + self.assertNotIn(empty_occurrence.pk, list(qs.values_list("pk", flat=True))) + + def test_reconcile_dry_run_reports_without_saving(self): + from ami.main.checks import reconcile_missing_determinations + + broken_pks = list(Occurrence.objects.filter(project=self.project).values_list("pk", flat=True)[:2]) + self._null_out_determinations(broken_pks) + + result = reconcile_missing_determinations(project_id=self.project.pk, dry_run=True) + + self.assertEqual(result.checked, 2) + self.assertEqual(result.fixed, 0) + self.assertEqual(result.unfixable, 0) + self.assertEqual( + Occurrence.objects.filter(pk__in=broken_pks, determination__isnull=True).count(), + 2, + "dry_run must not modify the database", + ) + + def test_reconcile_fixes_missing_determinations(self): + from ami.main.checks import reconcile_missing_determinations + + broken_pks = list(Occurrence.objects.filter(project=self.project).values_list("pk", flat=True)[:2]) + self._null_out_determinations(broken_pks) + + result = reconcile_missing_determinations(project_id=self.project.pk, dry_run=False) + + self.assertEqual(result.checked, 2) + self.assertEqual(result.fixed, 2) + self.assertEqual(result.unfixable, 0) + self.assertFalse( + Occurrence.objects.filter(pk__in=broken_pks, determination__isnull=True).exists(), + "reconcile should have populated determination from best prediction", + ) + + def test_reconcile_scoped_by_project_ignores_other_projects(self): + from ami.main.checks import reconcile_missing_determinations + from ami.tests.fixtures.main import create_captures + + other_project, other_deployment = setup_test_project(reuse=False) + create_captures(deployment=other_deployment, num_nights=1, images_per_night=3) + group_images_into_events(deployment=other_deployment) + create_taxa(project=other_project) + create_occurrences(deployment=other_deployment, num=2) + + other_broken = list(Occurrence.objects.filter(project=other_project).values_list("pk", flat=True)) + mine_broken = list(Occurrence.objects.filter(project=self.project).values_list("pk", flat=True)[:1]) + self._null_out_determinations(other_broken + mine_broken) + + result = reconcile_missing_determinations(project_id=self.project.pk, dry_run=False) + + self.assertEqual(result.checked, 1) + self.assertEqual(result.fixed, 1) + self.assertTrue( + Occurrence.objects.filter(pk__in=other_broken, determination__isnull=True).exists(), + "other project's occurrences must not be touched", + ) + + def test_reconcile_scoped_by_occurrence_ids(self): + from ami.main.checks import reconcile_missing_determinations + + all_pks = list(Occurrence.objects.filter(project=self.project).values_list("pk", flat=True)) + self._null_out_determinations(all_pks) + + result = reconcile_missing_determinations(occurrence_ids=all_pks[:1], dry_run=False) + + self.assertEqual(result.checked, 1) + self.assertEqual(result.fixed, 1) + self.assertEqual( + Occurrence.objects.filter(pk__in=all_pks[1:], determination__isnull=True).count(), + len(all_pks) - 1, + "occurrences not listed must not be fixed", + ) + + def test_reconcile_no_issues_returns_zero_checked(self): + from ami.main.checks import reconcile_missing_determinations + + result = reconcile_missing_determinations(project_id=self.project.pk, dry_run=False) + + self.assertEqual(result.checked, 0) + self.assertEqual(result.fixed, 0) + self.assertEqual(result.unfixable, 0)