diff --git a/.agents/AGENTS.md b/.agents/AGENTS.md index 1b6ac558e..746223e55 100644 --- a/.agents/AGENTS.md +++ b/.agents/AGENTS.md @@ -276,7 +276,8 @@ Processing services are FastAPI applications that implement the AMI ML API contr **Health Checks:** - Cached status with 3 retries and exponential backoff (0s, 2s, 4s) - Celery Beat task runs periodic checks (`ami.ml.tasks.check_processing_services_online`) -- Status stored in `ProcessingService.last_checked_live` boolean field +- Status stored in `ProcessingService.last_seen_live` boolean field +- Async/pull-mode services update status via `mark_seen()` when they register pipelines - UI shows red/green indicator based on cached status Location: `processing_services/` directory contains example implementations diff --git a/.agents/DATABASE_SCHEMA.md b/.agents/DATABASE_SCHEMA.md index 2a83bdb26..fdbbca832 100644 --- a/.agents/DATABASE_SCHEMA.md +++ b/.agents/DATABASE_SCHEMA.md @@ -255,8 +255,9 @@ erDiagram bigint id PK string name string endpoint_url - boolean last_checked_live - float last_checked_latency + datetime last_seen + boolean last_seen_live + float last_seen_latency } ProjectPipelineConfig { diff --git a/ami/jobs/tests/test_tasks.py b/ami/jobs/tests/test_tasks.py index 25e609244..daf1b6ae6 100644 --- a/ami/jobs/tests/test_tasks.py +++ b/ami/jobs/tests/test_tasks.py @@ -17,7 +17,8 @@ from ami.jobs.models import Job, JobDispatchMode, JobState, MLJob from ami.jobs.tasks import process_nats_pipeline_result from ami.main.models import Detection, Project, SourceImage, SourceImageCollection -from ami.ml.models import Pipeline +from ami.ml.models import Algorithm, Pipeline +from ami.ml.models.algorithm import AlgorithmTaskType from ami.ml.orchestration.async_job_state import AsyncJobStateManager from ami.ml.schemas import PipelineResultsError, PipelineResultsResponse, SourceImageResponse from ami.users.models import User @@ -180,6 +181,15 @@ def test_process_nats_pipeline_result_mixed_results(self, mock_manager_class): """ mock_manager = self._setup_mock_nats(mock_manager_class) + # Create detection algorithm for the pipeline + detection_algorithm = Algorithm.objects.create( + name="test-detector", + key="test-detector", + task_type=AlgorithmTaskType.LOCALIZATION, + ) + # Update pipeline to include detection algorithm + self.pipeline.algorithms.add(detection_algorithm) + # For this test, we just want to verify progress tracking works with mixed results # We'll skip checking final job completion status since that depends on all stages diff --git a/ami/jobs/views.py b/ami/jobs/views.py index 6d0626f23..832e15f30 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -30,6 +30,30 @@ logger = logging.getLogger(__name__) +def _mark_pipeline_pull_services_seen(job: "Job") -> None: + """ + Record a heartbeat for async (pull-mode) processing services linked to the job's pipeline. + + Called on every task-fetch and result-submit request so that the worker's polling activity + keeps last_seen/last_seen_live current. The periodic check_processing_services_online task + will mark services offline if this heartbeat stops arriving within PROCESSING_SERVICE_LAST_SEEN_MAX. + + IMPORTANT: This marks ALL async services on the pipeline within this project as live, not just + the specific service that made the request. If multiple async services share the same pipeline + within a project, a single worker polling will keep all of them appearing online. + Once application-token auth is available (PR #1117), this should be scoped to the individual + calling service instead. + """ + import datetime + + if not job.pipeline_id: + return + job.pipeline.processing_services.async_services().filter(projects=job.project_id).update( + last_seen=datetime.datetime.now(), + last_seen_live=True, + ) + + class JobFilterSet(filters.FilterSet): """Custom filterset to enable pipeline name filtering.""" @@ -245,6 +269,9 @@ def tasks(self, request, pk=None): if not job.pipeline: raise ValidationError("This job does not have a pipeline configured") + # Record heartbeat for async processing services on this pipeline + _mark_pipeline_pull_services_seen(job) + # Get tasks from NATS JetStream from ami.ml.orchestration.nats_queue import TaskQueueManager @@ -272,6 +299,9 @@ def result(self, request, pk=None): job = self.get_object() + # Record heartbeat for async processing services on this pipeline + _mark_pipeline_pull_services_seen(job) + # Validate request data is a list if isinstance(request.data, list): results = request.data diff --git a/ami/main/admin.py b/ami/main/admin.py index c6170b153..1ffc6d5fc 100644 --- a/ami/main/admin.py +++ b/ami/main/admin.py @@ -265,6 +265,7 @@ class SourceImageAdmin(AdminBase): "checksum", "checksum_algorithm", "created_at", + "get_was_processed", ) list_filter = ( @@ -281,7 +282,12 @@ class SourceImageAdmin(AdminBase): ) def get_queryset(self, request: HttpRequest) -> QuerySet[Any]: - return super().get_queryset(request).select_related("event", "deployment", "deployment__data_source") + return ( + super() + .get_queryset(request) + .select_related("event", "deployment", "deployment__data_source") + .with_was_processed() # avoids N+1 from get_was_processed in list_display + ) class ClassificationInline(admin.TabularInline): diff --git a/ami/main/api/serializers.py b/ami/main/api/serializers.py index 6d0d93762..0caa7b3e6 100644 --- a/ami/main/api/serializers.py +++ b/ami/main/api/serializers.py @@ -1246,6 +1246,7 @@ class Meta: "source_images", "source_images_count", "source_images_with_detections_count", + "source_images_processed_count", "occurrences_count", "taxa_count", "description", @@ -1547,6 +1548,7 @@ class EventTimelineIntervalSerializer(serializers.Serializer): captures_count = serializers.IntegerField() detections_count = serializers.IntegerField() detections_avg = serializers.IntegerField() + was_processed = serializers.BooleanField() class EventTimelineMetaSerializer(serializers.Serializer): diff --git a/ami/main/api/views.py b/ami/main/api/views.py index 18536e7d5..75951535a 100644 --- a/ami/main/api/views.py +++ b/ami/main/api/views.py @@ -36,6 +36,7 @@ from ami.utils.storages import ConnectionTestResult from ..models import ( + NULL_DETECTIONS_FILTER, Classification, Deployment, Detection, @@ -378,7 +379,7 @@ def timeline(self, request, pk=None): ) resolution = datetime.timedelta(minutes=resolution_minutes) - qs = SourceImage.objects.filter(event=event) + qs = SourceImage.objects.filter(event=event).with_was_processed() # type: ignore # Bulk update all source images where detections_count is null update_detection_counts(qs=qs, null_only=True) @@ -404,7 +405,7 @@ def timeline(self, request, pk=None): source_images = list( qs.filter(timestamp__range=(start_time, end_time)) .order_by("timestamp") - .values("id", "timestamp", "detections_count") + .values("id", "timestamp", "detections_count", "was_processed") ) timeline = [] @@ -421,6 +422,7 @@ def timeline(self, request, pk=None): "captures_count": 0, "detections_count": 0, "detection_counts": [], + "was_processed": False, } while image_index < len(source_images) and source_images[image_index]["timestamp"] <= interval_end: @@ -432,6 +434,9 @@ def timeline(self, request, pk=None): interval_data["detection_counts"] += [image["detections_count"]] if image["detections_count"] >= max(interval_data["detection_counts"]): interval_data["top_capture"] = SourceImage(pk=image["id"]) + # Track if any image in this interval was processed + if image["was_processed"]: + interval_data["was_processed"] = True image_index += 1 # Set a meaningful average detection count to display for the interval @@ -602,7 +607,7 @@ def prefetch_detections(self, queryset: QuerySet, project: Project | None = None score = get_default_classification_threshold(project, self.request) prefetch_queryset = ( - Detection.objects.all() + Detection.objects.exclude(NULL_DETECTIONS_FILTER) .annotate( determination_score=models.Max("occurrence__detections__classifications__score"), # Store whether this occurrence should be included based on default filters @@ -709,6 +714,7 @@ class SourceImageCollectionViewSet(DefaultViewSet, ProjectMixin): SourceImageCollection.objects.all() .with_source_images_count() # type: ignore .with_source_images_with_detections_count() + .with_source_images_processed_count() .prefetch_related("jobs") ) serializer_class = SourceImageCollectionSerializer @@ -724,6 +730,7 @@ class SourceImageCollectionViewSet(DefaultViewSet, ProjectMixin): "method", "source_images_count", "source_images_with_detections_count", + "source_images_processed_count", "occurrences_count", ] @@ -898,7 +905,7 @@ class DetectionViewSet(DefaultViewSet, ProjectMixin): API endpoint that allows detections to be viewed or edited. """ - queryset = Detection.objects.all().select_related("source_image", "detection_algorithm") + queryset = Detection.objects.exclude(NULL_DETECTIONS_FILTER).select_related("source_image", "detection_algorithm") serializer_class = DetectionSerializer filterset_fields = ["source_image", "detection_algorithm", "source_image__project"] ordering_fields = ["created_at", "updated_at", "detection_score", "timestamp"] diff --git a/ami/main/integrity.py b/ami/main/integrity.py new file mode 100644 index 000000000..09bf9cf0c --- /dev/null +++ b/ami/main/integrity.py @@ -0,0 +1,106 @@ +""" +Data integrity checks for the main app. + +Functions here can be called from management commands, post-job hooks, +or periodic Celery tasks. +""" + +import dataclasses +import logging + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class ReconcileResult: + checked: int = 0 + fixed: int = 0 + unfixable: int = 0 + + +def get_occurrences_missing_determination( + project_id: int | None = None, + job_id: int | None = None, +): + """ + Return occurrences that have detections with classifications but no determination set. + + Occurrences with no classifications at all are excluded (they legitimately have no + determination). + """ + from ami.main.models import Occurrence + + qs = Occurrence.objects.filter( + determination__isnull=True, + detections__classifications__isnull=False, + ).distinct() + + if project_id is not None: + qs = qs.filter(project_id=project_id) + + if job_id is not None: + from ami.jobs.models import Job + + job = Job.objects.get(pk=job_id) + if job.pipeline: + qs = qs.filter( + detections__classifications__algorithm__in=job.pipeline.algorithms.all(), + project_id=job.project_id, + ) + + return qs + + +def reconcile_missing_determinations( + project_id: int | None = None, + job_id: int | None = None, + occurrence_ids: list[int] | None = None, + dry_run: bool = False, +) -> ReconcileResult: + """ + Find occurrences missing determinations and attempt to fix them by re-running + update_occurrence_determination. + """ + from ami.main.models import update_occurrence_determination + + if occurrence_ids is not None: + from ami.main.models import Occurrence + + occurrences = ( + Occurrence.objects.filter( + pk__in=occurrence_ids, + determination__isnull=True, + detections__classifications__isnull=False, + ) + .distinct() + .select_related("determination") + ) + else: + occurrences = get_occurrences_missing_determination( + project_id=project_id, + job_id=job_id, + ).select_related("determination") + + result = ReconcileResult(checked=occurrences.count()) + + if result.checked == 0 or dry_run: + return result + + logger.info(f"Found {result.checked} occurrences missing determination") + + for occurrence in occurrences.iterator(): + try: + updated = update_occurrence_determination(occurrence, current_determination=None, save=True) + if updated: + result.fixed += 1 + else: + result.unfixable += 1 + except Exception: + result.unfixable += 1 + logger.exception(f"Error reconciling occurrence {occurrence.pk}") + + logger.info( + f"Reconciliation complete: {result.fixed} fixed, {result.unfixable} unfixable " + f"out of {result.checked} checked" + ) + return result diff --git a/ami/main/management/commands/check_data_integrity.py b/ami/main/management/commands/check_data_integrity.py new file mode 100644 index 000000000..279d1d7b7 --- /dev/null +++ b/ami/main/management/commands/check_data_integrity.py @@ -0,0 +1,35 @@ +import logging + +from django.core.management.base import BaseCommand + +from ami.main.integrity import reconcile_missing_determinations + +logger = logging.getLogger(__name__) + + +class Command(BaseCommand): + help = "Find and fix occurrences missing determinations." + + def add_arguments(self, parser): + parser.add_argument("--dry-run", action="store_true", help="Report issues without fixing them") + parser.add_argument("--project", type=int, help="Limit to a specific project ID") + parser.add_argument("--job", type=int, help="Limit to occurrences related to a specific job ID") + + def handle(self, *args, **options): + dry_run = options["dry_run"] + if dry_run: + self.stdout.write("DRY RUN — no changes will be made\n") + + result = reconcile_missing_determinations( + project_id=options.get("project"), + job_id=options.get("job"), + dry_run=dry_run, + ) + + self.stdout.write(f"Checked: {result.checked}") + if result.fixed: + self.stdout.write(self.style.SUCCESS(f"Fixed: {result.fixed}")) + if result.unfixable: + self.stdout.write(self.style.WARNING(f"Unfixable: {result.unfixable}")) + if result.checked == 0: + self.stdout.write(self.style.SUCCESS("No issues found.")) diff --git a/ami/main/models.py b/ami/main/models.py index 0bad68531..2dc5e13bd 100644 --- a/ami/main/models.py +++ b/ami/main/models.py @@ -85,6 +85,8 @@ class TaxonRank(OrderedEnum): ] ) +NULL_DETECTIONS_FILTER = Q(bbox__isnull=True) | Q(bbox=[]) + def get_media_url(path: str) -> str: """ @@ -1775,6 +1777,19 @@ def with_taxa_count(self, project: Project | None = None, request=None): taxa_count=Coalesce(models.Subquery(taxa_subquery, output_field=models.IntegerField()), 0) ) + def with_was_processed(self): + """ + Annotate each SourceImage with a boolean `was_processed` indicating + whether any detections exist for that image. + + This mirrors `SourceImage.get_was_processed()` but as a queryset + annotation for efficient bulk queries. + """ + # @TODO: this returns a was processed status for any algorithm. One the session detail view supports + # filtering by algorithm, this should be updated to return was_processed for the selected algorithm. + processed_exists = models.Exists(Detection.objects.filter(source_image_id=models.OuterRef("pk"))) + return self.annotate(was_processed=processed_exists) + class SourceImageManager(models.Manager.from_queryset(SourceImageQuerySet)): pass @@ -1874,7 +1889,29 @@ def size_display(self) -> str: return filesizeformat(self.size) def get_detections_count(self) -> int: - return self.detections.distinct().count() + # Detections count excludes detections without bounding boxes + # Detections with null bounding boxes are valid and indicates the image was successfully processed + return self.detections.exclude(NULL_DETECTIONS_FILTER).count() + + def get_was_processed(self, algorithm_key: str | None = None) -> bool: + """ + Return True if this image has been processed by any algorithm (or a specific one). + + Uses the ``was_processed`` annotation when available (set by + ``SourceImageQuerySet.with_was_processed()``). Falls back to a DB query otherwise. + + Do not call in bulk without the annotation — use ``with_was_processed()`` + on the queryset instead to avoid N+1 queries. + + :param algorithm_key: If provided, only detections from this algorithm are checked. + The annotation does not filter by algorithm; per-algorithm + checks always use a DB query. + """ + if algorithm_key is None and hasattr(self, "was_processed"): + return self.was_processed # type: ignore[return-value] + if algorithm_key: + return self.detections.filter(detection_algorithm__key=algorithm_key).exists() + return self.detections.exists() def get_base_url(self) -> str | None: """ @@ -2044,6 +2081,7 @@ def update_detection_counts(qs: models.QuerySet[SourceImage] | None = None, null subquery = models.Subquery( Detection.objects.filter(source_image_id=models.OuterRef("pk")) + .exclude(NULL_DETECTIONS_FILTER) .values("source_image_id") .annotate(count=models.Count("id")) .values("count"), @@ -2514,6 +2552,15 @@ def save(self, *args, **kwargs): super().save(*args, **kwargs) +class DetectionQuerySet(BaseQuerySet): + def null_detections(self): + return self.filter(NULL_DETECTIONS_FILTER) + + +class DetectionManager(models.Manager.from_queryset(DetectionQuerySet)): + pass + + @final class Detection(BaseModel): """An object detected in an image""" @@ -2582,6 +2629,8 @@ class Detection(BaseModel): source_image_id: int detection_algorithm_id: int + objects = DetectionManager() + def get_bbox(self): if self.bbox: return BoundingBox( @@ -3752,7 +3801,18 @@ def with_source_images_count(self): def with_source_images_with_detections_count(self): return self.annotate( source_images_with_detections_count=models.Count( - "images", filter=models.Q(images__detections__isnull=False), distinct=True + "images", + filter=(~models.Q(images__detections__bbox__isnull=True) & ~models.Q(images__detections__bbox=[])), + distinct=True, + ) + ) + + def with_source_images_processed_count(self): + return self.annotate( + source_images_processed_count=models.Count( + "images", + filter=models.Q(images__detections__isnull=False), + distinct=True, ) ) @@ -3863,7 +3923,10 @@ def source_images_count(self) -> int | None: def source_images_with_detections_count(self) -> int | None: # This should always be pre-populated using queryset annotations - # return self.images.filter(detections__isnull=False).count() + return None + + def source_images_processed_count(self) -> int | None: + # This should always be pre-populated using queryset annotations return None def occurrences_count(self) -> int | None: diff --git a/ami/main/tasks.py b/ami/main/tasks.py new file mode 100644 index 000000000..0abf778b4 --- /dev/null +++ b/ami/main/tasks.py @@ -0,0 +1,21 @@ +import logging + +from config import celery_app + +logger = logging.getLogger(__name__) + + +@celery_app.task(soft_time_limit=300, time_limit=360) +def check_data_integrity(): + """ + Periodic task to find and fix occurrences missing determinations. + + Register via django_celery_beat in Django admin: + Task: ami.main.tasks.check_data_integrity + Schedule: e.g. every 24 hours + """ + from ami.main.integrity import reconcile_missing_determinations + + result = reconcile_missing_determinations() + logger.info(f"Data integrity check: {result.checked} checked, {result.fixed} fixed, {result.unfixable} unfixable") + return {"checked": result.checked, "fixed": result.fixed, "unfixable": result.unfixable} diff --git a/ami/ml/management/commands/check_dead_letter_queue.py b/ami/ml/management/commands/check_dead_letter_queue.py new file mode 100644 index 000000000..8fcec9986 --- /dev/null +++ b/ami/ml/management/commands/check_dead_letter_queue.py @@ -0,0 +1,48 @@ +""" +Management command to check dead letter queue messages for a job. + +Usage: + python manage.py check_dead_letter_queue + +Example: + python manage.py check_dead_letter_queue 123 +""" + +from asgiref.sync import async_to_sync +from django.core.management.base import BaseCommand, CommandError + +from ami.ml.orchestration.nats_queue import TaskQueueManager + + +class Command(BaseCommand): + help = "Check dead letter queue messages for a job ID" + + def add_arguments(self, parser): + parser.add_argument( + "job_id", + type=int, + help="Job ID to check for dead letter queue messages", + ) + + def handle(self, *args, **options): + job_id = options["job_id"] + + try: + dead_letter_ids = async_to_sync(self._check_dead_letter_queue)(job_id) + + if dead_letter_ids: + self.stdout.write( + self.style.WARNING(f"Found {len(dead_letter_ids)} dead letter image(s) for job {job_id}:") + ) + for image_id in dead_letter_ids: + self.stdout.write(f" - Image ID: {image_id}") + else: + self.stdout.write(self.style.SUCCESS(f"No dead letter images found for job {job_id}")) + + except Exception as e: + raise CommandError(f"Failed to check dead letter queue: {e}") + + async def _check_dead_letter_queue(self, job_id: int) -> list[str]: + """Check for dead letter queue messages using TaskQueueManager.""" + async with TaskQueueManager() as manager: + return await manager.get_dead_letter_image_ids(job_id) diff --git a/ami/ml/migrations/0027_rename_last_checked_to_last_seen.py b/ami/ml/migrations/0027_rename_last_checked_to_last_seen.py new file mode 100644 index 000000000..4f14eee7c --- /dev/null +++ b/ami/ml/migrations/0027_rename_last_checked_to_last_seen.py @@ -0,0 +1,26 @@ +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ("ml", "0026_make_processing_service_endpoint_url_nullable"), + ] + + operations = [ + migrations.RenameField( + model_name="processingservice", + old_name="last_checked", + new_name="last_seen", + ), + migrations.RenameField( + model_name="processingservice", + old_name="last_checked_live", + new_name="last_seen_live", + ), + migrations.RenameField( + model_name="processingservice", + old_name="last_checked_latency", + new_name="last_seen_latency", + ), + ] diff --git a/ami/ml/models/pipeline.py b/ami/ml/models/pipeline.py index f76822e3a..aa2b9baba 100644 --- a/ami/ml/models/pipeline.py +++ b/ami/ml/models/pipeline.py @@ -23,6 +23,7 @@ from ami.base.models import BaseModel, BaseQuerySet from ami.base.schemas import ConfigurableStage, default_stages from ami.main.models import ( + NULL_DETECTIONS_FILTER, Classification, Deployment, Detection, @@ -63,10 +64,21 @@ def filter_processed_images( ) -> typing.Iterable[SourceImage]: """ Return only images that need to be processed by a given pipeline. - An image needs processing if: - 1. It has no detections from the pipeline's detection algorithm - or - 2. It has detections but they don't have classifications from all the pipeline's classification algorithms + + Each image is checked against its existing detections from this pipeline's algorithms: + + YIELD (needs processing): + 1. No existing detections at all — image has never been processed by this pipeline + 2. Has real detections without classifications — detector ran but classifier didn't + 3. Has real detections with classifications, but not from all pipeline classifiers — + e.g. a new classifier was added to the pipeline since last run + + SKIP (already processed): + 4. Only null detections exist (bbox=None) — pipeline ran but found nothing + 5. Real detections exist and are fully classified by all pipeline classifiers + + Null detections are sentinels that mark an image as "processed, nothing found." + They are excluded from classification checks so they don't trigger reprocessing. """ pipeline_algorithms = pipeline.algorithms.all() @@ -84,8 +96,12 @@ def filter_processed_images( task_logger.debug(f"Image {image} needs processing: has no existing detections from pipeline's detector") # If there are no existing detections from this pipeline, send the image yield image - elif existing_detections.filter(classifications__isnull=True).exists(): - # Check if there are detections with no classifications + elif not existing_detections.exclude(NULL_DETECTIONS_FILTER).exists(): # type: ignore + # All detections for this image are null (processed but nothing found) — skip + task_logger.debug(f"Image {image} has only null detections from pipeline {pipeline}, skipping!") + continue + elif existing_detections.exclude(NULL_DETECTIONS_FILTER).filter(classifications__isnull=True).exists(): + # Check if any real detections (non-null) have no classifications task_logger.debug( f"Image {image} needs processing: has existing detections with no classifications " "from pipeline {pipeline}" @@ -402,26 +418,54 @@ def get_or_create_detection( :param detection_resp: A DetectionResponse object :param algorithms_known: A dictionary of algorithms registered in the pipeline, keyed by the algorithm key - :param created_objects: A list to store created objects - :return: A tuple of the Detection object and a boolean indicating whether it was created + + For real detections (bbox is not None), the lookup is algorithm-agnostic — the same + bounding box on the same image is the same physical detection regardless of which algorithm + found it; duplicates are avoided this way. + + For null detections (bbox=None), the lookup is algorithm-specific — null is a sentinel value + (not a physical detection), so each algorithm gets its own null detection. This ensures + get_was_processed(algorithm_key=...) returns correct per-algorithm processed status. """ - serialized_bbox = list(detection_resp.bbox.dict().values()) + if detection_resp.bbox is not None: + serialized_bbox = list(detection_resp.bbox.dict().values()) + else: + serialized_bbox = None detection_repr = f"Detection {detection_resp.source_image_id} {serialized_bbox}" assert str(detection_resp.source_image_id) == str( source_image.pk ), f"Detection belongs to a different source image: {detection_repr}" - existing_detection = Detection.objects.filter( - source_image=source_image, - bbox=serialized_bbox, - ).first() + if serialized_bbox is None: + # Null detection: algorithm-specific lookup so different pipelines don't share sentinels. + # Use bbox__isnull=True because JSONField filter(bbox=None) matches JSON null literal, + # not SQL NULL which is what Detection(bbox=None) stores. + assert detection_resp.algorithm, f"No detection algorithm was specified for detection {detection_repr}" + try: + detection_algo = algorithms_known[detection_resp.algorithm.key] + except KeyError as err: + raise PipelineNotConfigured( + f"Detection algorithm {detection_resp.algorithm.key} is not a known algorithm. " + "The processing service must declare it in the /info endpoint. " + f"Known algorithms: {list(algorithms_known.keys())}" + ) from err + existing_detection = Detection.objects.filter( + source_image=source_image, + bbox__isnull=True, + detection_algorithm=detection_algo, + ).first() + else: + # Real detection: algorithm-agnostic — same bbox = same physical detection + existing_detection = Detection.objects.filter( + source_image=source_image, + bbox=serialized_bbox, + ).first() # A detection may have a pre-existing crop image URL or not. # If not, a new one will be created in a periodic background task. if detection_resp.crop_image_url and detection_resp.crop_image_url.strip("/"): - # Ensure that the crop image URL is not empty or only a slash. None is fine. crop_url = detection_resp.crop_image_url else: crop_url = None @@ -434,15 +478,17 @@ def get_or_create_detection( detection = existing_detection else: - assert detection_resp.algorithm, f"No detection algorithm was specified for detection {detection_repr}" - try: - detection_algo = algorithms_known[detection_resp.algorithm.key] - except KeyError: - raise PipelineNotConfigured( - f"Detection algorithm {detection_resp.algorithm.key} is not a known algorithm. " - "The processing service must declare it in the /info endpoint. " - f"Known algorithms: {list(algorithms_known.keys())}" - ) + # Resolve algorithm for creation (null detections already resolved above) + if serialized_bbox is not None: + assert detection_resp.algorithm, f"No detection algorithm was specified for detection {detection_repr}" + try: + detection_algo = algorithms_known[detection_resp.algorithm.key] + except KeyError as err: + raise PipelineNotConfigured( + f"Detection algorithm {detection_resp.algorithm.key} is not a known algorithm. " + "The processing service must declare it in the /info endpoint. " + f"Known algorithms: {list(algorithms_known.keys())}" + ) from err new_detection = Detection( source_image=source_image, @@ -485,6 +531,7 @@ def create_detections( existing_detections: list[Detection] = [] new_detections: list[Detection] = [] + for detection_resp in detections: source_image = source_image_map.get(detection_resp.source_image_id) if not source_image: @@ -810,6 +857,37 @@ class PipelineSaveResults: total_time: float +def create_null_detections_for_undetected_images( + results: PipelineResultsResponse, + detection_algorithm: Algorithm, + logger: logging.Logger = logger, +) -> list[DetectionResponse]: + """ + Create null DetectionResponse objects (empty bbox) for images that have no detections. + + :param results: The PipelineResultsResponse from the processing service + :param algorithms_known: Dictionary of algorithms keyed by algorithm key + + :return: List of DetectionResponse objects with null bbox + """ + source_images_with_detections = {detection.source_image_id for detection in results.detections} + null_detections_to_add = [] + detection_algorithm_reference = AlgorithmReference(name=detection_algorithm.name, key=detection_algorithm.key) + + for source_img in results.source_images: + if source_img.id not in source_images_with_detections: + null_detections_to_add.append( + DetectionResponse( + source_image_id=source_img.id, + bbox=None, + algorithm=detection_algorithm_reference, + timestamp=now(), + ) + ) + + return null_detections_to_add + + @celery_app.task(soft_time_limit=60 * 4, time_limit=60 * 5) def save_results( results: PipelineResultsResponse | None = None, @@ -857,6 +935,13 @@ def save_results( ) algorithms_known: dict[str, Algorithm] = {algo.key: algo for algo in pipeline.algorithms.all()} + try: + detection_algorithm = pipeline.algorithms.get(task_type__in=Algorithm.detection_task_types) + except Algorithm.DoesNotExist: + raise ValueError("Pipeline does not have a detection algorithm") + except Algorithm.MultipleObjectsReturned: + raise NotImplementedError("Multiple detection algorithms per pipeline are not supported") + job_logger.info(f"Algorithms registered for pipeline: \n{', '.join(algorithms_known.keys())}") if results.algorithms: @@ -866,6 +951,15 @@ def save_results( "Algorithms and category maps must be registered before processing, using /info endpoint." ) + # Ensure all images have detections + # if not, add a NULL detection (empty bbox) to the results + null_detections = create_null_detections_for_undetected_images( + results=results, + detection_algorithm=detection_algorithm, + logger=job_logger, + ) + results.detections = results.detections + null_detections + detections = create_detections( detections=results.detections, algorithms_known=algorithms_known, @@ -886,6 +980,16 @@ def save_results( logger=job_logger, ) + # Check for occurrences that ended up without a determination despite having classifications. + # This can happen if the bulk_update in create_and_update_occurrences_for_detections partially fails. + from ami.main.integrity import reconcile_missing_determinations + + occurrence_ids = [d.occurrence_id for d in detections if d.occurrence_id] + if occurrence_ids: + result = reconcile_missing_determinations(occurrence_ids=occurrence_ids) + if result.fixed or result.unfixable: + job_logger.warning(f"Post-save reconciliation: {result.fixed} fixed, {result.unfixable} unfixable") + # Update precalculated counts on source images and events source_images = list(source_images) logger.info(f"Updating calculated fields for {len(source_images)} source images") @@ -949,7 +1053,7 @@ def online(self, project: Project) -> PipelineQuerySet: """ return self.filter( processing_services__projects=project, - processing_services__last_checked_live=True, + processing_services__last_seen_live=True, ).distinct() @@ -1048,7 +1152,7 @@ def collect_images( def choose_processing_service_for_pipeline( self, job_id: int | None, pipeline_name: str, project_id: int ) -> ProcessingService: - # @TODO use the cached `last_checked_latency` and a max age to avoid checking every time + # @TODO use the cached `last_seen_latency` and a max age to avoid checking every time job = None task_logger = logger @@ -1067,32 +1171,31 @@ def choose_processing_service_for_pipeline( # check the status of all processing services and pick the one with the lowest latency lowest_latency = float("inf") - processing_services_online = False + processing_service_lowest_latency = None for processing_service in processing_services: - if processing_service.last_checked_live: - processing_services_online = True - if ( - processing_service.last_checked_latency - and processing_service.last_checked_latency < lowest_latency - ): - lowest_latency = processing_service.last_checked_latency - # pick the processing service that has lowest latency + if processing_service.last_seen_live: + if processing_service.last_seen_latency and processing_service.last_seen_latency < lowest_latency: + lowest_latency = processing_service.last_seen_latency + processing_service_lowest_latency = processing_service + elif processing_service_lowest_latency is None: + # Online but no latency data (e.g. async/pull-mode service) — use as fallback processing_service_lowest_latency = processing_service - # if all offline then throw error - if not processing_services_online: + if processing_service_lowest_latency is None: msg = f'No processing services are online for the pipeline "{pipeline_name}".' task_logger.error(msg) - raise Exception(msg) - else: + + if lowest_latency < float("inf"): task_logger.info( f"Using processing service with latency {round(lowest_latency, 4)}: " f"{processing_service_lowest_latency}" ) + else: + task_logger.info(f"Using processing service (no latency data): {processing_service_lowest_latency}") - return processing_service_lowest_latency + return processing_service_lowest_latency def process_images( self, diff --git a/ami/ml/models/processing_service.py b/ami/ml/models/processing_service.py index ec7516d39..bad3dd147 100644 --- a/ami/ml/models/processing_service.py +++ b/ami/ml/models/processing_service.py @@ -22,8 +22,34 @@ logger = logging.getLogger(__name__) +# Max age of last_seen before a pull-mode (no-endpoint) service is considered offline. +# Pull-mode workers poll every ~5s, so 60s gives 12x buffer for transient failures. +PROCESSING_SERVICE_LAST_SEEN_MAX = datetime.timedelta(seconds=60) -class ProcessingServiceManager(models.Manager.from_queryset(BaseQuerySet)): + +class ProcessingServiceQuerySet(BaseQuerySet): + def async_services(self) -> "ProcessingServiceQuerySet": + """ + Filter to pull-mode (async) processing services — those with no endpoint URL. + + These correspond to jobs with dispatch_mode=ASYNC_API. Instead of Antenna calling + out to them, they poll Antenna for tasks and push results back. Their liveness is + tracked via heartbeats from mark_seen() rather than active health checks. + """ + return self.filter(models.Q(endpoint_url__isnull=True) | models.Q(endpoint_url__exact="")) + + def sync_services(self) -> "ProcessingServiceQuerySet": + """ + Filter to push-mode (sync) processing services — those with a configured endpoint URL. + + These correspond to jobs with dispatch_mode=SYNC_API. Antenna actively calls their + /readyz and /process endpoints. Their liveness is tracked by the periodic + check_processing_services_online Celery task. + """ + return self.exclude(models.Q(endpoint_url__isnull=True) | models.Q(endpoint_url__exact="")) + + +class ProcessingServiceManager(models.Manager.from_queryset(ProcessingServiceQuerySet)): """Custom manager for ProcessingService to handle specific queries.""" def create(self, **kwargs) -> "ProcessingService": @@ -41,12 +67,21 @@ class ProcessingService(BaseModel): projects = models.ManyToManyField("main.Project", related_name="processing_services", blank=True) endpoint_url = models.CharField(max_length=1024, null=True, blank=True) pipelines = models.ManyToManyField("ml.Pipeline", related_name="processing_services", blank=True) - last_checked = models.DateTimeField(null=True) - last_checked_live = models.BooleanField(null=True) - last_checked_latency = models.FloatField(null=True) + last_seen = models.DateTimeField(null=True) + last_seen_live = models.BooleanField(null=True) + last_seen_latency = models.FloatField(null=True) objects = ProcessingServiceManager() + @property + def is_async(self) -> bool: + """ + True if this is a pull-mode (async) service with no endpoint URL, corresponding to + jobs with dispatch_mode=ASYNC_API. False for push-mode services with a configured + endpoint, corresponding to jobs with dispatch_mode=SYNC_API. + """ + return not self.endpoint_url + def __str__(self): endpoint_display = self.endpoint_url or "async" return f'#{self.pk} "{self.name}" ({endpoint_display})' @@ -139,10 +174,27 @@ def create_pipelines( algorithms_created=algorithms_created, ) + def mark_seen(self, live: bool = True) -> None: + """ + Record that we heard from this processing service. + Used by async/pull-mode services that don't have an endpoint to check. + """ + self.last_seen = datetime.datetime.now() + self.last_seen_live = live + self.save(update_fields=["last_seen", "last_seen_live"]) + def get_status(self, timeout=90) -> ProcessingServiceStatusResponse: """ Check the status of the processing service. - This is a simple health check that pings the /readyz endpoint of the service. + + This check has two behaviors depending on the version of the processing service: + + If the service is a v2/pull-mode/async service with no endpoint URL, this will derive the status + from the last_seen heartbeat timestamp. If the last_seen timestamp is recent (within 60s), + the service is considered live. No requests are made by this method. + + If the service is a v1/push-mode/interactive service with an endpoint URL, this method will ping the + /readyz endpoint to check if it's live. Uses urllib3 Retry with exponential backoff to handle cold starts and transient failures. The timeout is set to 90s per attempt to accommodate serverless cold starts, especially for @@ -150,18 +202,28 @@ def get_status(self, timeout=90) -> ProcessingServiceStatusResponse: connection errors are handled gracefully. Args: - timeout: Request timeout in seconds per attempt (default: 90s for serverless cold starts) + timeout: Request timeout in seconds per attempt (default: 90s for serverless cold starts). Only applies \ + to services with an endpoint URL. """ - # If no endpoint URL is configured, return a no-op response - if self.endpoint_url is None: + # If no endpoint URL is configured, the derive status from last registration heartbeat + if not self.endpoint_url: + is_live = bool( + self.last_seen + and self.last_seen_live + and (datetime.datetime.now() - self.last_seen) < PROCESSING_SERVICE_LAST_SEEN_MAX + ) + if not is_live and self.last_seen_live: + # Heartbeat has expired — mark stale + self.last_seen_live = False + self.save(update_fields=["last_seen_live"]) + pipeline_names = list(self.pipelines.values_list("name", flat=True)) if is_live else [] return ProcessingServiceStatusResponse( - timestamp=datetime.datetime.now(), - request_successful=False, - server_live=None, - pipelines_online=[], + timestamp=self.last_seen or datetime.datetime.now(), + request_successful=is_live, + server_live=is_live, + pipelines_online=pipeline_names, pipeline_configs=[], - endpoint_url=self.endpoint_url, - error="No endpoint URL configured - service operates in pull mode", + endpoint_url=None, latency=0.0, ) @@ -171,7 +233,7 @@ def get_status(self, timeout=90) -> ProcessingServiceStatusResponse: pipeline_configs = [] pipelines_online = [] timestamp = datetime.datetime.now() - self.last_checked = timestamp + self.last_seen = timestamp resp = None # Create session with retry logic for connection errors and timeouts @@ -184,23 +246,23 @@ def get_status(self, timeout=90) -> ProcessingServiceStatusResponse: try: resp = session.get(ready_check_url, timeout=timeout) resp.raise_for_status() - self.last_checked_live = True + self.last_seen_live = True except requests.exceptions.RequestException as e: error = f"Error connecting to {ready_check_url}: {e}" logger.error(error) - self.last_checked_live = False + self.last_seen_live = False finally: latency = time.time() - start_time - self.last_checked_latency = latency + self.last_seen_latency = latency self.save( update_fields=[ - "last_checked", - "last_checked_live", - "last_checked_latency", + "last_seen", + "last_seen_live", + "last_seen_latency", ] ) - if self.last_checked_live: + if self.last_seen_live: # The specific pipeline statuses are not required for the status response # but the intention is to show which ones are loaded into memory and ready to use. # @TODO: this may be overkill, but it is displayed in the UI now. @@ -214,7 +276,7 @@ def get_status(self, timeout=90) -> ProcessingServiceStatusResponse: response = ProcessingServiceStatusResponse( timestamp=timestamp, request_successful=resp.ok if resp else False, - server_live=self.last_checked_live, + server_live=self.last_seen_live, pipelines_online=pipelines_online, pipeline_configs=pipeline_configs, endpoint_url=self.endpoint_url, @@ -229,7 +291,7 @@ def get_pipeline_configs(self, timeout=6): Get the pipeline configurations from the processing service. This can be a long response as it includes the full category map for each algorithm. """ - if self.endpoint_url is None: + if not self.endpoint_url: return [] info_url = urljoin(self.endpoint_url, "info") diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py index 884676637..b6e9af254 100644 --- a/ami/ml/orchestration/nats_queue.py +++ b/ami/ml/orchestration/nats_queue.py @@ -43,6 +43,7 @@ async def get_connection(nats_url: str) -> tuple[nats.NATS, JetStreamContext]: TASK_TTR = getattr(settings, "NATS_TASK_TTR", 30) # Visibility timeout in seconds (configurable) +ADVISORY_STREAM_NAME = "advisories" # Shared stream for max delivery advisories across all jobs class TaskQueueManager: @@ -72,6 +73,15 @@ def __init__(self, nats_url: str | None = None, max_ack_pending: int | None = No async def __aenter__(self): """Create connection on enter.""" self.nc, self.js = await get_connection(self.nats_url) + + try: + await self._setup_advisory_stream() + except BaseException: + if self.nc and not self.nc.is_closed: + await self.nc.close() + self.nc = None + self.js = None + raise return self async def __aexit__(self, exc_type, exc_val, exc_tb): @@ -95,7 +105,7 @@ def _get_consumer_name(self, job_id: int) -> str: """Get consumer name from job_id.""" return f"job-{job_id}-consumer" - async def _stream_exists(self, job_id: int) -> bool: + async def _job_stream_exists(self, job_id: int) -> bool: """Check if stream exists for the given job. Only catches NotFoundError (→ False). TimeoutError propagates deliberately @@ -106,6 +116,10 @@ async def _stream_exists(self, job_id: int) -> bool: raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") stream_name = self._get_stream_name(job_id) + return await self._stream_exists(stream_name) + + async def _stream_exists(self, stream_name: str) -> bool: + """Check if a stream with the given name exists.""" try: await asyncio.wait_for(self.js.stream_info(stream_name), timeout=NATS_JETSTREAM_TIMEOUT) return True @@ -117,7 +131,7 @@ async def _ensure_stream(self, job_id: int): if self.js is None: raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") - if not await self._stream_exists(job_id): + if not await self._job_stream_exists(job_id): stream_name = self._get_stream_name(job_id) subject = self._get_subject(job_id) logger.warning(f"Stream {stream_name} does not exist") @@ -218,7 +232,7 @@ async def reserve_tasks(self, job_id: int, count: int, timeout: float = 5) -> li raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") try: - if not await self._stream_exists(job_id): + if not await self._job_stream_exists(job_id): logger.debug(f"Stream for job '{job_id}' does not exist when reserving task") return [] @@ -231,7 +245,7 @@ async def reserve_tasks(self, job_id: int, count: int, timeout: float = 5) -> li try: msgs = await psub.fetch(count, timeout=timeout) - except nats.errors.TimeoutError: + except (asyncio.TimeoutError, nats.errors.TimeoutError): logger.debug(f"No tasks available in stream for job '{job_id}'") return [] finally: @@ -250,7 +264,7 @@ async def reserve_tasks(self, job_id: int, count: int, timeout: float = 5) -> li logger.debug(f"No tasks reserved from stream for job '{job_id}'") return tasks - except asyncio.TimeoutError: + except (asyncio.TimeoutError, nats.errors.TimeoutError): raise # NATS unreachable — propagate so the view can return an appropriate error except Exception as e: logger.error(f"Failed to reserve tasks from stream for job '{job_id}': {e}") @@ -271,6 +285,7 @@ async def acknowledge_task(self, reply_subject: str) -> bool: try: await self.nc.publish(reply_subject, b"+ACK") + await self.nc.flush() logger.debug(f"Acknowledged task with reply subject {reply_subject}") return True except Exception as e: @@ -330,9 +345,134 @@ async def delete_stream(self, job_id: int) -> bool: logger.error(f"Failed to delete stream for job '{job_id}': {e}") return False + async def _setup_advisory_stream(self): + """Ensure the shared advisory stream exists to capture max-delivery events. + + Called on every __aenter__ so that advisories are captured from the moment + any TaskQueueManager connection is opened, not just when the DLQ is first read. + """ + if not await self._stream_exists(ADVISORY_STREAM_NAME): + await asyncio.wait_for( + self.js.add_stream( + name=ADVISORY_STREAM_NAME, + subjects=["$JS.EVENT.ADVISORY.>"], + max_age=3600, # Keep advisories for 1 hour + ), + timeout=NATS_JETSTREAM_TIMEOUT, + ) + logger.info("Advisory stream created") + + def _get_dlq_consumer_name(self, job_id: int) -> str: + """Get the durable consumer name for dead letter queue advisory tracking.""" + return f"job-{job_id}-dlq" + + async def get_dead_letter_image_ids(self, job_id: int, n: int = 10) -> list[str]: + """ + Get image IDs from dead letter queue (messages that exceeded max delivery attempts). + + Pulls from persistent advisory stream to find failed messages, then looks up image IDs. + Uses a durable consumer so acknowledged advisories are not re-delivered on subsequent calls. + + Args: + job_id: The job ID (integer primary key) + n: Maximum number of image IDs to return (default: 10) + + Returns: + List of image IDs that failed to process after max retry attempts + """ + if self.nc is None or self.js is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + + stream_name = self._get_stream_name(job_id) + consumer_name = self._get_consumer_name(job_id) + dlq_consumer_name = self._get_dlq_consumer_name(job_id) + dead_letter_ids = [] + + subject_filter = f"$JS.EVENT.ADVISORY.CONSUMER.MAX_DELIVERIES.{stream_name}.{consumer_name}" + + # Use a durable consumer so ACKs persist across calls — ephemeral consumers + # are deleted on unsubscribe, discarding all ACK tracking and causing every + # advisory to be re-delivered on the next call. + psub = await self.js.pull_subscribe(subject_filter, durable=dlq_consumer_name, stream=ADVISORY_STREAM_NAME) + + try: + msgs = await psub.fetch(n, timeout=1.0) + + for msg in msgs: + advisory_data = json.loads(msg.data.decode()) + + # Get the stream sequence of the failed message + if "stream_seq" in advisory_data: + stream_seq = advisory_data["stream_seq"] + + # Look up the actual message by sequence to get task ID + try: + job_msg = await self.js.get_msg(stream_name, stream_seq) + + if job_msg and job_msg.data: + task_data = json.loads(job_msg.data.decode()) + + if "image_id" in task_data: + dead_letter_ids.append(str(task_data["image_id"])) + else: + logger.warning(f"No image_id found in task data: {task_data}") + except Exception as e: + logger.warning(f"Could not retrieve message {stream_seq} from {stream_name}: {e}") + # The message might have been discarded after max_deliver exceeded + else: + logger.warning(f"No stream_seq in advisory data: {advisory_data}") + + # Acknowledge even if we couldn't find the stream_seq or image_id so it doesn't get re-delivered + # it shouldn't happen since stream_seq is part of the `io.nats.jetstream.advisory.v1.max_deliver` + # schema and all our messages have an image_id + await msg.ack() + logger.info( + f"Acknowledged advisory message for stream_seq {advisory_data.get('stream_seq', 'unknown')}" + ) + + # Flush to ensure all ACKs are written to the socket before unsubscribing. + # msg.ack() only queues a publish in the client buffer; without flush() the + # ACKs can be silently dropped when the subscription is torn down. + await self.nc.flush() + + except (asyncio.TimeoutError, nats.errors.TimeoutError): + logger.info(f"No advisory messages found for job {job_id}") + finally: + await psub.unsubscribe() + + return dead_letter_ids[:n] + + async def delete_dlq_consumer(self, job_id: int) -> bool: + """ + Delete the durable DLQ advisory consumer for a job. + + Args: + job_id: The job ID (integer primary key) + + Returns: + bool: True if successful, False otherwise + """ + if self.js is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + + dlq_consumer_name = self._get_dlq_consumer_name(job_id) + try: + await asyncio.wait_for( + self.js.delete_consumer(ADVISORY_STREAM_NAME, dlq_consumer_name), + timeout=NATS_JETSTREAM_TIMEOUT, + ) + logger.info(f"Deleted DLQ consumer {dlq_consumer_name} for job '{job_id}'") + return True + except nats.js.errors.NotFoundError: + logger.debug(f"DLQ consumer {dlq_consumer_name} for job '{job_id}' not found when attempting to delete") + return True # Consider it a success if the consumer is already gone + except Exception as e: + logger.warning(f"Failed to delete DLQ consumer for job '{job_id}': {e}") + return False + async def cleanup_job_resources(self, job_id: int) -> bool: """ - Clean up all NATS resources (consumer and stream) for a job. + Clean up all NATS resources (consumer, stream, and DLQ advisory consumer) for a job. This should be called when a job completes or is cancelled. @@ -342,8 +482,9 @@ async def cleanup_job_resources(self, job_id: int) -> bool: Returns: bool: True if successful, False otherwise """ - # Delete consumer first, then stream + # Delete consumer first, then stream, then the durable DLQ advisory consumer consumer_deleted = await self.delete_consumer(job_id) stream_deleted = await self.delete_stream(job_id) + dlq_consumer_deleted = await self.delete_dlq_consumer(job_id) - return consumer_deleted and stream_deleted + return consumer_deleted and stream_deleted and dlq_consumer_deleted diff --git a/ami/ml/orchestration/tests/test_nats_queue.py b/ami/ml/orchestration/tests/test_nats_queue.py index cf3514bce..da47f3429 100644 --- a/ami/ml/orchestration/tests/test_nats_queue.py +++ b/ami/ml/orchestration/tests/test_nats_queue.py @@ -1,11 +1,13 @@ """Unit tests for TaskQueueManager.""" +import json import unittest from unittest.mock import AsyncMock, MagicMock, patch import nats +import nats.errors -from ami.ml.orchestration.nats_queue import TaskQueueManager +from ami.ml.orchestration.nats_queue import ADVISORY_STREAM_NAME, TaskQueueManager from ami.ml.schemas import PipelineProcessingTask @@ -25,6 +27,7 @@ def _create_mock_nats_connection(self): nc = MagicMock() nc.is_closed = False nc.close = AsyncMock() + nc.flush = AsyncMock() js = MagicMock() js.stream_info = AsyncMock() @@ -60,7 +63,8 @@ async def test_publish_task_creates_stream_and_consumer(self): async with TaskQueueManager() as manager: await manager.publish_task(456, sample_task) - js.add_stream.assert_called_once() + # add_stream called twice: advisory stream in __aenter__ + job stream in _ensure_stream + self.assertEqual(js.add_stream.call_count, 2) self.assertIn("job_456", str(js.add_stream.call_args)) js.add_consumer.assert_called_once() @@ -153,7 +157,8 @@ async def test_cleanup_job_resources(self): result = await manager.cleanup_job_resources(123) self.assertTrue(result) - js.delete_consumer.assert_called_once() + # delete_consumer called twice: job consumer + DLQ advisory consumer + self.assertEqual(js.delete_consumer.call_count, 2) js.delete_stream.assert_called_once() async def test_naming_conventions(self): @@ -177,3 +182,56 @@ async def test_operations_without_connection_raise_error(self): with self.assertRaisesRegex(RuntimeError, "Connection is not open"): await manager.delete_stream(123) + + async def test_get_dead_letter_image_ids_returns_image_ids(self): + """Test that advisory messages are resolved to image IDs correctly.""" + nc, js = self._create_mock_nats_connection() + js.get_msg = AsyncMock() + + def make_advisory(seq): + m = MagicMock() + m.data = json.dumps({"stream_seq": seq}).encode() + m.ack = AsyncMock() + return m + + def make_job_msg(image_id): + m = MagicMock() + m.data = json.dumps({"image_id": image_id}).encode() + return m + + advisories = [make_advisory(1), make_advisory(2)] + js.get_msg.side_effect = [make_job_msg("img-1"), make_job_msg("img-2")] + + mock_psub = MagicMock() + mock_psub.fetch = AsyncMock(return_value=advisories) + mock_psub.unsubscribe = AsyncMock() + js.pull_subscribe = AsyncMock(return_value=mock_psub) + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager() as manager: + result = await manager.get_dead_letter_image_ids(123, n=10) + + self.assertEqual(result, ["img-1", "img-2"]) + js.pull_subscribe.assert_called_once_with( + "$JS.EVENT.ADVISORY.CONSUMER.MAX_DELIVERIES.job_123.job-123-consumer", + durable="job-123-dlq", + stream=ADVISORY_STREAM_NAME, + ) + mock_psub.fetch.assert_called_once_with(10, timeout=1.0) + mock_psub.unsubscribe.assert_called_once() + + async def test_get_dead_letter_image_ids_no_messages(self): + """Test that a fetch timeout returns an empty list and still unsubscribes.""" + nc, js = self._create_mock_nats_connection() + + mock_psub = MagicMock() + mock_psub.fetch = AsyncMock(side_effect=nats.errors.TimeoutError) + mock_psub.unsubscribe = AsyncMock() + js.pull_subscribe = AsyncMock(return_value=mock_psub) + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager() as manager: + result = await manager.get_dead_letter_image_ids(123) + + self.assertEqual(result, []) + mock_psub.unsubscribe.assert_called_once() diff --git a/ami/ml/schemas.py b/ami/ml/schemas.py index f63e6e1a1..7449c59e6 100644 --- a/ami/ml/schemas.py +++ b/ami/ml/schemas.py @@ -163,14 +163,14 @@ class Config: class DetectionRequest(pydantic.BaseModel): source_image: SourceImageRequest # the 'original' image - bbox: BoundingBox + bbox: BoundingBox | None = None crop_image_url: str | None = None algorithm: AlgorithmReference class DetectionResponse(pydantic.BaseModel): source_image_id: str - bbox: BoundingBox + bbox: BoundingBox | None = None inference_time: float | None = None algorithm: AlgorithmReference timestamp: datetime.datetime diff --git a/ami/ml/serializers.py b/ami/ml/serializers.py index 6c5782c8f..1711e19e1 100644 --- a/ami/ml/serializers.py +++ b/ami/ml/serializers.py @@ -2,6 +2,7 @@ from rest_framework import serializers from ami.main.api.serializers import DefaultSerializer, MinimalNestedModelSerializer +from ami.main.models import Project from .models.algorithm import Algorithm, AlgorithmCategoryMap from .models.pipeline import Pipeline, PipelineStage @@ -66,6 +67,8 @@ class Meta: class ProcessingServiceNestedSerializer(DefaultSerializer): + is_async = serializers.BooleanField(read_only=True) + class Meta: model = ProcessingService fields = [ @@ -73,8 +76,9 @@ class Meta: "id", "details", "endpoint_url", - "last_checked", - "last_checked_live", + "is_async", + "last_seen", + "last_seen_live", "created_at", "updated_at", ] @@ -134,6 +138,12 @@ class Meta: class ProcessingServiceSerializer(DefaultSerializer): pipelines = PipelineNestedSerializer(many=True, read_only=True) projects = serializers.SerializerMethodField() + is_async = serializers.BooleanField(read_only=True) + project = serializers.PrimaryKeyRelatedField( + write_only=True, + queryset=Project.objects.all(), + required=False, + ) class Meta: model = ProcessingService @@ -144,11 +154,13 @@ class Meta: "description", "projects", "endpoint_url", + "is_async", "pipelines", "created_at", "updated_at", - "last_checked", - "last_checked_live", + "last_seen", + "last_seen_live", + "project", ] def get_projects(self, obj): diff --git a/ami/ml/tasks.py b/ami/ml/tasks.py index 68e9603bd..3bfa458de 100644 --- a/ami/ml/tasks.py +++ b/ami/ml/tasks.py @@ -95,25 +95,49 @@ def remove_duplicate_classifications(project_id: int | None = None, dry_run: boo return num_deleted -@celery_app.task(soft_time_limit=10, time_limit=20) +# Timeout per sync service in the periodic beat task. Shorter than the default (90s for +# cold-start waits) since a missed check just waits for the next beat cycle. +# Worst case: 4 attempts (initial + 3 retries) × 8s timeout + backoff (0 + 2 + 4) = 38s per service. +_BEAT_STATUS_TIMEOUT = 8 + +# Discard queued copies that built up while the worker was unavailable — the next +# beat firing will pick things up fresh. Beat schedule is every 5 minutes. +_BEAT_TASK_EXPIRES = 240 + + +@celery_app.task(soft_time_limit=120, time_limit=150, expires=_BEAT_TASK_EXPIRES) def check_processing_services_online(): """ - Check the status of all v1 synchronous processing services and update the last_seen field. - We will update last_seen for asynchronous services when we receive a request from them. - - @TODO make this async to check all services in parallel + Check the status of all processing services and update last_seen/last_seen_live fields. + + - Async services (no endpoint URL): heartbeat is updated by mark_seen() on registration + and by _mark_pipeline_pull_services_seen() on task polling. This task marks them offline + if last_seen has exceeded PROCESSING_SERVICE_LAST_SEEN_MAX. Runs first so it always + executes even if a slow sync check hits the time limit. + - Sync services (endpoint URL set): checked sequentially with a short per-request timeout. + Safe to skip a cycle — the next beat firing will catch up. """ - from ami.ml.models import ProcessingService + import datetime - logger.info("Checking which synchronous processing services are online.") + from ami.ml.models.processing_service import PROCESSING_SERVICE_LAST_SEEN_MAX, ProcessingService - services = ProcessingService.objects.exclude(endpoint_url__isnull=True).exclude(endpoint_url__exact="").all() + logger.info("Checking which processing services are online.") - for service in services: - logger.info(f"Checking service {service}") + # Async services first — fast DB-only operation, must not be blocked by sync checks + stale_cutoff = datetime.datetime.now() - PROCESSING_SERVICE_LAST_SEEN_MAX + stale = ProcessingService.objects.async_services().filter(last_seen_live=True, last_seen__lt=stale_cutoff) + count = stale.count() + if count: + logger.info( + f"Marking {count} async service(s) offline (no heartbeat within {PROCESSING_SERVICE_LAST_SEEN_MAX})." + ) + stale.update(last_seen_live=False) + + for service in ProcessingService.objects.sync_services(): + logger.info(f"Checking push-mode service {service}") try: - status_response = service.get_status() + status_response = service.get_status(timeout=_BEAT_STATUS_TIMEOUT) logger.debug(status_response) - except Exception as e: - logger.error(f"Error checking service {service}: {e}") + except Exception: + logger.exception("Error checking service %s", service) continue diff --git a/ami/ml/tests.py b/ami/ml/tests.py index bbc648d6e..36ba5b5f7 100644 --- a/ami/ml/tests.py +++ b/ami/ml/tests.py @@ -136,9 +136,9 @@ def test_create_processing_service_without_endpoint_url(self): # Check that endpoint_url is null self.assertIsNone(data["instance"]["endpoint_url"]) - # Check that status indicates no endpoint configured + # Check that status indicates service is not yet live (no heartbeat received) self.assertFalse(data["status"]["request_successful"]) - self.assertIn("No endpoint URL configured", data["status"]["error"]) + self.assertFalse(data["status"]["server_live"]) self.assertIsNone(data["status"]["endpoint_url"]) def test_get_status_with_null_endpoint_url(self): @@ -149,10 +149,8 @@ def test_get_status_with_null_endpoint_url(self): status = service.get_status() self.assertFalse(status.request_successful) - self.assertIsNone(status.server_live) + self.assertFalse(status.server_live) # No heartbeat received yet = not live self.assertIsNone(status.endpoint_url) - self.assertIsNotNone(status.error) - self.assertIn("No endpoint URL configured", (status.error or "")) self.assertEqual(status.pipelines_online, []) def test_get_pipeline_configs_with_null_endpoint_url(self): @@ -164,6 +162,118 @@ def test_get_pipeline_configs_with_null_endpoint_url(self): self.assertEqual(configs, []) +class TestProcessingServiceLastSeen(TestCase): + """Test the last_seen, last_seen_live, and last_seen_latency fields.""" + + def setUp(self): + self.project = Project.objects.create(name="Last Seen Test Project") + + def test_mark_seen_sets_fields(self): + """Test that mark_seen() sets last_seen and last_seen_live.""" + service = ProcessingService.objects.create(name="Async Worker", endpoint_url=None) + service.projects.add(self.project) + + self.assertIsNone(service.last_seen) + self.assertIsNone(service.last_seen_live) + + service.mark_seen(live=True) + service.refresh_from_db() + + self.assertIsNotNone(service.last_seen) + self.assertTrue(service.last_seen_live) + + def test_mark_seen_offline(self): + """Test that mark_seen(live=False) sets last_seen_live to False.""" + service = ProcessingService.objects.create(name="Async Worker Offline", endpoint_url=None) + + service.mark_seen(live=False) + service.refresh_from_db() + + self.assertIsNotNone(service.last_seen) + self.assertFalse(service.last_seen_live) + + def test_get_status_updates_last_seen_for_sync_service(self): + """Test that get_status() updates last_seen fields for sync services (even if endpoint is unreachable).""" + service = ProcessingService.objects.create(name="Sync Service", endpoint_url="http://nonexistent-host:9999") + service.projects.add(self.project) + + # get_status should update the fields even for unreachable endpoints + service.get_status(timeout=1) + service.refresh_from_db() + + self.assertIsNotNone(service.last_seen) + self.assertFalse(service.last_seen_live) # unreachable = not live + self.assertIsNotNone(service.last_seen_latency) + + def test_model_has_last_seen_fields(self): + """Test that ProcessingService model has last_seen fields and not last_checked.""" + service = ProcessingService.objects.create(name="Field Test Service", endpoint_url=None) + service.mark_seen(live=True) + service.refresh_from_db() + + # Verify new fields exist + self.assertTrue(hasattr(service, "last_seen")) + self.assertTrue(hasattr(service, "last_seen_live")) + self.assertTrue(hasattr(service, "last_seen_latency")) + + # Verify old fields don't exist + self.assertFalse(hasattr(service, "last_checked")) + self.assertFalse(hasattr(service, "last_checked_live")) + self.assertFalse(hasattr(service, "last_checked_latency")) + + +class TestProjectPipelineRegistrationUpdatesLastSeen(APITestCase): + """Test that async pipeline registration updates last_seen on the processing service.""" + + def setUp(self): + from ami.users.roles import ProjectManager, create_roles_for_project + + self.user = User.objects.create_user(email="lastseen@example.com") # type: ignore + self.project = Project.objects.create(name="Last Seen Project", owner=self.user, create_defaults=False) + create_roles_for_project(self.project) + ProjectManager.assign_user(self.user, self.project) + + def test_pipeline_registration_marks_service_as_seen(self): + """Test that POSTing to the pipeline registration endpoint marks the service as last_seen_live.""" + url = f"/api/v2/projects/{self.project.pk}/pipelines/" + payload = { + "processing_service_name": "AsyncTestWorker", + "pipelines": [], + } + + self.client.force_authenticate(user=self.user) + response = self.client.post(url, payload, format="json") + self.assertEqual(response.status_code, 201) + + service = ProcessingService.objects.get(name="AsyncTestWorker") + self.assertIsNotNone(service.last_seen) + self.assertTrue(service.last_seen_live) + + def test_repeated_registration_updates_last_seen(self): + """Test that re-registering updates the last_seen timestamp.""" + url = f"/api/v2/projects/{self.project.pk}/pipelines/" + payload = { + "processing_service_name": "AsyncTestWorkerRepeat", + "pipelines": [], + } + + self.client.force_authenticate(user=self.user) + + # First registration + self.client.post(url, payload, format="json") + service = ProcessingService.objects.get(name="AsyncTestWorkerRepeat") + first_seen = service.last_seen + + # Second registration + self.client.post(url, payload, format="json") + service.refresh_from_db() + second_seen = service.last_seen + + self.assertIsNotNone(first_seen) + self.assertIsNotNone(second_seen) + self.assertGreaterEqual(second_seen, first_seen) + + class TestPipelineWithProcessingService(TestCase): def test_run_pipeline_with_errors_from_processing_service(self): """ @@ -735,6 +845,181 @@ def test_project_pipeline_config(self): final_config = self.pipeline.get_config(self.project.pk) self.assertEqual(final_config["test_param"], "project_value") + def test_image_with_null_detection(self): + """ + Test saving results for a pipeline that returns null detections for some images. + """ + image = self.test_images[0] + results = self.fake_pipeline_results([image], self.pipeline) + + # Manually change the results for a single image to a list of empty detections + results.detections = [] + + save_results(results) + + image.save() + self.assertEqual(image.get_detections_count(), 0) # detections_count should exclude null detections + total_num_detections = image.detections.distinct().count() + self.assertEqual(total_num_detections, 1) + + was_processed = image.get_was_processed() + self.assertEqual(was_processed, True) + + # Also test filtering by algorithm + was_processed = image.get_was_processed(algorithm_key="random-detector") + self.assertEqual(was_processed, True) + + def test_filter_processed_images_skips_null_only_image(self): + """ + An image with only null detections (processed, nothing found) should be + skipped by filter_processed_images — it doesn't need reprocessing. + """ + from ami.ml.models.pipeline import filter_processed_images + + image = self.test_images[0] + detector = self.algorithms["random-detector"] + + # Simulate a previous run that found nothing: create a null detection + Detection.objects.create( + source_image=image, + detection_algorithm=detector, + bbox=None, + ) + + result = list(filter_processed_images([image], self.pipeline)) + self.assertEqual(result, [], "Image with only null detections should be skipped") + + def test_filter_processed_images_yields_image_with_null_and_real_unclassified(self): + """ + An image with BOTH a null detection AND a real detection lacking classifications + should NOT be skipped — the real detection still needs to be classified. + """ + from ami.ml.models.pipeline import filter_processed_images + + image = self.test_images[0] + detector = self.algorithms["random-detector"] + + # Null detection from a prior empty run + Detection.objects.create( + source_image=image, + detection_algorithm=detector, + bbox=None, + ) + # Real detection with no classification yet + Detection.objects.create( + source_image=image, + detection_algorithm=detector, + bbox=[0.1, 0.2, 0.3, 0.4], + ) + + result = list(filter_processed_images([image], self.pipeline)) + self.assertEqual(result, [image], "Image with real unclassified detections should be yielded") + + def test_filter_processed_images_skips_null_and_fully_classified(self): + """ + An image with a null detection AND a real detection that is fully classified + by all pipeline algorithms should be skipped — it's fully processed. + """ + from ami.ml.models.pipeline import filter_processed_images + + image = self.test_images[0] + detector = self.algorithms["random-detector"] + binary_classifier = self.algorithms["random-binary-classifier"] + species_classifier = self.algorithms["random-species-classifier"] + + # Null detection from a prior empty run + Detection.objects.create( + source_image=image, + detection_algorithm=detector, + bbox=None, + ) + # Real detection with classifications from all pipeline algorithms + real_det = Detection.objects.create( + source_image=image, + detection_algorithm=detector, + bbox=[0.1, 0.2, 0.3, 0.4], + ) + taxon = Taxon.objects.create(name="Test Species Filtered") + Classification.objects.create( + detection=real_det, + taxon=taxon, + algorithm=binary_classifier, + score=0.9, + timestamp=datetime.datetime.now(), + ) + Classification.objects.create( + detection=real_det, + taxon=taxon, + algorithm=species_classifier, + score=0.8, + timestamp=datetime.datetime.now(), + ) + + result = list(filter_processed_images([image], self.pipeline)) + self.assertEqual(result, [], "Fully classified image with null detection should be skipped") + + def test_null_detections_are_algorithm_specific(self): + """ + Null detections from different pipelines/algorithms should not be shared. + Each algorithm's null detection is tracked separately so that + get_was_processed(algorithm_key=...) returns the correct per-algorithm status. + """ + from ami.ml.models.pipeline import save_results + + image = self.test_images[0] + + # Pipeline 1 processes image, finds nothing + results_1 = self.fake_pipeline_results([image], self.pipeline) + results_1.detections = [] + save_results(results_1) + + # Create a second pipeline with a DIFFERENT detector algorithm + detector_2, _ = Algorithm.objects.get_or_create( + key="constant-detector", + defaults={"name": "Constant Detector", "task_type": "detection"}, + ) + pipeline_2 = Pipeline.objects.create(name="Test Pipeline 2 Null Detect") + pipeline_2.algorithms.set([detector_2]) + + # Pipeline 2 processes the same image, also finds nothing + results_2 = self.fake_pipeline_results([image], pipeline_2) + results_2.detections = [] + save_results(results_2) + + # Both algorithms should independently mark the image as processed + detector_1_key = self.algorithms["random-detector"].key + self.assertTrue(image.get_was_processed(algorithm_key=detector_1_key)) + self.assertTrue( + image.get_was_processed(algorithm_key="constant-detector"), + "Pipeline 2's null detection should be created separately", + ) + + # Each pipeline must have its own null detection in the DB + null_detections = image.detections.filter(bbox__isnull=True) + self.assertEqual(null_detections.count(), 2, "Each pipeline should have its own null detection") + + def test_null_detection_deduplication_same_pipeline(self): + """ + Running the same pipeline twice on the same image should not create + duplicate null detections — the second run reuses the existing one. + """ + from ami.ml.models.pipeline import save_results + + image = self.test_images[0] + + # Run pipeline twice, both with no detections + results_1 = self.fake_pipeline_results([image], self.pipeline) + results_1.detections = [] + save_results(results_1) + + results_2 = self.fake_pipeline_results([image], self.pipeline) + results_2.detections = [] + save_results(results_2) + + # Should still be exactly one null detection + null_detections = image.detections.filter(bbox__isnull=True) + self.assertEqual(null_detections.count(), 1, "Same pipeline should not create duplicate null detections") + class TestAlgorithmCategoryMaps(TestCase): def setUp(self): diff --git a/ami/ml/views.py b/ami/ml/views.py index b3272f567..58832a10b 100644 --- a/ami/ml/views.py +++ b/ami/ml/views.py @@ -277,4 +277,7 @@ def create(self, request, *args, **kwargs): projects=Project.objects.filter(pk=project.pk), ) + # Record that we heard from this async processing service + processing_service.mark_seen(live=True) + return Response(response.dict(), status=status.HTTP_201_CREATED) diff --git a/ui/src/data-services/models/capture-set.ts b/ui/src/data-services/models/capture-set.ts index f56c8af2e..3605a0f9b 100644 --- a/ui/src/data-services/models/capture-set.ts +++ b/ui/src/data-services/models/capture-set.ts @@ -75,6 +75,10 @@ export class CaptureSet extends Entity { return this._data.source_images_with_detections_count } + get numImagesProcessed(): number | undefined { + return this._data.source_images_processed_count + } + get numImagesWithDetectionsLabel(): string { const pct = this.numImagesWithDetections && this.numImages @@ -86,6 +90,16 @@ export class CaptureSet extends Entity { )}%)` } + get numImagesProcessedLabel(): string { + const numProcessed = this.numImagesProcessed ?? 0 + const pct = + this.numImages && this.numImages > 0 + ? (numProcessed / this.numImages) * 100 + : 0 + + return `${numProcessed.toLocaleString()} (${pct.toFixed(0)}%)` + } + get numJobs(): number | undefined { return this._data.jobs?.length } diff --git a/ui/src/data-services/models/occurrence-details.ts b/ui/src/data-services/models/occurrence-details.ts index 90654984f..6d4d6bf63 100644 --- a/ui/src/data-services/models/occurrence-details.ts +++ b/ui/src/data-services/models/occurrence-details.ts @@ -56,7 +56,7 @@ export class OccurrenceDetails extends Occurrence { .map((i: any) => { const taxon = new Taxon(i.taxon) const overridden = i.withdrawn - const applied = taxon.id === this.determinationTaxon.id + const applied = !!this.determinationTaxon && taxon.id === this.determinationTaxon.id const identification: HumanIdentification = { id: `${i.id}`, @@ -82,8 +82,8 @@ export class OccurrenceDetails extends Occurrence { .sort(sortByDate) .map((p: any) => { const taxon = new Taxon(p.taxon) - const overridden = taxon.id !== this.determinationTaxon.id - const applied = taxon.id === this.determinationTaxon.id + const overridden = !this.determinationTaxon || taxon.id !== this.determinationTaxon.id + const applied = !!this.determinationTaxon && taxon.id === this.determinationTaxon.id const prediction: MachinePrediction = { id: `${p.id}`, diff --git a/ui/src/data-services/models/occurrence.ts b/ui/src/data-services/models/occurrence.ts index 8482c5bc9..40b31c814 100644 --- a/ui/src/data-services/models/occurrence.ts +++ b/ui/src/data-services/models/occurrence.ts @@ -8,13 +8,15 @@ export type ServerOccurrence = any // TODO: Update this type export class Occurrence { protected readonly _occurrence: ServerOccurrence - private readonly _determinationTaxon: Taxon + private readonly _determinationTaxon: Taxon | undefined private readonly _images: { src: string }[] = [] public constructor(occurrence: ServerOccurrence) { this._occurrence = occurrence - this._determinationTaxon = new Taxon(occurrence.determination_details.taxon) + this._determinationTaxon = occurrence.determination_details?.taxon + ? new Taxon(occurrence.determination_details.taxon) + : undefined this._images = occurrence.detection_images .filter((src: string) => !!src.length) @@ -49,8 +51,10 @@ export class Occurrence { return this._occurrence.deployment?.name } - get determinationId(): string { - return `${this._occurrence.determination.id}` + get determinationId(): string | undefined { + return this._occurrence.determination + ? `${this._occurrence.determination.id}` + : undefined } get determinationIdentificationId(): string | undefined { @@ -70,7 +74,7 @@ export class Occurrence { } get determinationScore(): number | undefined { - const score = this._occurrence.determination_details.score + const score = this._occurrence.determination_details?.score if (score || score === 0) { return score @@ -89,17 +93,17 @@ export class Occurrence { return undefined } - get determinationTaxon(): Taxon { + get determinationTaxon(): Taxon | undefined { return this._determinationTaxon } get determinationVerified(): boolean { - return !!this._occurrence.determination_details.identification + return !!this._occurrence.determination_details?.identification } get determinationVerifiedBy() { const verifiedBy = - this._occurrence.determination_details.identification?.user + this._occurrence.determination_details?.identification?.user return verifiedBy ? { @@ -116,7 +120,8 @@ export class Occurrence { } get displayName(): string { - return `${this.determinationTaxon.name} #${this.id}` + const name = this.determinationTaxon?.name ?? 'Unknown' + return `${name} #${this.id}` } get firstAppearanceTimestamp(): string { @@ -180,6 +185,10 @@ export class Occurrence { return false } + if (!this.determinationTaxon) { + return false + } + return ( identificationTaxonId === this.determinationTaxon.id && identificationUserId === userId diff --git a/ui/src/data-services/models/pipeline.ts b/ui/src/data-services/models/pipeline.ts index 86fc0c755..1a43de78b 100644 --- a/ui/src/data-services/models/pipeline.ts +++ b/ui/src/data-services/models/pipeline.ts @@ -103,7 +103,7 @@ export class Pipeline { (service: any) => new ProcessingService(service) ) for (const processingService of processingServices) { - if (processingService.lastCheckedLive) { + if (processingService.lastSeenLive || processingService.isAsync) { return { online: true, service: processingService } } } @@ -115,7 +115,7 @@ export class Pipeline { const processingServices = this._pipeline.processing_services let total_online = 0 for (const processingService of processingServices) { - if (processingService.last_checked_live) { + if (processingService.last_seen_live) { total_online += 1 } } @@ -123,22 +123,23 @@ export class Pipeline { return total_online + '/' + processingServices.length } - get processingServicesOnlineLastChecked(): string | undefined { + get processingServicesOnlineLastSeen(): string | undefined { const processingServices = this._pipeline.processing_services if (!processingServices.length) { return undefined } - const last_checked_times = [] - for (const processingService of processingServices) { - last_checked_times.push( - new Date(processingService.last_checked).getTime() - ) + const last_seen_times = processingServices + .filter((s: any) => s.last_seen != null) + .map((s: any) => new Date(s.last_seen).getTime()) + + if (!last_seen_times.length) { + return undefined } return getFormatedDateTimeString({ - date: new Date(Math.max(...last_checked_times)), + date: new Date(Math.max(...last_seen_times)), }) } diff --git a/ui/src/data-services/models/processing-service.ts b/ui/src/data-services/models/processing-service.ts index 4f92f9116..836baba6c 100644 --- a/ui/src/data-services/models/processing-service.ts +++ b/ui/src/data-services/models/processing-service.ts @@ -7,6 +7,7 @@ export type ServerProcessingService = any // TODO: Update this type export const SERVER_PROCESSING_SERVICE_STATUS_CODES = [ 'OFFLINE', 'ONLINE', + 'UNKNOWN', ] as const export type ServerProcessingServiceStatusCode = @@ -15,6 +16,7 @@ export type ServerProcessingServiceStatusCode = export enum ProcessingServiceStatusType { Success, Error, + Unknown, } export class ProcessingService extends Entity { @@ -50,8 +52,13 @@ export class ProcessingService extends Entity { return `${this._processingService.name}` } - get endpointUrl(): string { - return `${this._processingService.endpoint_url}` + get endpointUrl(): string | undefined { + const url = this._processingService.endpoint_url + return url && url.trim().length > 0 ? url : undefined + } + + get isAsync(): boolean { + return this._processingService.is_async ?? false } get description(): string { @@ -68,18 +75,18 @@ export class ProcessingService extends Entity { }) } - get lastChecked(): string | undefined { - if (!this._processingService.last_checked) { + get lastSeen(): string | undefined { + if (!this._processingService.last_seen) { return undefined } return getFormatedDateTimeString({ - date: new Date(this._processingService.last_checked), + date: new Date(this._processingService.last_seen), }) } - get lastCheckedLive(): boolean { - return this._processingService.last_checked_live + get lastSeenLive(): boolean { + return this._processingService.last_seen_live ?? false } get numPiplinesAdded(): number { @@ -92,7 +99,10 @@ export class ProcessingService extends Entity { type: ProcessingServiceStatusType color: string } { - const status_code = this.lastCheckedLive ? 'ONLINE' : 'OFFLINE' + if (this.isAsync) { + return ProcessingService.getStatusInfo('UNKNOWN') + } + const status_code = this.lastSeenLive ? 'ONLINE' : 'OFFLINE' return ProcessingService.getStatusInfo(status_code) } @@ -103,11 +113,13 @@ export class ProcessingService extends Entity { const type = { OFFLINE: ProcessingServiceStatusType.Error, ONLINE: ProcessingServiceStatusType.Success, + UNKNOWN: ProcessingServiceStatusType.Unknown, }[code] const color = { [ProcessingServiceStatusType.Error]: '#ef4444', // color-destructive-500, [ProcessingServiceStatusType.Success]: '#09af8a', // color-success-500 + [ProcessingServiceStatusType.Unknown]: '#9ca3af', // gray-400 }[type] return { diff --git a/ui/src/pages/occurrence-details/identification-card/machine-prediction.tsx b/ui/src/pages/occurrence-details/identification-card/machine-prediction.tsx index 11a1d824c..879d745a3 100644 --- a/ui/src/pages/occurrence-details/identification-card/machine-prediction.tsx +++ b/ui/src/pages/occurrence-details/identification-card/machine-prediction.tsx @@ -111,7 +111,7 @@ export const MachinePrediction = ({ isLoading={isLoading} /> {topN?.map(({ score, taxon }) => { - const applied = taxon.id === occurrence.determinationTaxon.id + const applied = !!occurrence.determinationTaxon && taxon.id === occurrence.determinationTaxon.id return (
- - navigate( - getAppRoute({ - to: APP_ROUTES.TAXON_DETAILS({ - projectId: projectId as string, - taxonId: id, - }), - }) - ) - } - size="lg" - taxon={occurrence.determinationTaxon} - /> + {occurrence.determinationTaxon ? ( + + navigate( + getAppRoute({ + to: APP_ROUTES.TAXON_DETAILS({ + projectId: projectId as string, + taxonId: id, + }), + }) + ) + } + size="lg" + taxon={occurrence.determinationTaxon} + /> + ) : ( + {translate(STRING.UNKNOWN)} + )}
{occurrence.determinationScore !== undefined ? ( ) : null} - {canUpdate && ( + {canUpdate && occurrence.determinationTaxon && ( <> occurrence.id)} - occurrenceTaxa={occurrences.map( - (occurrence) => occurrence.determinationTaxon - )} + occurrenceTaxa={occurrences + .map((occurrence) => occurrence.determinationTaxon) + .filter((taxon): taxon is Taxon => !!taxon)} />
) @@ -69,13 +70,14 @@ const Agree = ({ return !agreed }) + .filter((occurrence) => !!occurrence.determinationTaxon) .map((occurrence) => ({ agreeWith: { identificationId: occurrence.determinationIdentificationId, predictionId: occurrence.determinationPredictionId, }, occurrenceId: occurrence.id, - taxonId: occurrence.determinationTaxon.id, + taxonId: occurrence.determinationTaxon!.id, })), [occurrences] ) diff --git a/ui/src/pages/occurrences/occurrence-columns.tsx b/ui/src/pages/occurrences/occurrence-columns.tsx index fc98d9b1a..23d1c0ac6 100644 --- a/ui/src/pages/occurrences/occurrence-columns.tsx +++ b/ui/src/pages/occurrences/occurrence-columns.tsx @@ -195,9 +195,13 @@ const TaxonCell = ({
- + {item.determinationTaxon ? ( + + ) : ( + {translate(STRING.UNKNOWN)} + )} - {showQuickActions && canUpdate && ( + {showQuickActions && canUpdate && item.determinationTaxon && (
- taxon.rank === 'GENUS' || - taxon.rank === 'SPECIES' || - taxon.rank === 'SUBSPECIES' +export const isGenusOrBelow = (taxon?: Taxon) => + taxon?.rank === 'GENUS' || + taxon?.rank === 'SPECIES' || + taxon?.rank === 'SUBSPECIES' export const OccurrenceGallery = ({ error, @@ -158,7 +158,7 @@ export const OccurrenceGallery = ({ } )} > - {item.determinationTaxon.name} + {item.determinationTaxon?.name ?? translate(STRING.UNKNOWN)}
@@ -181,7 +181,7 @@ export const OccurrenceGallery = ({ /> ) : null} - {!isSelecting && canUpdate && ( + {!isSelecting && canUpdate && item.determinationTaxon && ( <> diff --git a/ui/src/pages/project/capture-sets/capture-set-columns.tsx b/ui/src/pages/project/capture-sets/capture-set-columns.tsx index 2172046c0..faf9c8f14 100644 --- a/ui/src/pages/project/capture-sets/capture-set-columns.tsx +++ b/ui/src/pages/project/capture-sets/capture-set-columns.tsx @@ -104,6 +104,16 @@ export const columns: (projectId: string) => TableColumn[] = ( ), }, + { + id: 'total-processed-captures', + name: translate(STRING.FIELD_LABEL_TOTAL_PROCESSED_CAPTURES), + styles: { + textAlign: TextAlign.Right, + }, + renderCell: (item: CaptureSet) => ( + + ), + }, { id: 'occurrences', name: translate(STRING.FIELD_LABEL_OCCURRENCES), diff --git a/ui/src/pages/project/capture-sets/capture-sets.tsx b/ui/src/pages/project/capture-sets/capture-sets.tsx index 3f7c1f8d2..a029137ef 100644 --- a/ui/src/pages/project/capture-sets/capture-sets.tsx +++ b/ui/src/pages/project/capture-sets/capture-sets.tsx @@ -28,6 +28,7 @@ export const CaptureSets = () => { settings: true, captures: true, 'captures-with-detections': true, + 'total-processed-captures': true, status: true, } ) diff --git a/ui/src/pages/project/pipelines/pipelines-columns.tsx b/ui/src/pages/project/pipelines/pipelines-columns.tsx index eda9f6f58..1cee4bd7a 100644 --- a/ui/src/pages/project/pipelines/pipelines-columns.tsx +++ b/ui/src/pages/project/pipelines/pipelines-columns.tsx @@ -52,11 +52,11 @@ export const columns: ( ), }, { - id: 'processing-services-online-last-checked', - name: 'Status last checked', - sortField: 'processing_services_online_last_checked', + id: 'processing-services-online-last-seen', + name: 'Status last seen', + sortField: 'processing_services_online_last_seen', renderCell: (item: Pipeline) => ( - + ), }, { diff --git a/ui/src/pages/project/processing-services/processing-services-columns.tsx b/ui/src/pages/project/processing-services/processing-services-columns.tsx index 9c5331a5e..8b6da16d7 100644 --- a/ui/src/pages/project/processing-services/processing-services-columns.tsx +++ b/ui/src/pages/project/processing-services/processing-services-columns.tsx @@ -53,7 +53,11 @@ export const columns: ( renderCell: (item: ProcessingService) => ( ), diff --git a/ui/src/utils/language.ts b/ui/src/utils/language.ts index 20d5ad0ea..d2cd62ab5 100644 --- a/ui/src/utils/language.ts +++ b/ui/src/utils/language.ts @@ -109,7 +109,7 @@ export enum STRING { FIELD_LABEL_JOB, FIELD_LABEL_JOBS, FIELD_LABEL_KEY, - FIELD_LABEL_LAST_CHECKED, + FIELD_LABEL_LAST_SEEN, FIELD_LABEL_LAST_DATE, FIELD_LABEL_LAST_SYNCED, FIELD_LABEL_LATITUDE, @@ -153,6 +153,7 @@ export enum STRING { FIELD_LABEL_TIME, FIELD_LABEL_TIMESTAMP, FIELD_LABEL_TOTAL_FILES, + FIELD_LABEL_TOTAL_PROCESSED_CAPTURES, FIELD_LABEL_TOTAL_RECORDS, FIELD_LABEL_TOTAL_SIZE, FIELD_LABEL_TRAINING_IMAGES, @@ -400,7 +401,7 @@ const ENGLISH_STRINGS: { [key in STRING]: string } = { [STRING.FIELD_LABEL_JOB]: 'Job', [STRING.FIELD_LABEL_JOBS]: 'Jobs', [STRING.FIELD_LABEL_KEY]: 'Key', - [STRING.FIELD_LABEL_LAST_CHECKED]: 'Last checked', + [STRING.FIELD_LABEL_LAST_SEEN]: 'Last seen', [STRING.FIELD_LABEL_LAST_DATE]: 'Last date', [STRING.FIELD_LABEL_LAST_SYNCED]: 'Last synced with data source', [STRING.FIELD_LABEL_LATITUDE]: 'Latitude', @@ -444,6 +445,7 @@ const ENGLISH_STRINGS: { [key in STRING]: string } = { [STRING.FIELD_LABEL_TIME]: 'Local time', [STRING.FIELD_LABEL_TIMESTAMP]: 'Timestamp', [STRING.FIELD_LABEL_TOTAL_FILES]: 'Total files', + [STRING.FIELD_LABEL_TOTAL_PROCESSED_CAPTURES]: 'Processed captures', [STRING.FIELD_LABEL_TOTAL_RECORDS]: 'Total records', [STRING.FIELD_LABEL_TOTAL_SIZE]: 'Total size', [STRING.FIELD_LABEL_TRAINING_IMAGES]: 'Reference images',