diff --git a/ami/ml/models/pipeline.py b/ami/ml/models/pipeline.py index 5d14fc24b..c259e4aea 100644 --- a/ami/ml/models/pipeline.py +++ b/ami/ml/models/pipeline.py @@ -995,6 +995,10 @@ def save_results( event_ids = [img.event_id for img in source_images] # type: ignore update_calculated_fields_for_events(pks=event_ids) + deployment_ids = {img.deployment_id for img in source_images if img.deployment_id} + for deployment in Deployment.objects.filter(pk__in=deployment_ids): + deployment.update_calculated_fields(save=True) + total_time = time.time() - start_time job_logger.info(f"Saved results from pipeline {pipeline} in {total_time:.2f} seconds") diff --git a/ami/ml/tests.py b/ami/ml/tests.py index 559cbe9ad..6bc55bdf1 100644 --- a/ami/ml/tests.py +++ b/ami/ml/tests.py @@ -10,7 +10,9 @@ from ami.base.serializers import reverse_with_params from ami.main.models import ( Classification, + Deployment, Detection, + Event, Project, SourceImage, SourceImageCollection, @@ -1404,3 +1406,110 @@ def test_update_state_returns_none_when_state_genuinely_missing(self): # Do NOT call initialize_job — the total key doesn't exist. progress = self.manager.update_state({"img1", "img2"}, "process") self.assertIsNone(progress) + + +class TestSaveResultsRefreshesDeploymentCounts(TestCase): + """save_results must refresh Deployment cached counts, not just Event counts. + + Reproduces the "Station counts for occurrences and taxa are not always + getting updated" report: prior to the fix, save_results refreshed + update_calculated_fields_for_events but never the parent Deployment, so + deployment.occurrences_count / taxa_count stayed at the pre-job value + until something else (a manual deployment.save) ran. + """ + + def setUp(self): + self.project = Project.objects.create(name="Refresh Counts Project") + self.deployment = Deployment.objects.create(name="d1", project=self.project) + event_time = datetime.datetime(2026, 4, 16, 22, 0, 0) + self.event = Event.objects.create( + project=self.project, + deployment=self.deployment, + group_by="2026-04-16", + start=event_time, + end=event_time, + ) + self.image = SourceImage.objects.create( + deployment=self.deployment, + project=self.project, + event=self.event, + timestamp=event_time, + path="refresh_counts_test.jpg", + ) + self.collection = SourceImageCollection.objects.create(project=self.project, name="c") + self.collection.images.add(self.image) + + self.pipeline = Pipeline.objects.create(name="Refresh Counts Pipeline (Random)") + self.algorithms = { + key: get_or_create_algorithm_and_category_map(val) for key, val in ALGORITHM_CHOICES.items() + } + self.pipeline.algorithms.set( + [ + self.algorithms["random-detector"], + self.algorithms["random-binary-classifier"], + self.algorithms["random-species-classifier"], + ] + ) + + self.deployment.update_calculated_fields(save=True) + self.deployment.refresh_from_db() + self.assertEqual(self.deployment.occurrences_count, 0) + self.assertEqual(self.deployment.taxa_count, 0) + + def _fake_results(self): + detector = ALGORITHM_CHOICES["random-detector"] + binary_classifier = ALGORITHM_CHOICES["random-binary-classifier"] + species_classifier = ALGORITHM_CHOICES["random-species-classifier"] + assert binary_classifier.category_map and species_classifier.category_map + + detection = DetectionResponse( + source_image_id=self.image.pk, + bbox=BoundingBox(x1=0.0, y1=0.0, x2=1.0, y2=1.0), + inference_time=0.1, + algorithm=AlgorithmReference(name=detector.name, key=detector.key), + timestamp=self.image.timestamp, + classifications=[ + ClassificationResponse( + classification=binary_classifier.category_map.labels[0], + labels=binary_classifier.category_map.labels, + scores=[0.95], + algorithm=AlgorithmReference(name=binary_classifier.name, key=binary_classifier.key), + timestamp=self.image.timestamp, + terminal=False, + ), + ClassificationResponse( + classification=species_classifier.category_map.labels[0], + labels=species_classifier.category_map.labels, + scores=[0.85], + algorithm=AlgorithmReference(name=species_classifier.name, key=species_classifier.key), + timestamp=self.image.timestamp, + terminal=True, + ), + ], + ) + return PipelineResultsResponse( + pipeline=self.pipeline.slug, + algorithms={ + detector.key: detector, + binary_classifier.key: binary_classifier, + species_classifier.key: species_classifier, + }, + total_time=0.01, + source_images=[SourceImageResponse(id=self.image.pk, url=self.image.path)], + detections=[detection], + ) + + def test_deployment_counts_refresh_after_save_results(self): + save_results(self._fake_results()) + + self.deployment.refresh_from_db() + self.assertGreater( + self.deployment.occurrences_count, + 0, + "Deployment.occurrences_count should reflect occurrences created by save_results", + ) + self.assertGreater( + self.deployment.taxa_count, + 0, + "Deployment.taxa_count should reflect taxa from occurrences created by save_results", + )