Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
b60eab0
merge
carlos-irreverentlabs Jan 16, 2026
644927f
Merge remote-tracking branch 'upstream/main'
carlosgjs Jan 22, 2026
218f7aa
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 3, 2026
867201d
Update ML job counts in async case
carlosgjs Feb 6, 2026
90da389
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 10, 2026
cdf57ea
Update date picker version and tweak layout logic (#1105)
annavik Feb 6, 2026
6837ad6
fix: Properly handle async job state with celery tasks (#1114)
carlosgjs Feb 7, 2026
f1cd62d
PSv2: Implement queue clean-up upon job completion (#1113)
carlosgjs Feb 7, 2026
74df9ea
fix: PSv2: Workers should not try to fetch tasks from v1 jobs (#1118)
carlosgjs Feb 9, 2026
4a082d3
PSv2 cleanup: use is_complete() and dispatch_mode in job progress han…
mihow Feb 10, 2026
9d560cf
Merge branch 'main' into carlos/trackcounts
carlosgjs Feb 10, 2026
e43536b
track captures and failures
carlosgjs Feb 11, 2026
50df5f6
Update tests, CR feedback, log error images
carlosgjs Feb 11, 2026
3287fe2
CR feedback
carlosgjs Feb 11, 2026
a87b05a
fix type checking
carlosgjs Feb 11, 2026
89bf950
Merge remote-tracking branch 'origin/main' into carlos/trackcounts
mihow Feb 12, 2026
a5ff6f8
refactor: rename _get_progress to _commit_update in TaskStateManager
mihow Feb 12, 2026
337b7fc
fix: unify FAILURE_THRESHOLD and convert TaskProgress to dataclass
mihow Feb 12, 2026
8618d3c
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 13, 2026
4331dee
Merge branch 'main' into carlos/trackcounts
carlosgjs Feb 13, 2026
afee6e7
refactor: rename TaskProgress to JobStateProgress
mihow Feb 12, 2026
65d77cb
docs: update NATS todo and planning docs with session learnings
mihow Feb 13, 2026
8e8cd80
Rename TaskStateManager to AsyncJobStateManager
carlosgjs Feb 13, 2026
34af787
Merge branch 'carlos/trackcounts' of github.com:uw-ssec/antenna into …
carlosgjs Feb 13, 2026
afc4472
Track results counts in the job itself vs Redis
carlosgjs Feb 13, 2026
b6c3c6a
small simplification
carlosgjs Feb 13, 2026
b15024f
Reset counts to 0 on reset
carlosgjs Feb 13, 2026
b2e4a72
chore: remove local planning docs from PR branch
mihow Feb 17, 2026
a15ebda
docs: clarify three-layer job state architecture in docstrings
mihow Feb 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 57 additions & 11 deletions ami/jobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import time
from collections.abc import Callable
from typing import Any

from asgiref.sync import async_to_sync
from celery.signals import task_failure, task_postrun, task_prerun
Expand Down Expand Up @@ -59,23 +60,27 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub
result_data: Dictionary containing the pipeline result
reply_subject: NATS reply subject for acknowledgment
"""
from ami.jobs.models import Job # avoid circular import
from ami.jobs.models import Job, JobState # avoid circular import

_, t = log_time()

# Validate with Pydantic - check for error response first
if "error" in result_data:
error_result = PipelineResultsError(**result_data)
processed_image_ids = {str(error_result.image_id)} if error_result.image_id else set()
failed_image_ids = processed_image_ids # Same as processed for errors
logger.error(f"Pipeline returned error for job {job_id}, image {error_result.image_id}: {error_result.error}")
pipeline_result = None
else:
pipeline_result = PipelineResultsResponse(**result_data)
processed_image_ids = {str(img.id) for img in pipeline_result.source_images}
failed_image_ids = set() # No failures for successful results

state_manager = TaskStateManager(job_id)

progress_info = state_manager.update_state(processed_image_ids, stage="process", request_id=self.request.id)
progress_info = state_manager.update_state(
processed_image_ids, stage="process", request_id=self.request.id, failed_image_ids=failed_image_ids
)
if not progress_info:
logger.warning(
f"Another task is already processing results for job {job_id}. "
Expand All @@ -84,15 +89,27 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub
raise self.retry(countdown=5, max_retries=10)

try:
_update_job_progress(job_id, "process", progress_info.percentage)
FAILURE_THRESHOLD = 0.5
complete_state = JobState.SUCCESS
if (progress_info.failed / progress_info.total) >= FAILURE_THRESHOLD:
complete_state = JobState.FAILURE
Comment thread
carlosgjs marked this conversation as resolved.
Outdated
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
_update_job_progress(
job_id,
"process",
progress_info.percentage,
complete_state=complete_state,
processed=progress_info.processed,
remaining=progress_info.remaining,
failed=progress_info.failed,
)

_, t = t(f"TIME: Updated job {job_id} progress in PROCESS stage progress to {progress_info.percentage*100}%")
job = Job.objects.get(pk=job_id)
job.logger.info(f"Processing pipeline result for job {job_id}, reply_subject: {reply_subject}")
job.logger.info(
f" Job {job_id} progress: {progress_info.processed}/{progress_info.total} images processed "
f"({progress_info.percentage*100}%), {progress_info.remaining} remaining, {len(processed_image_ids)} just "
"processed"
f"({progress_info.percentage*100}%), {progress_info.remaining} remaining, {progress_info.failed} failed, "
f"{len(processed_image_ids)} just processed"
)
except Job.DoesNotExist:
# don't raise and ack so that we don't retry since the job doesn't exists
Expand All @@ -102,6 +119,7 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub

try:
# Save to database (this is the slow operation)
detections_count, classifications_count, captures_count = 0, 0, 0
if pipeline_result:
# should never happen since otherwise we could not be processing results here
assert job.pipeline is not None, "Job pipeline is None"
Expand All @@ -112,18 +130,43 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub
f"Saved pipeline results to database with {len(pipeline_result.detections)} detections"
f", percentage: {progress_info.percentage*100}%"
)
# Calculate detection and classification counts from this result
detections_count = len(pipeline_result.detections) if pipeline_result else 0
classifications_count = (
sum(len(detection.classifications) for detection in pipeline_result.detections)
if pipeline_result
else 0
)
captures_count = len(pipeline_result.source_images)

_ack_task_via_nats(reply_subject, job.logger)
# Update job stage with calculated progress
progress_info = state_manager.update_state(processed_image_ids, stage="results", request_id=self.request.id)

progress_info = state_manager.update_state(
processed_image_ids,
stage="results",
request_id=self.request.id,
detections_count=detections_count,
classifications_count=classifications_count,
captures_count=captures_count,
)

if not progress_info:
logger.warning(
f"Another task is already processing results for job {job_id}. "
f"Retrying task {self.request.id} in 5 seconds..."
)
raise self.retry(countdown=5, max_retries=10)
_update_job_progress(job_id, "results", progress_info.percentage)

_update_job_progress(
job_id,
"results",
progress_info.percentage,
complete_state=complete_state,
detections=progress_info.detections,
classifications=progress_info.classifications,
captures=progress_info.captures,
)
Comment thread
carlosgjs marked this conversation as resolved.

except Exception as e:
job.logger.error(
Expand All @@ -149,19 +192,22 @@ async def ack_task():
# Don't fail the task if ACK fails - data is already saved


def _update_job_progress(job_id: int, stage: str, progress_percentage: float) -> None:
def _update_job_progress(
job_id: int, stage: str, progress_percentage: float, complete_state: Any, **state_params
) -> None:
from ami.jobs.models import Job, JobState # avoid circular import

Comment thread
carlosgjs marked this conversation as resolved.
Outdated
with transaction.atomic():
job = Job.objects.select_for_update().get(pk=job_id)
job.progress.update_stage(
stage,
status=JobState.SUCCESS if progress_percentage >= 1.0 else JobState.STARTED,
status=complete_state if progress_percentage >= 1.0 else JobState.STARTED,
progress=progress_percentage,
**state_params,
)
if job.progress.is_complete():
job.status = JobState.SUCCESS
job.progress.summary.status = JobState.SUCCESS
job.status = complete_state
job.progress.summary.status = complete_state
job.finished_at = datetime.datetime.now() # Use naive datetime in local time
job.logger.info(f"Updated job {job_id} progress in stage '{stage}' to {progress_percentage*100}%")
job.save()
Expand Down
78 changes: 75 additions & 3 deletions ami/ml/orchestration/task_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@


# Define a namedtuple for a TaskProgress with the image counts
TaskProgress = namedtuple("TaskProgress", ["remaining", "total", "processed", "percentage"])
TaskProgress = namedtuple(
"TaskProgress",
["remaining", "total", "processed", "percentage", "detections", "classifications", "captures", "failed"],
)


class TaskStateManager:
Expand All @@ -35,6 +38,10 @@ def __init__(self, job_id: int):
self.job_id = job_id
self._pending_key = f"job:{job_id}:pending_images"
self._total_key = f"job:{job_id}:pending_images_total"
self._failed_key = f"job:{job_id}:failed_images"
self._detections_key = f"job:{job_id}:total_detections"
self._classifications_key = f"job:{job_id}:total_classifications"
self._captures_key = f"job:{job_id}:total_captures"

def initialize_job(self, image_ids: list[str]) -> None:
"""
Expand All @@ -46,8 +53,16 @@ def initialize_job(self, image_ids: list[str]) -> None:
for stage in self.STAGES:
cache.set(self._get_pending_key(stage), image_ids, timeout=self.TIMEOUT)

# Initialize failed images set for process stage only
cache.set(self._failed_key, set(), timeout=self.TIMEOUT)

cache.set(self._total_key, len(image_ids), timeout=self.TIMEOUT)

# Initialize detection and classification counters
cache.set(self._detections_key, 0, timeout=self.TIMEOUT)
cache.set(self._classifications_key, 0, timeout=self.TIMEOUT)
cache.set(self._captures_key, 0, timeout=self.TIMEOUT)

def _get_pending_key(self, stage: str) -> str:
return f"{self._pending_key}:{stage}"

Expand All @@ -56,12 +71,22 @@ def update_state(
processed_image_ids: set[str],
stage: str,
request_id: str,
detections_count: int = 0,
classifications_count: int = 0,
captures_count: int = 0,
failed_image_ids: set[str] | None = None,
) -> None | TaskProgress:
"""
Update the task state with newly processed images.

Args:
processed_image_ids: Set of image IDs that have just been processed
stage: The processing stage ("process" or "results")
request_id: Unique identifier for this processing request
detections_count: Number of detections to add to cumulative count
classifications_count: Number of classifications to add to cumulative count
captures_count: Number of captures to add to cumulative count
failed_image_ids: Set of image IDs that failed processing (optional)
"""
# Create a unique lock key for this job
lock_key = f"job:{self.job_id}:process_results_lock"
Expand All @@ -72,7 +97,9 @@ def update_state(

try:
# Update progress tracking in Redis
progress_info = self._get_progress(processed_image_ids, stage)
progress_info = self._get_progress(
processed_image_ids, stage, detections_count, classifications_count, captures_count, failed_image_ids
)
return progress_info
finally:
# Always release the lock when done
Expand All @@ -82,7 +109,15 @@ def update_state(
cache.delete(lock_key)
logger.debug(f"Released lock for job {self.job_id}, task {request_id}")

def _get_progress(self, processed_image_ids: set[str], stage: str) -> TaskProgress | None:
def _get_progress(
self,
processed_image_ids: set[str],
stage: str,
detections_count: int = 0,
classifications_count: int = 0,
captures_count: int = 0,
failed_image_ids: set[str] | None = None,
) -> TaskProgress | None:
"""
Get current progress information for the job.

Expand All @@ -92,6 +127,10 @@ def _get_progress(self, processed_image_ids: set[str], stage: str) -> TaskProgre
- total: Total number of images (or None if not tracked)
- processed: Number of images processed (or None if not tracked)
- percentage: Progress as float 0.0-1.0 (or None if not tracked)
- detections: Cumulative count of detections
- classifications: Cumulative count of classifications
- captures: Cumulative count of captures
- failed: Number of unique failed images
"""
pending_images = cache.get(self._get_pending_key(stage))
total_images = cache.get(self._total_key)
Expand All @@ -104,6 +143,31 @@ def _get_progress(self, processed_image_ids: set[str], stage: str) -> TaskProgre
remaining = len(remaining_images)
processed = total_images - remaining
percentage = float(processed) / total_images if total_images > 0 else 1.0

# Update cumulative detection, classification, and capture counts
current_detections = cache.get(self._detections_key, 0)
current_classifications = cache.get(self._classifications_key, 0)
current_captures = cache.get(self._captures_key, 0)

new_detections = current_detections + detections_count
new_classifications = current_classifications + classifications_count
new_captures = current_captures + captures_count

cache.set(self._detections_key, new_detections, timeout=self.TIMEOUT)
cache.set(self._classifications_key, new_classifications, timeout=self.TIMEOUT)
cache.set(self._captures_key, new_captures, timeout=self.TIMEOUT)

# Update failed images set if provided
if failed_image_ids:
existing_failed = cache.get(self._failed_key) or set()
updated_failed = existing_failed | failed_image_ids # Union to prevent duplicates
cache.set(self._failed_key, updated_failed, timeout=self.TIMEOUT)
failed_set = updated_failed
else:
failed_set = cache.get(self._failed_key) or set()

failed_count = len(failed_set)

logger.info(
f"Pending images from Redis for job {self.job_id} {stage}: "
f"{remaining}/{total_images}: {percentage*100}%"
Expand All @@ -114,6 +178,10 @@ def _get_progress(self, processed_image_ids: set[str], stage: str) -> TaskProgre
total=total_images,
processed=processed,
percentage=percentage,
detections=new_detections,
classifications=new_classifications,
captures=new_captures,
failed=failed_count,
)

def cleanup(self) -> None:
Expand All @@ -122,4 +190,8 @@ def cleanup(self) -> None:
"""
for stage in self.STAGES:
cache.delete(self._get_pending_key(stage))
cache.delete(self._failed_key)
cache.delete(self._total_key)
cache.delete(self._detections_key)
cache.delete(self._classifications_key)
cache.delete(self._captures_key)
Loading
Loading