diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index d8ff89d39..ad3e18ca8 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -192,10 +192,10 @@ def _fail_job(job_id: int, reason: str) -> None: job.save(update_fields=["status", "progress", "finished_at"]) job.logger.error(f"Job {job_id} marked as FAILURE: {reason}") - cleanup_async_job_resources(job.pk, job.logger) + cleanup_async_job_resources(job.pk) except Job.DoesNotExist: logger.error(f"Cannot fail job {job_id}: not found") - cleanup_async_job_resources(job_id, logger) + cleanup_async_job_resources(job_id) def _ack_task_via_nats(reply_subject: str, job_logger: logging.Logger) -> None: @@ -423,7 +423,7 @@ def cleanup_async_job_if_needed(job) -> None: # import here to avoid circular imports from ami.ml.orchestration.jobs import cleanup_async_job_resources - cleanup_async_job_resources(job.pk, job.logger) + cleanup_async_job_resources(job.pk) @task_prerun.connect(sender=run_job) diff --git a/ami/ml/orchestration/jobs.py b/ami/ml/orchestration/jobs.py index 95c763b1b..79f787e28 100644 --- a/ami/ml/orchestration/jobs.py +++ b/ami/ml/orchestration/jobs.py @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) -def cleanup_async_job_resources(job_id: int, _logger: logging.Logger) -> bool: +def cleanup_async_job_resources(job_id: int) -> bool: """ Clean up NATS JetStream and Redis resources for a completed job. @@ -21,12 +21,26 @@ def cleanup_async_job_resources(job_id: int, _logger: logging.Logger) -> bool: Cleanup failures are logged but don't fail the job - data is already saved. + Resolves the job (and its per-job logger) internally so callers only need + to pass the ``job_id`` — matches the pattern used by ``save_results`` in + ``ami/jobs/tasks.py``. If the ``Job`` row is gone (e.g. the + ``Job.DoesNotExist`` path in ``_fail_job``), the function falls back to + the module logger and TaskQueueManager's module-logger path. + Args: - job_id: The Job ID (integer primary key) - _logger: Logger to use for logging cleanup results + job_id: The Job ID (integer primary key). Returns: bool: True if both cleanups succeeded, False otherwise """ + # Resolve the logger up front: job.logger when the Job exists, module + # logger otherwise. Matches the pattern used by save_results. + job: Job | None = None + try: + job = Job.objects.get(pk=job_id) + except Job.DoesNotExist: + pass + job_logger: logging.Logger = job.logger if job else logger + redis_success = False nats_success = False @@ -34,24 +48,26 @@ def cleanup_async_job_resources(job_id: int, _logger: logging.Logger) -> bool: try: state_manager = AsyncJobStateManager(job_id) state_manager.cleanup() - _logger.info(f"Cleaned up Redis state for job {job_id}") + job_logger.info(f"Cleaned up Redis state for job {job_id}") redis_success = True except Exception as e: - _logger.error(f"Error cleaning up Redis state for job {job_id}: {e}") + job_logger.error(f"Error cleaning up Redis state for job {job_id}: {e}") - # Cleanup NATS resources + # Cleanup NATS resources. Only forward a real per-job logger to + # TaskQueueManager — passing the module logger would mirror cleanup + # lifecycle lines into an unrelated logger. async def cleanup(): - async with TaskQueueManager() as manager: + async with TaskQueueManager(job_logger=job.logger if job else None) as manager: return await manager.cleanup_job_resources(job_id) try: nats_success = async_to_sync(cleanup)() if nats_success: - _logger.info(f"Cleaned up NATS resources for job {job_id}") + job_logger.info(f"Cleaned up NATS resources for job {job_id}") else: - _logger.warning(f"Failed to clean up NATS resources for job {job_id}") + job_logger.warning(f"Failed to clean up NATS resources for job {job_id}") except Exception as e: - _logger.error(f"Error cleaning up NATS resources for job {job_id}: {e}") + job_logger.error(f"Error cleaning up NATS resources for job {job_id}: {e}") return redis_success and nats_success @@ -97,16 +113,29 @@ async def queue_all_images(): successful_queues = 0 failed_queues = 0 - async with TaskQueueManager() as manager: + # Pass job.logger so stream/consumer setup, per-image debug lines, and + # publish failures all appear in the UI job log (not just the module + # logger). All log calls inside this block go through manager.log_async + # so module + job logger stay in sync with one consistent API — and + # the sync_to_async bridge for JobLogHandler's ORM save lives in one + # place instead of being re-implemented at every call site. + async with TaskQueueManager(job_logger=job.logger) as manager: for image_pk, task in tasks: try: - logger.info(f"Queueing image {image_pk} to stream for job '{job.pk}': {task.image_url}") + await manager.log_async( + logging.DEBUG, + f"Queueing image {image_pk} to stream for job '{job.pk}': {task.image_url}", + ) success = await manager.publish_task( job_id=job.pk, data=task, ) except Exception as e: - logger.error(f"Failed to queue image {image_pk} to stream for job '{job.pk}': {e}") + await manager.log_async( + logging.ERROR, + f"Failed to queue image {image_pk} to stream for job '{job.pk}': {e}", + exc_info=True, + ) success = False if success: diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py index b6e9af254..43d9d65e5 100644 --- a/ami/ml/orchestration/nats_queue.py +++ b/ami/ml/orchestration/nats_queue.py @@ -15,6 +15,7 @@ import logging import nats +from asgiref.sync import sync_to_async from django.conf import settings from nats.js import JetStreamContext from nats.js.api import AckPolicy, ConsumerConfig, DeliverPolicy @@ -54,21 +55,137 @@ class TaskQueueManager: nats_url: NATS server URL. Falls back to settings.NATS_URL, then "nats://nats:4222". max_ack_pending: Max unacknowledged messages per consumer. Falls back to settings.NATS_MAX_ACK_PENDING, then 1000. + job_logger: Optional per-job logger. When set, lifecycle events (stream / + consumer create or reuse, cleanup stats, publish failures) are mirrored + to this logger in addition to the module logger, so they appear in the + job's own log stream as seen from the UI. Per-message and per-poll + events stay on the module logger only to avoid drowning large jobs. Use as an async context manager: - async with TaskQueueManager() as manager: + async with TaskQueueManager(job_logger=job.logger) as manager: await manager.publish_task(123, {'data': 'value'}) tasks = await manager.reserve_tasks(123, count=64) await manager.acknowledge_task(tasks[0].reply_subject) """ - def __init__(self, nats_url: str | None = None, max_ack_pending: int | None = None): + def __init__( + self, + nats_url: str | None = None, + max_ack_pending: int | None = None, + job_logger: logging.Logger | None = None, + ): self.nats_url = nats_url or getattr(settings, "NATS_URL", "nats://nats:4222") self.max_ack_pending = ( max_ack_pending if max_ack_pending is not None else getattr(settings, "NATS_MAX_ACK_PENDING", 1000) ) + self.job_logger = job_logger self.nc: nats.NATS | None = None self.js: JetStreamContext | None = None + # Dedupe lifecycle log lines per manager session so a job that publishes + # hundreds of tasks doesn't emit hundreds of "reusing stream" messages. + self._streams_logged: set[int] = set() + self._consumers_logged: set[int] = set() + + async def log_async(self, level: int, msg: str, *, exc_info: bool = False) -> None: + """Log to both the module logger and the job logger (if set). + + Named ``log_async`` (not ``log``) to flag at every call site that this + is the async fan-out helper, distinct from stdlib ``Logger.log`` — + callers must ``await`` it. Use this from any async context where the + line should appear in both ops dashboards and the job's UI log. + + Module logger fires synchronously (ops dashboards / stdout / New Relic + are unaffected). The job logger call is bridged through + ``sync_to_async`` because Django's ``JobLogHandler`` does an ORM + ``refresh_from_db`` + ``save`` on every emit — calling that directly + from the event loop raises ``SynchronousOnlyOperation`` and the log + line is silently dropped. The bridge offloads the handler work to a + thread so the line actually lands in ``job.logs.stdout``. + + Pass ``exc_info=True`` inside an ``except`` block to capture the + traceback on both loggers (same semantics as stdlib ``Logger.log``). + + Exceptions from the job logger are swallowed so logging a lifecycle + event never breaks the actual NATS operation. + + Gated by ``isEnabledFor`` up front so a disabled level returns + immediately without paying for the ``sync_to_async`` round-trip. + Matters most at DEBUG during large queues — stdlib ``Logger.log`` + does the same level check internally before formatting a message; + we have to do it explicitly here because the job-logger mirror + happens through ``sync_to_async`` (ThreadPoolExecutor submit), which + would otherwise fire once per image even when the handler is about + to drop the record. + + FUTURE: this currently mirrors granular per-job lifecycle (stream / + consumer create+reuse, per-image debug, forensic stats) to BOTH the + module logger and the job logger. The longer-term preference is to + route — granular lifecycle stays on ``job.logger`` only (matching + ``ami.jobs.tasks.save_results`` and friends, where ``job.logger`` has + ``propagate=False`` and never reaches stdout / NR), with the module + logger reserved for true ops signals (connection failures, NATS-side + errors). Kept symmetric for now because async ML processing is still + being stabilized and the extra stdout visibility is helping us + debug. Once we trust the per-job UI log as the canonical place to + inspect a job, switch ``log_async`` to route-not-mirror at INFO/DEBUG + and only auto-mirror at WARNING+ (so true error signals still always + reach ops dashboards). + """ + module_enabled = logger.isEnabledFor(level) + job_enabled = ( + self.job_logger is not None and self.job_logger is not logger and self.job_logger.isEnabledFor(level) + ) + if not module_enabled and not job_enabled: + return + if module_enabled: + logger.log(level, msg, exc_info=exc_info) + if job_enabled: + try: + await sync_to_async(self.job_logger.log)(level, msg, exc_info=exc_info) + except Exception as e: + logger.warning(f"Failed to mirror log to job logger: {e}") + + @staticmethod + def _format_consumer_config(info) -> str: + """Format ConsumerInfo config into a compact creation-time string. + + Reads the actual config from the ConsumerInfo returned by + ``add_consumer`` or ``consumer_info``, so the log always reflects + what the server accepted rather than what we requested. + """ + cfg = info.config + if cfg is None: + return "config=?" + + def _val(v): + """Unwrap enum .value if present, pass through scalars.""" + return v.value if hasattr(v, "value") else v + + return ( + f"max_deliver={_val(cfg.max_deliver) if cfg.max_deliver is not None else '?'}, " + f"ack_wait={_val(cfg.ack_wait) if cfg.ack_wait is not None else '?'}s, " + f"max_ack_pending={_val(cfg.max_ack_pending) if cfg.max_ack_pending is not None else '?'}, " + f"deliver_policy={_val(cfg.deliver_policy) if cfg.deliver_policy is not None else '?'}, " + f"ack_policy={_val(cfg.ack_policy) if cfg.ack_policy is not None else '?'}" + ) + + @staticmethod + def _format_consumer_stats(info) -> str: + """Format ConsumerInfo into a compact runtime stats string. + + All nats-py ConsumerInfo fields are Optional, so defensive access is + required: this method renders missing values as '?'. Used for both + reuse-announcements and forensic cleanup lines. + """ + delivered = info.delivered.consumer_seq if info.delivered is not None else "?" + ack_floor = info.ack_floor.consumer_seq if info.ack_floor is not None else "?" + return ( + f"delivered={delivered} " + f"ack_floor={ack_floor} " + f"num_pending={info.num_pending if info.num_pending is not None else '?'} " + f"num_ack_pending={info.num_ack_pending if info.num_ack_pending is not None else '?'} " + f"num_redelivered={info.num_redelivered if info.num_redelivered is not None else '?'}" + ) async def __aenter__(self): """Create connection on enter.""" @@ -127,27 +244,72 @@ async def _stream_exists(self, stream_name: str) -> bool: return False async def _ensure_stream(self, job_id: int): - """Ensure stream exists for the given job.""" + """Ensure stream exists for the given job. + + Logs a lifecycle line to both the module and job logger the first time it + sees a given job in this manager session (creation or reuse). Subsequent + calls in the same session skip the NATS round-trip entirely via the + ``_streams_logged`` set. + + Concurrency note: ``Job.cancel()`` can trigger ``cleanup_async_job_resources`` + in the request thread while this manager is still in its publish loop in + the Celery worker, so the stream *can* be deleted mid-flight from a + different manager session. The early-return is still safe in that case — + subsequent ``publish_task`` calls will fail loudly (``self.js.publish`` + returns an error, caught and logged by ``publish_task``) rather than + silently recreating the stream without a consumer. Failing loud on a + cancel race is the correct behavior. + """ + if job_id in self._streams_logged: + return if self.js is None: raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") - if not await self._job_stream_exists(job_id): - stream_name = self._get_stream_name(job_id) - subject = self._get_subject(job_id) - logger.warning(f"Stream {stream_name} does not exist") - # Stream doesn't exist, create it - await asyncio.wait_for( - self.js.add_stream( - name=stream_name, - subjects=[subject], - max_age=86400, # 24 hours retention - ), - timeout=NATS_JETSTREAM_TIMEOUT, + stream_name = self._get_stream_name(job_id) + subject = self._get_subject(job_id) + + try: + info = await asyncio.wait_for(self.js.stream_info(stream_name), timeout=NATS_JETSTREAM_TIMEOUT) + state = info.state + messages = state.messages if state is not None else "?" + last_seq = state.last_seq if state is not None else "?" + await self.log_async( + logging.INFO, + f"Reusing NATS stream {stream_name} (messages={messages}, last_seq={last_seq})", ) - logger.info(f"Created stream {stream_name}") + self._streams_logged.add(job_id) + return + except nats.js.errors.NotFoundError: + pass + + await asyncio.wait_for( + self.js.add_stream( + name=stream_name, + subjects=[subject], + max_age=86400, # 24 hours retention + ), + timeout=NATS_JETSTREAM_TIMEOUT, + ) + await self.log_async(logging.INFO, f"Created NATS stream {stream_name}") + self._streams_logged.add(job_id) async def _ensure_consumer(self, job_id: int): - """Ensure consumer exists for the given job.""" + """Ensure consumer exists for the given job. + + On first sight in this manager session (creation or reuse), emits a line + to both the module and job logger. On creation the line includes the + config snapshot (max_deliver, ack_wait, max_ack_pending, deliver_policy, + ack_policy) so forensic readers can see exactly what delivery semantics + were in effect. Subsequent calls skip the NATS round-trip via the + ``_consumers_logged`` set. + + Same concurrency caveat as ``_ensure_stream``: a concurrent cancel can + delete the consumer mid-flight. The early-return stays safe because + downstream ``publish_task`` fails loudly rather than silently recreating + an orphan consumer. + """ + if job_id in self._consumers_logged: + return if self.js is None: raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") @@ -160,27 +322,39 @@ async def _ensure_consumer(self, job_id: int): self.js.consumer_info(stream_name, consumer_name), timeout=NATS_JETSTREAM_TIMEOUT, ) - logger.debug(f"Consumer {consumer_name} already exists: {info}") - except asyncio.TimeoutError: - raise # NATS unreachable — let caller handle it - except Exception: - # Consumer doesn't exist, create it - await asyncio.wait_for( - self.js.add_consumer( - stream=stream_name, - config=ConsumerConfig( - durable_name=consumer_name, - ack_policy=AckPolicy.EXPLICIT, - ack_wait=TASK_TTR, # Visibility timeout (TTR) - max_deliver=5, # Max retry attempts - deliver_policy=DeliverPolicy.ALL, - max_ack_pending=self.max_ack_pending, - filter_subject=subject, - ), - ), - timeout=NATS_JETSTREAM_TIMEOUT, + await self.log_async( + logging.INFO, + f"Reusing NATS consumer {consumer_name} ({self._format_consumer_stats(info)})", ) - logger.info(f"Created consumer {consumer_name}") + self._consumers_logged.add(job_id) + return + except nats.js.errors.NotFoundError: + # Consumer doesn't exist, fall through to create it. Other + # JetStream errors (auth, API, transient) and asyncio.TimeoutError + # propagate naturally — we don't want to mask them as "missing + # consumer" and emit misleading creation logs. + pass + + info = await asyncio.wait_for( + self.js.add_consumer( + stream=stream_name, + config=ConsumerConfig( + durable_name=consumer_name, + ack_policy=AckPolicy.EXPLICIT, + ack_wait=TASK_TTR, # Visibility timeout (TTR) + max_deliver=5, # Max retry attempts + deliver_policy=DeliverPolicy.ALL, + max_ack_pending=self.max_ack_pending, + filter_subject=subject, + ), + ), + timeout=NATS_JETSTREAM_TIMEOUT, + ) + await self.log_async( + logging.INFO, + f"Created NATS consumer {consumer_name} ({self._format_consumer_config(info)})", + ) + self._consumers_logged.add(job_id) async def publish_task(self, job_id: int, data: PipelineProcessingTask) -> bool: """ @@ -212,7 +386,10 @@ async def publish_task(self, job_id: int, data: PipelineProcessingTask) -> bool: return True except Exception as e: - logger.error(f"Failed to publish task to stream for job '{job_id}': {e}") + # Per-message success logs stay at module level (noise in 10k-image + # jobs), but a failure on even a single publish deserves to surface + # in the job log — otherwise the failure path is invisible to users. + await self.log_async(logging.ERROR, f"Failed to publish task to stream for job '{job_id}': {e}") return False async def reserve_tasks(self, job_id: int, count: int, timeout: float = 5) -> list[PipelineProcessingTask]: @@ -292,6 +469,35 @@ async def acknowledge_task(self, reply_subject: str) -> bool: logger.error(f"Failed to acknowledge task: {e}") return False + async def _log_final_consumer_stats(self, job_id: int) -> None: + """Log one forensic line about the consumer state before deletion. + + This is the single most useful line in a post-mortem: it tells you how + many messages were delivered, how many were acked, and how many were + redelivered before the consumer vanished. Failures here must NOT block + cleanup — if the consumer or stream is already gone, just skip it. + """ + if self.js is None: + return + stream_name = self._get_stream_name(job_id) + consumer_name = self._get_consumer_name(job_id) + try: + info = await asyncio.wait_for( + self.js.consumer_info(stream_name, consumer_name), + timeout=NATS_JETSTREAM_TIMEOUT, + ) + except Exception as e: + # Broad catch is intentional here (unlike _ensure_consumer): at + # cleanup time we tolerate any failure — stream gone, consumer + # already deleted, auth, timeout — so the delete calls below + # still get a chance to run. + logger.debug(f"Could not fetch consumer info for {consumer_name} before deletion: {e}") + return + await self.log_async( + logging.INFO, + f"Finalizing NATS consumer {consumer_name} before deletion ({self._format_consumer_stats(info)})", + ) + async def delete_consumer(self, job_id: int) -> bool: """ Delete the consumer for a job. @@ -313,10 +519,10 @@ async def delete_consumer(self, job_id: int) -> bool: self.js.delete_consumer(stream_name, consumer_name), timeout=NATS_JETSTREAM_TIMEOUT, ) - logger.info(f"Deleted consumer {consumer_name} for job '{job_id}'") + await self.log_async(logging.INFO, f"Deleted NATS consumer {consumer_name} for job '{job_id}'") return True except Exception as e: - logger.error(f"Failed to delete consumer for job '{job_id}': {e}") + await self.log_async(logging.ERROR, f"Failed to delete NATS consumer for job '{job_id}': {e}") return False async def delete_stream(self, job_id: int) -> bool: @@ -339,10 +545,10 @@ async def delete_stream(self, job_id: int) -> bool: self.js.delete_stream(stream_name), timeout=NATS_JETSTREAM_TIMEOUT, ) - logger.info(f"Deleted stream {stream_name} for job '{job_id}'") + await self.log_async(logging.INFO, f"Deleted NATS stream {stream_name} for job '{job_id}'") return True except Exception as e: - logger.error(f"Failed to delete stream for job '{job_id}': {e}") + await self.log_async(logging.ERROR, f"Failed to delete NATS stream for job '{job_id}': {e}") return False async def _setup_advisory_stream(self): @@ -482,6 +688,10 @@ async def cleanup_job_resources(self, job_id: int) -> bool: Returns: bool: True if successful, False otherwise """ + # Log a forensic snapshot of the consumer state BEFORE we destroy it. + # This is the highest-leverage line for post-mortem investigations. + await self._log_final_consumer_stats(job_id) + # Delete consumer first, then stream, then the durable DLQ advisory consumer consumer_deleted = await self.delete_consumer(job_id) stream_deleted = await self.delete_stream(job_id) diff --git a/ami/ml/orchestration/tests/test_nats_queue.py b/ami/ml/orchestration/tests/test_nats_queue.py index da47f3429..d1d651450 100644 --- a/ami/ml/orchestration/tests/test_nats_queue.py +++ b/ami/ml/orchestration/tests/test_nats_queue.py @@ -1,6 +1,7 @@ """Unit tests for TaskQueueManager.""" import json +import logging import unittest from unittest.mock import AsyncMock, MagicMock, patch @@ -235,3 +236,273 @@ async def test_get_dead_letter_image_ids_no_messages(self): self.assertEqual(result, []) mock_psub.unsubscribe.assert_called_once() + + +class TestTaskQueueManagerJobLogger(unittest.IsolatedAsyncioTestCase): + """Tests covering the job_logger lifecycle-mirroring behavior (#1220).""" + + def _create_sample_task(self): + return PipelineProcessingTask( + id="task-1", + image_id="img-1", + image_url="https://example.com/image.jpg", + ) + + def _create_mock_nats_connection(self): + """Duplicate of the sibling helper — kept local so the two test classes + can evolve independently.""" + nc = MagicMock() + nc.is_closed = False + nc.close = AsyncMock() + nc.flush = AsyncMock() + + js = MagicMock() + js.stream_info = AsyncMock() + js.add_stream = AsyncMock() + js.add_consumer = AsyncMock() + js.consumer_info = AsyncMock() + js.publish = AsyncMock(return_value=MagicMock(seq=1)) + js.pull_subscribe = AsyncMock() + js.delete_consumer = AsyncMock() + js.delete_stream = AsyncMock() + + return nc, js + + def _make_consumer_info( + self, + delivered=10, + ack_floor=8, + num_pending=2, + num_ack_pending=2, + num_redelivered=1, + max_deliver=5, + ack_wait=30, + max_ack_pending=1000, + deliver_policy="all", + ack_policy="explicit", + ): + """Build a ConsumerInfo-like MagicMock with nested SequenceInfo stubs + and a config sub-object for creation-time logging.""" + info = MagicMock() + info.delivered = MagicMock(consumer_seq=delivered) + info.ack_floor = MagicMock(consumer_seq=ack_floor) + info.num_pending = num_pending + info.num_ack_pending = num_ack_pending + info.num_redelivered = num_redelivered + info.config = MagicMock( + max_deliver=max_deliver, + ack_wait=ack_wait, + max_ack_pending=max_ack_pending, + deliver_policy=deliver_policy, + ack_policy=ack_policy, + ) + return info + + def _make_stream_info(self, messages=5, last_seq=5): + info = MagicMock() + info.state = MagicMock(messages=messages, last_seq=last_seq) + return info + + def _make_captured_logger(self) -> logging.Logger: + """A real Logger that captures to a list — better than MagicMock.log + because it exercises the actual `logger.log(level, msg)` dispatch and + surfaces any type surprises in the call site.""" + log_logger = logging.getLogger(f"test.job_logger.{id(self)}") + log_logger.handlers.clear() + log_logger.setLevel(logging.DEBUG) + + captured = [] + + class CaptureHandler(logging.Handler): + def emit(self, record): + captured.append((record.levelno, record.getMessage())) + + log_logger.addHandler(CaptureHandler()) + log_logger._captured = captured # type: ignore[attr-defined] + return log_logger + + async def test_create_stream_and_consumer_logs_to_job_logger(self): + """First publish on a brand-new job should log stream/consumer creation + to both the module logger and the passed-in job_logger.""" + nc, js = self._create_mock_nats_connection() + js.stream_info.side_effect = nats.js.errors.NotFoundError() + js.consumer_info.side_effect = nats.js.errors.NotFoundError() + js.add_consumer = AsyncMock(return_value=self._make_consumer_info(delivered=0, ack_floor=0)) + + job_logger = self._make_captured_logger() + captured = job_logger._captured # type: ignore[attr-defined] + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager(job_logger=job_logger) as manager: + await manager.publish_task(42, self._create_sample_task()) + + messages = [m for _, m in captured] + self.assertTrue( + any("Created NATS stream job_42" in m for m in messages), + f"expected stream-create log on job_logger, got {messages}", + ) + self.assertTrue( + any("Created NATS consumer job-42-consumer" in m for m in messages), + f"expected consumer-create log on job_logger, got {messages}", + ) + # Config snapshot should appear on the creation line. + self.assertTrue( + any("max_deliver=5" in m and "ack_policy=" in m for m in messages), + f"expected consumer config snapshot in log, got {messages}", + ) + + async def test_publish_success_does_not_spam_job_logger(self): + """After the first publish, subsequent publishes in the same session + must NOT emit new setup lines — per-message logging is forbidden for + 10k-image jobs.""" + nc, js = self._create_mock_nats_connection() + # First call hits NotFound (create path), subsequent calls succeed (reuse path) + js.stream_info.side_effect = [nats.js.errors.NotFoundError(), self._make_stream_info()] + js.consumer_info.side_effect = [nats.js.errors.NotFoundError(), self._make_consumer_info()] + js.add_consumer = AsyncMock(return_value=self._make_consumer_info(delivered=0, ack_floor=0)) + + job_logger = self._make_captured_logger() + captured = job_logger._captured # type: ignore[attr-defined] + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager(job_logger=job_logger) as manager: + await manager.publish_task(42, self._create_sample_task()) + captured_after_first = list(captured) + await manager.publish_task(42, self._create_sample_task()) + + new_messages = captured[len(captured_after_first) :] + # The second publish should not add any lifecycle log lines — dedup set + # should swallow them after the first publish for this job_id. + lifecycle_terms = ("Created NATS", "Reusing NATS") + for _, m in new_messages: + self.assertFalse( + any(term in m for term in lifecycle_terms), + f"unexpected lifecycle log on second publish: {m}", + ) + + async def test_reuse_stream_and_consumer_logs_with_stats(self): + """When stream and consumer already exist, the reuse line should include + a summary of current consumer state so forensic readers can tell whether + the queue is empty, backed up, or mid-redelivery.""" + nc, js = self._create_mock_nats_connection() + js.stream_info.return_value = self._make_stream_info(messages=17, last_seq=17) + js.consumer_info.return_value = self._make_consumer_info( + delivered=12, ack_floor=10, num_pending=5, num_ack_pending=2, num_redelivered=3 + ) + + job_logger = self._make_captured_logger() + captured = job_logger._captured # type: ignore[attr-defined] + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager(job_logger=job_logger) as manager: + await manager.publish_task(99, self._create_sample_task()) + + messages = [m for _, m in captured] + self.assertTrue( + any("Reusing NATS stream job_99" in m and "messages=17" in m and "last_seq=17" in m for m in messages), + f"expected reuse-stream log with state, got {messages}", + ) + self.assertTrue( + any( + "Reusing NATS consumer job-99-consumer" in m + and "delivered=12" in m + and "ack_floor=10" in m + and "num_pending=5" in m + and "num_redelivered=3" in m + for m in messages + ), + f"expected reuse-consumer log with stats, got {messages}", + ) + + async def test_cleanup_logs_final_consumer_stats_before_delete(self): + """cleanup_job_resources must emit a forensic snapshot of the consumer + state BEFORE the delete calls land. This is the single most useful line + for a post-mortem — without it, the consumer is already gone by the + time anyone investigates.""" + nc, js = self._create_mock_nats_connection() + final_info = self._make_consumer_info( + delivered=434, ack_floor=420, num_pending=0, num_ack_pending=14, num_redelivered=5 + ) + js.consumer_info.return_value = final_info + + job_logger = self._make_captured_logger() + captured = job_logger._captured # type: ignore[attr-defined] + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager(job_logger=job_logger) as manager: + await manager.cleanup_job_resources(123) + + messages = [m for _, m in captured] + finalizing_idx = None + delete_idx = None + for i, m in enumerate(messages): + if "Finalizing NATS consumer job-123-consumer" in m: + finalizing_idx = i + if delete_idx is None and "Deleted NATS consumer job-123-consumer" in m: + delete_idx = i + + self.assertIsNotNone(finalizing_idx, f"expected forensic finalize-log, got {messages}") + self.assertIsNotNone(delete_idx, f"expected delete-log, got {messages}") + self.assertLess( + finalizing_idx, # type: ignore[arg-type] + delete_idx, # type: ignore[arg-type] + "finalize snapshot must log BEFORE the delete", + ) + # The stats themselves should make it into the line. + final_line = messages[finalizing_idx] # type: ignore[index] + for expected in ("delivered=434", "ack_floor=420", "num_redelivered=5"): + self.assertIn(expected, final_line) + + async def test_cleanup_tolerates_missing_consumer(self): + """If the consumer is already gone when cleanup runs, the pre-delete + stats call must NOT raise or block — cleanup is called in failure + paths where the consumer may have already been deleted.""" + nc, js = self._create_mock_nats_connection() + js.consumer_info.side_effect = nats.js.errors.NotFoundError() + + job_logger = self._make_captured_logger() + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager(job_logger=job_logger) as manager: + # Must not raise. + result = await manager.cleanup_job_resources(77) + + # delete_consumer / delete_stream are still called on the mock and + # return truthy, so overall cleanup is reported successful. + self.assertTrue(result) + + async def test_publish_failure_surfaces_on_job_logger(self): + """A failed publish (which today only logs to the module logger) must + now also land on the job_logger so users see the failure in the UI.""" + nc, js = self._create_mock_nats_connection() + js.stream_info.return_value = self._make_stream_info() + js.consumer_info.return_value = self._make_consumer_info() + js.publish = AsyncMock(side_effect=RuntimeError("simulated nats outage")) + + job_logger = self._make_captured_logger() + captured = job_logger._captured # type: ignore[attr-defined] + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager(job_logger=job_logger) as manager: + result = await manager.publish_task(55, self._create_sample_task()) + + self.assertFalse(result) + messages = [m for level, m in captured if level >= logging.ERROR] + self.assertTrue( + any("Failed to publish task" in m and "simulated nats outage" in m for m in messages), + f"expected publish failure on job_logger, got {messages}", + ) + + async def test_no_job_logger_falls_back_to_module_logger_only(self): + """When job_logger is None (e.g., module-level uses like advisory + listener), lifecycle logs must still be emitted to the module logger + without crashing on a None attribute access.""" + nc, js = self._create_mock_nats_connection() + js.stream_info.side_effect = nats.js.errors.NotFoundError() + js.consumer_info.side_effect = nats.js.errors.NotFoundError() + js.add_consumer = AsyncMock(return_value=self._make_consumer_info(delivered=0, ack_floor=0)) + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager() as manager: # no job_logger passed + # Must not raise. + await manager.publish_task(1, self._create_sample_task())