Skip to content
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
b60eab0
merge
carlos-irreverentlabs Jan 16, 2026
644927f
Merge remote-tracking branch 'upstream/main'
carlosgjs Jan 22, 2026
218f7aa
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 3, 2026
867201d
Update ML job counts in async case
carlosgjs Feb 6, 2026
90da389
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 10, 2026
cdf57ea
Update date picker version and tweak layout logic (#1105)
annavik Feb 6, 2026
6837ad6
fix: Properly handle async job state with celery tasks (#1114)
carlosgjs Feb 7, 2026
f1cd62d
PSv2: Implement queue clean-up upon job completion (#1113)
carlosgjs Feb 7, 2026
74df9ea
fix: PSv2: Workers should not try to fetch tasks from v1 jobs (#1118)
carlosgjs Feb 9, 2026
4a082d3
PSv2 cleanup: use is_complete() and dispatch_mode in job progress han…
mihow Feb 10, 2026
9d560cf
Merge branch 'main' into carlos/trackcounts
carlosgjs Feb 10, 2026
e43536b
track captures and failures
carlosgjs Feb 11, 2026
50df5f6
Update tests, CR feedback, log error images
carlosgjs Feb 11, 2026
3287fe2
CR feedback
carlosgjs Feb 11, 2026
a87b05a
fix type checking
carlosgjs Feb 11, 2026
89bf950
Merge remote-tracking branch 'origin/main' into carlos/trackcounts
mihow Feb 12, 2026
a5ff6f8
refactor: rename _get_progress to _commit_update in TaskStateManager
mihow Feb 12, 2026
337b7fc
fix: unify FAILURE_THRESHOLD and convert TaskProgress to dataclass
mihow Feb 12, 2026
8618d3c
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 13, 2026
4331dee
Merge branch 'main' into carlos/trackcounts
carlosgjs Feb 13, 2026
afee6e7
refactor: rename TaskProgress to JobStateProgress
mihow Feb 12, 2026
65d77cb
docs: update NATS todo and planning docs with session learnings
mihow Feb 13, 2026
8e8cd80
Rename TaskStateManager to AsyncJobStateManager
carlosgjs Feb 13, 2026
34af787
Merge branch 'carlos/trackcounts' of github.com:uw-ssec/antenna into …
carlosgjs Feb 13, 2026
afc4472
Track results counts in the job itself vs Redis
carlosgjs Feb 13, 2026
b6c3c6a
small simplification
carlosgjs Feb 13, 2026
b15024f
Reset counts to 0 on reset
carlosgjs Feb 13, 2026
b2e4a72
chore: remove local planning docs from PR branch
mihow Feb 17, 2026
a15ebda
docs: clarify three-layer job state architecture in docstrings
mihow Feb 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ami/jobs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,8 @@ def process_images(cls, job, images):

job.logger.info(f"All tasks completed for job {job.pk}")

FAILURE_THRESHOLD = 0.5
from ami.jobs.tasks import FAILURE_THRESHOLD

if image_count and (percent_successful < FAILURE_THRESHOLD):
job.progress.update_stage("process", status=JobState.FAILURE)
job.save()
Expand Down
80 changes: 68 additions & 12 deletions ami/jobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import time
from collections.abc import Callable
from typing import TYPE_CHECKING

from asgiref.sync import async_to_sync
from celery.signals import task_failure, task_postrun, task_prerun
Expand All @@ -14,7 +15,13 @@
from ami.tasks import default_soft_time_limit, default_time_limit
from config import celery_app

if TYPE_CHECKING:
from ami.jobs.models import JobState

logger = logging.getLogger(__name__)
# Minimum success rate. Jobs with fewer than this fraction of images
# processed successfully are marked as failed. Also used in MLJob.process_images().
FAILURE_THRESHOLD = 0.5


@celery_app.task(bind=True, soft_time_limit=default_soft_time_limit, time_limit=default_time_limit)
Expand Down Expand Up @@ -59,23 +66,27 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub
result_data: Dictionary containing the pipeline result
reply_subject: NATS reply subject for acknowledgment
"""
from ami.jobs.models import Job # avoid circular import
from ami.jobs.models import Job, JobState # avoid circular import

_, t = log_time()

# Validate with Pydantic - check for error response first
error_result = None
if "error" in result_data:
error_result = PipelineResultsError(**result_data)
processed_image_ids = {str(error_result.image_id)} if error_result.image_id else set()
logger.error(f"Pipeline returned error for job {job_id}, image {error_result.image_id}: {error_result.error}")
failed_image_ids = processed_image_ids # Same as processed for errors
pipeline_result = None
else:
pipeline_result = PipelineResultsResponse(**result_data)
processed_image_ids = {str(img.id) for img in pipeline_result.source_images}
failed_image_ids = set() # No failures for successful results

state_manager = TaskStateManager(job_id)

progress_info = state_manager.update_state(processed_image_ids, stage="process", request_id=self.request.id)
progress_info = state_manager.update_state(
processed_image_ids, stage="process", request_id=self.request.id, failed_image_ids=failed_image_ids
)
if not progress_info:
logger.warning(
f"Another task is already processing results for job {job_id}. "
Expand All @@ -84,16 +95,31 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub
raise self.retry(countdown=5, max_retries=10)

try:
_update_job_progress(job_id, "process", progress_info.percentage)
complete_state = JobState.SUCCESS
if progress_info.total > 0 and (progress_info.failed / progress_info.total) > FAILURE_THRESHOLD:
complete_state = JobState.FAILURE
_update_job_progress(
job_id,
"process",
progress_info.percentage,
complete_state=complete_state,
processed=progress_info.processed,
remaining=progress_info.remaining,
failed=progress_info.failed,
)

_, t = t(f"TIME: Updated job {job_id} progress in PROCESS stage progress to {progress_info.percentage*100}%")
job = Job.objects.get(pk=job_id)
job.logger.info(f"Processing pipeline result for job {job_id}, reply_subject: {reply_subject}")
job.logger.info(
f" Job {job_id} progress: {progress_info.processed}/{progress_info.total} images processed "
f"({progress_info.percentage*100}%), {progress_info.remaining} remaining, {len(processed_image_ids)} just "
"processed"
f"({progress_info.percentage*100}%), {progress_info.remaining} remaining, {progress_info.failed} failed, "
f"{len(processed_image_ids)} just processed"
)
if error_result:
job.logger.error(
f"Pipeline returned error for job {job_id}, image {error_result.image_id}: {error_result.error}"
)
except Job.DoesNotExist:
# don't raise and ack so that we don't retry since the job doesn't exists
logger.error(f"Job {job_id} not found")
Expand All @@ -102,6 +128,7 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub

try:
# Save to database (this is the slow operation)
detections_count, classifications_count, captures_count = 0, 0, 0
if pipeline_result:
# should never happen since otherwise we could not be processing results here
assert job.pipeline is not None, "Job pipeline is None"
Expand All @@ -112,18 +139,44 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub
f"Saved pipeline results to database with {len(pipeline_result.detections)} detections"
f", percentage: {progress_info.percentage*100}%"
)
# Calculate detection and classification counts from this result
detections_count = len(pipeline_result.detections)
classifications_count = sum(len(detection.classifications) for detection in pipeline_result.detections)
captures_count = len(pipeline_result.source_images)

_ack_task_via_nats(reply_subject, job.logger)
# Update job stage with calculated progress
progress_info = state_manager.update_state(processed_image_ids, stage="results", request_id=self.request.id)

progress_info = state_manager.update_state(
processed_image_ids,
stage="results",
request_id=self.request.id,
detections_count=detections_count,
classifications_count=classifications_count,
captures_count=captures_count,
)

if not progress_info:
logger.warning(
f"Another task is already processing results for job {job_id}. "
f"Retrying task {self.request.id} in 5 seconds..."
)
raise self.retry(countdown=5, max_retries=10)
_update_job_progress(job_id, "results", progress_info.percentage)

# update complete state based on latest progress info after saving results
complete_state = JobState.SUCCESS
if progress_info.total > 0 and (progress_info.failed / progress_info.total) > FAILURE_THRESHOLD:
complete_state = JobState.FAILURE

_update_job_progress(
job_id,
"results",
progress_info.percentage,
complete_state=complete_state,
detections=progress_info.detections,
classifications=progress_info.classifications,
captures=progress_info.captures,
)
Comment thread
carlosgjs marked this conversation as resolved.

except Exception as e:
job.logger.error(
Expand All @@ -149,19 +202,22 @@ async def ack_task():
# Don't fail the task if ACK fails - data is already saved


def _update_job_progress(job_id: int, stage: str, progress_percentage: float) -> None:
def _update_job_progress(
job_id: int, stage: str, progress_percentage: float, complete_state: "JobState", **state_params
) -> None:
from ami.jobs.models import Job, JobState # avoid circular import

with transaction.atomic():
job = Job.objects.select_for_update().get(pk=job_id)
job.progress.update_stage(
stage,
status=JobState.SUCCESS if progress_percentage >= 1.0 else JobState.STARTED,
status=complete_state if progress_percentage >= 1.0 else JobState.STARTED,
progress=progress_percentage,
**state_params,
)
if job.progress.is_complete():
job.status = JobState.SUCCESS
job.progress.summary.status = JobState.SUCCESS
job.status = complete_state
job.progress.summary.status = complete_state
job.finished_at = datetime.datetime.now() # Use naive datetime in local time
job.logger.info(f"Updated job {job_id} progress in stage '{stage}' to {progress_percentage*100}%")
job.save()
Expand Down
103 changes: 94 additions & 9 deletions ami/ml/orchestration/task_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,25 @@
"""

import logging
from collections import namedtuple
from dataclasses import dataclass

from django.core.cache import cache

logger = logging.getLogger(__name__)


# Define a namedtuple for a TaskProgress with the image counts
TaskProgress = namedtuple("TaskProgress", ["remaining", "total", "processed", "percentage"])
@dataclass
class JobStateProgress:
"""Progress snapshot for a job stage tracked in Redis."""

remaining: int = 0
total: int = 0
processed: int = 0
percentage: float = 0.0
detections: int = 0
classifications: int = 0
captures: int = 0
failed: int = 0


def _lock_key(job_id: int) -> str:
Expand Down Expand Up @@ -39,6 +49,10 @@ def __init__(self, job_id: int):
self.job_id = job_id
self._pending_key = f"job:{job_id}:pending_images"
self._total_key = f"job:{job_id}:pending_images_total"
self._failed_key = f"job:{job_id}:failed_images"
self._detections_key = f"job:{job_id}:total_detections"
self._classifications_key = f"job:{job_id}:total_classifications"
self._captures_key = f"job:{job_id}:total_captures"

def initialize_job(self, image_ids: list[str]) -> None:
"""
Expand All @@ -50,8 +64,16 @@ def initialize_job(self, image_ids: list[str]) -> None:
for stage in self.STAGES:
cache.set(self._get_pending_key(stage), image_ids, timeout=self.TIMEOUT)

# Initialize failed images set for process stage only
cache.set(self._failed_key, set(), timeout=self.TIMEOUT)

cache.set(self._total_key, len(image_ids), timeout=self.TIMEOUT)

# Initialize detection and classification counters
cache.set(self._detections_key, 0, timeout=self.TIMEOUT)
cache.set(self._classifications_key, 0, timeout=self.TIMEOUT)
cache.set(self._captures_key, 0, timeout=self.TIMEOUT)

def _get_pending_key(self, stage: str) -> str:
return f"{self._pending_key}:{stage}"

Expand All @@ -60,12 +82,22 @@ def update_state(
processed_image_ids: set[str],
stage: str,
request_id: str,
) -> None | TaskProgress:
detections_count: int = 0,
classifications_count: int = 0,
captures_count: int = 0,
failed_image_ids: set[str] | None = None,
) -> None | JobStateProgress:
"""
Update the task state with newly processed images.

Args:
processed_image_ids: Set of image IDs that have just been processed
stage: The processing stage ("process" or "results")
request_id: Unique identifier for this processing request
detections_count: Number of detections to add to cumulative count
classifications_count: Number of classifications to add to cumulative count
captures_count: Number of captures to add to cumulative count
failed_image_ids: Set of image IDs that failed processing (optional)
"""
# Create a unique lock key for this job
lock_key = _lock_key(self.job_id)
Expand All @@ -76,7 +108,9 @@ def update_state(

try:
# Update progress tracking in Redis
progress_info = self._get_progress(processed_image_ids, stage)
progress_info = self._commit_update(
processed_image_ids, stage, detections_count, classifications_count, captures_count, failed_image_ids
)
return progress_info
finally:
# Always release the lock when done
Expand All @@ -86,7 +120,7 @@ def update_state(
cache.delete(lock_key)
logger.debug(f"Released lock for job {self.job_id}, task {request_id}")

def get_progress(self, stage: str) -> TaskProgress | None:
def get_progress(self, stage: str) -> JobStateProgress | None:
"""Read-only progress snapshot for the given stage. Does not acquire a lock or mutate state."""
pending_images = cache.get(self._get_pending_key(stage))
total_images = cache.get(self._total_key)
Expand All @@ -95,9 +129,27 @@ def get_progress(self, stage: str) -> TaskProgress | None:
remaining = len(pending_images)
processed = total_images - remaining
percentage = float(processed) / total_images if total_images > 0 else 1.0
return TaskProgress(remaining=remaining, total=total_images, processed=processed, percentage=percentage)
failed_set = cache.get(self._failed_key) or set()
return JobStateProgress(
remaining=remaining,
total=total_images,
processed=processed,
percentage=percentage,
detections=cache.get(self._detections_key, 0),
classifications=cache.get(self._classifications_key, 0),
captures=cache.get(self._captures_key, 0),
failed=len(failed_set),
)

def _get_progress(self, processed_image_ids: set[str], stage: str) -> TaskProgress | None:
def _commit_update(
self,
processed_image_ids: set[str],
stage: str,
detections_count: int = 0,
classifications_count: int = 0,
captures_count: int = 0,
failed_image_ids: set[str] | None = None,
) -> JobStateProgress | None:
"""
Update pending images and return progress. Must be called under lock.

Expand All @@ -114,16 +166,45 @@ def _get_progress(self, processed_image_ids: set[str], stage: str) -> TaskProgre
remaining = len(remaining_images)
processed = total_images - remaining
percentage = float(processed) / total_images if total_images > 0 else 1.0

# Update cumulative detection, classification, and capture counts
current_detections = cache.get(self._detections_key, 0)
current_classifications = cache.get(self._classifications_key, 0)
current_captures = cache.get(self._captures_key, 0)

new_detections = current_detections + detections_count
new_classifications = current_classifications + classifications_count
new_captures = current_captures + captures_count

cache.set(self._detections_key, new_detections, timeout=self.TIMEOUT)
cache.set(self._classifications_key, new_classifications, timeout=self.TIMEOUT)
cache.set(self._captures_key, new_captures, timeout=self.TIMEOUT)

# Update failed images set if provided
if failed_image_ids:
existing_failed = cache.get(self._failed_key) or set()
updated_failed = existing_failed | failed_image_ids # Union to prevent duplicates
cache.set(self._failed_key, updated_failed, timeout=self.TIMEOUT)
failed_set = updated_failed
else:
failed_set = cache.get(self._failed_key) or set()

failed_count = len(failed_set)

logger.info(
f"Pending images from Redis for job {self.job_id} {stage}: "
f"{remaining}/{total_images}: {percentage*100}%"
)

return TaskProgress(
return JobStateProgress(
remaining=remaining,
total=total_images,
processed=processed,
percentage=percentage,
detections=new_detections,
classifications=new_classifications,
captures=new_captures,
failed=failed_count,
)

def cleanup(self) -> None:
Expand All @@ -132,4 +213,8 @@ def cleanup(self) -> None:
"""
for stage in self.STAGES:
cache.delete(self._get_pending_key(stage))
cache.delete(self._failed_key)
cache.delete(self._total_key)
cache.delete(self._detections_key)
cache.delete(self._classifications_key)
cache.delete(self._captures_key)
6 changes: 3 additions & 3 deletions ami/ml/orchestration/tests/test_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,9 @@ def test_cleanup_on_job_completion(self):
job = self._create_job_with_queued_images()

# Simulate job completion: complete all stages (collect, process, then results)
_update_job_progress(job.pk, stage="collect", progress_percentage=1.0)
_update_job_progress(job.pk, stage="process", progress_percentage=1.0)
_update_job_progress(job.pk, stage="results", progress_percentage=1.0)
_update_job_progress(job.pk, stage="collect", progress_percentage=1.0, complete_state=JobState.SUCCESS)
_update_job_progress(job.pk, stage="process", progress_percentage=1.0, complete_state=JobState.SUCCESS)
_update_job_progress(job.pk, stage="results", progress_percentage=1.0, complete_state=JobState.SUCCESS)

# Verify cleanup happened
self._verify_resources_cleaned(job.pk)
Expand Down
Loading