diff --git a/src/backend/app.py b/src/backend/app.py index 4685b7fd1..37748f976 100644 --- a/src/backend/app.py +++ b/src/backend/app.py @@ -171,7 +171,7 @@ async def handle_chat(): # Extract request fields conversation_id = data.get("conversation_id") or str(uuid.uuid4()) - user_id = data.get("user_id", "anonymous") + user_id = data.get("user_id") or "anonymous" message = data.get("message", "") action = data.get("action") payload = data.get("payload", {}) @@ -337,7 +337,7 @@ async def _handle_parse_brief( if not has_existing_title: title_service = get_title_service() - generated_title = await title_service.generate_title(message) + generated_title = await title_service.generate_title(message, user_id=user_id, conversation_id=conversation_id) await cosmos_service.add_message_to_conversation( conversation_id=conversation_id, @@ -353,7 +353,7 @@ async def _handle_parse_brief( logger.exception(f"Failed to save message to CosmosDB: {e}") # Parse the brief - brief, questions, blocked = await orchestrator.parse_brief(message) + brief, questions, blocked = await orchestrator.parse_brief(message, user_id=user_id, conversation_id=conversation_id) if blocked: track_event_if_configured("Error_RAI_Check_Failed", {"conversation_id": conversation_id, "user_id": user_id, "status": "Brief parse blocked by RAI"}) @@ -537,7 +537,7 @@ async def _handle_refine_brief( logger.exception(f"Failed to save refinement message: {e}") # Use orchestrator to refine the brief - brief, questions, blocked = await orchestrator.parse_brief(message) + brief, questions, blocked = await orchestrator.parse_brief(message, user_id=user_id, conversation_id=conversation_id) if blocked: track_event_if_configured("Error_RAI_Check_Failed", {"conversation_id": conversation_id, "user_id": user_id, "status": "Brief refinement blocked by RAI"}) @@ -687,7 +687,9 @@ async def _handle_search_products( result = await orchestrator.select_products( request_text=message, current_products=current_products, - available_products=available_products + available_products=available_products, + user_id=user_id, + conversation_id=conversation_id ) # Save assistant response @@ -943,7 +945,9 @@ async def _run_regeneration_task( modification_request=modification_request, brief=brief, products=products_data, - previous_image_prompt=previous_image_prompt + previous_image_prompt=previous_image_prompt, + user_id=user_id, + conversation_id=conversation_id ) # Check for RAI block @@ -1132,7 +1136,7 @@ async def _handle_general_chat( if not has_existing_title: title_service = get_title_service() - generated_title = await title_service.generate_title(message) + generated_title = await title_service.generate_title(message, user_id=user_id, conversation_id=conversation_id) await cosmos_service.add_message_to_conversation( conversation_id=conversation_id, @@ -1151,7 +1155,8 @@ async def _handle_general_chat( response_content = "" async for response in orchestrator.process_message( message=message, - conversation_id=conversation_id + conversation_id=conversation_id, + user_id=user_id ): if response.get("content"): response_content += response.get("content", "") @@ -1197,7 +1202,9 @@ async def _run_generation_task(task_id: str, brief: CreativeBrief, products_data response = await orchestrator.generate_content( brief=brief, products=products_data, - generate_images=generate_images + generate_images=generate_images, + user_id=user_id, + conversation_id=conversation_id ) logger.info(f"Generation task {task_id} completed. Response keys: {list(response.keys()) if response else 'None'}") @@ -1303,7 +1310,7 @@ async def start_generation(): products_data = data.get("products", []) generate_images = data.get("generate_images", True) conversation_id = data.get("conversation_id") or str(uuid.uuid4()) - user_id = data.get("user_id", "anonymous") + user_id = data.get("user_id") or "anonymous" try: brief = CreativeBrief(**brief_data) diff --git a/src/backend/llm_token_telemetry.py b/src/backend/llm_token_telemetry.py new file mode 100644 index 000000000..745fceafd --- /dev/null +++ b/src/backend/llm_token_telemetry.py @@ -0,0 +1,1043 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Cross-accelerator LLM token-usage telemetry helpers. + +A single, dependency-light helper module that can be dropped into any Microsoft +Solution Accelerator to capture LLM token usage and emit standardized custom +events to Application Insights. + +Why this file exists +-------------------- +Seven solution accelerators have independently shipped near-identical +``token_usage_utils.py`` modules (see PRs: content-generation #860, CKM #933, +content-processing #586, Container-Migration #257, agentic-data-foundation +#383, customer-chatbot #218, MACAE #1003). They all: + +* extract token counts from agent_framework / Azure OpenAI responses, +* emit the same three custom events (``LLM_Token_Usage_Summary``, + ``LLM_Agent_Token_Usage``, ``LLM_Model_Token_Usage``), +* defensively swallow telemetry errors, +* duplicate the same KQL queries and Azure Workbook. + +This module consolidates the union of those behaviours behind one stable API +so each accelerator can replace its bespoke helper with an import. + +Public API +---------- +- ``TokenUsage`` -- immutable dataclass for counts +- ``extract_usage(obj)`` -- agent_framework run result / message +- ``extract_usage_from_dict(d)`` -- raw dict from any SDK +- ``extract_usage_from_stream_chunk`` -- streaming chunks +- ``extract_realtime_usage(resp)`` -- Azure AI Voice Live response.done +- ``TokenUsageEmitter`` -- emits the three events + optional + per-user / per-team / speech events +- ``TokenUsageScope`` -- context-manager that accumulates and + auto-emits on exit +- ``track_tokens`` -- decorator wrapper around the scope + +Design rules +------------ +* Telemetry NEVER raises. Extraction failures return ``None``; emission + failures are logged at WARNING. +* No hard dependency on ``azure-monitor-events-extension``; if absent the + emitter degrades to logging only. +* Arbitrary correlation dimensions are passed as ``**dimensions`` kwargs and + surface verbatim as custom-event properties. This is how each accelerator + attaches its own keys (``conversation_id``, ``process_id``, ``team_name``, + ``file_name``, ``tenant``, etc.) without forking the helper. +""" +from __future__ import annotations + +import asyncio +import functools +import logging +import os +import random +import time +from contextlib import AbstractContextManager +from dataclasses import dataclass, field +from typing import Any, Callable, Iterable, Mapping, Optional +from unittest.mock import NonCallableMock + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Event-name constants -- keep these stable; KQL queries and workbooks bind +# to these exact strings. +# --------------------------------------------------------------------------- +EVENT_SUMMARY = "LLM_Token_Usage_Summary" +EVENT_AGENT = "LLM_Agent_Token_Usage" +EVENT_MODEL = "LLM_Model_Token_Usage" +EVENT_USER = "LLM_User_Token_Usage" +EVENT_TEAM = "LLM_Team_Token_Usage" +EVENT_SPEECH = "Speech_Usage" + + +# Token-count field aliases observed across model providers / SDK versions. +_INPUT_KEYS = ( + "input_token_count", + "input_tokens", + "prompt_tokens", + "promptTokens", +) +_OUTPUT_KEYS = ( + "output_token_count", + "output_tokens", + "completion_tokens", + "completionTokens", +) +_TOTAL_KEYS = ( + "total_token_count", + "total_tokens", + "totalTokens", +) + + +# --------------------------------------------------------------------------- +# Data model +# --------------------------------------------------------------------------- +@dataclass(frozen=True) +class TokenUsage: + """Normalized token-usage record.""" + + input_tokens: int = 0 + output_tokens: int = 0 + total_tokens: int = 0 + + # Optional realtime / voice fields (None unless populated) + input_audio_tokens: Optional[int] = None + input_text_tokens: Optional[int] = None + input_cached_tokens: Optional[int] = None + output_audio_tokens: Optional[int] = None + output_text_tokens: Optional[int] = None + + @property + def has_any(self) -> bool: + return bool(self.input_tokens or self.output_tokens or self.total_tokens) + + def __add__(self, other: "TokenUsage") -> "TokenUsage": + if not isinstance(other, TokenUsage): + return NotImplemented + + def _sum(a: Optional[int], b: Optional[int]) -> Optional[int]: + if a is None and b is None: + return None + return (a or 0) + (b or 0) + + return TokenUsage( + input_tokens=self.input_tokens + other.input_tokens, + output_tokens=self.output_tokens + other.output_tokens, + total_tokens=self.total_tokens + other.total_tokens, + input_audio_tokens=_sum(self.input_audio_tokens, other.input_audio_tokens), + input_text_tokens=_sum(self.input_text_tokens, other.input_text_tokens), + input_cached_tokens=_sum(self.input_cached_tokens, other.input_cached_tokens), + output_audio_tokens=_sum(self.output_audio_tokens, other.output_audio_tokens), + output_text_tokens=_sum(self.output_text_tokens, other.output_text_tokens), + ) + + def to_event_props(self) -> dict[str, str]: + """Stringified property bag suitable for App Insights custom events.""" + props: dict[str, str] = { + "input_tokens": str(self.input_tokens), + "output_tokens": str(self.output_tokens), + "total_tokens": str(self.total_tokens), + } + for name in ( + "input_audio_tokens", + "input_text_tokens", + "input_cached_tokens", + "output_audio_tokens", + "output_text_tokens", + ): + value = getattr(self, name) + if value is not None: + props[name] = str(value) + return props + + +# --------------------------------------------------------------------------- +# Low-level coercion helpers +# --------------------------------------------------------------------------- +def _to_int(value: Any, default: int = 0) -> int: + """Best-effort int conversion; bool excluded; never raises.""" + if value is None or isinstance(value, bool): + return default + if isinstance(value, int): + return value + if isinstance(value, float): + return int(value) + if isinstance(value, str): + s = value.strip() + if s.isdigit(): + return int(s) + try: + return int(value) + except (TypeError, ValueError): + return default + + +def _to_int_or_none(value: Any) -> Optional[int]: + """Like :func:`_to_int` but preserves ``None`` for missing/absent values.""" + if value is None: + return None + return _to_int(value) + + +def _get(obj: Any, key: str, default: Any = None) -> Any: + """Read an attribute or dict key uniformly.""" + if obj is None: + return default + if isinstance(obj, Mapping): + return obj.get(key, default) + return getattr(obj, key, default) + + +def _is_iterable(obj: Any) -> bool: + """True only for real iterables (lists/tuples/sets/generators), NOT for + arbitrary objects (e.g. ``unittest.mock.Mock``) that happen to expose + ``__iter__`` but blow up on iteration.""" + if obj is None: + return False + if isinstance(obj, (list, tuple, set, frozenset)): + return True + # Strings are iterable but never the right answer for "messages". + if isinstance(obj, (str, bytes, bytearray, Mapping)): + return False + # Fall back to a duck-typed check, but reject Mock instances which would + # otherwise pretend to support iteration. + if isinstance(obj, NonCallableMock): + return False + return hasattr(obj, "__iter__") + + +def _read_counts(usage_obj: Any) -> Optional[TokenUsage]: + """Read ``input/output/total`` from any usage-bearing object/dict.""" + if usage_obj is None: + return None + + inp = out = tot = 0 + for k in _INPUT_KEYS: + v = _get(usage_obj, k) + if v: + inp = _to_int(v) + break + for k in _OUTPUT_KEYS: + v = _get(usage_obj, k) + if v: + out = _to_int(v) + break + for k in _TOTAL_KEYS: + v = _get(usage_obj, k) + if v: + tot = _to_int(v) + break + + if tot == 0 and (inp or out): + tot = inp + out + if not (inp or out or tot): + return None + return TokenUsage(input_tokens=inp, output_tokens=out, total_tokens=tot) + + +# --------------------------------------------------------------------------- +# Extraction -- public +# --------------------------------------------------------------------------- +def extract_usage(result: Any) -> Optional[TokenUsage]: + """Extract usage from an agent_framework run result, ChatMessage, or + OpenAI-style ChatCompletion. + + Checks (in order): + 1. ``result.usage_details`` or ``result.usage`` + 2. ``result.raw_representation.usage`` (OpenAI ChatCompletion shape) + 3. Aggregated ``result.messages[*].contents[*].usage_details`` + + Never raises -- returns ``None`` on any unexpected shape. + """ + if result is None: + return None + + try: + for attr in ("usage_details", "usage"): + found = _read_counts(_get(result, attr)) + if found: + return found + + raw = _get(result, "raw_representation") + if raw is not None: + found = _read_counts(_get(raw, "usage")) + if found: + return found + + aggregated = TokenUsage() + found_any = False + messages = _get(result, "messages") + if not _is_iterable(messages): + return None + for msg in messages: + contents = _get(msg, "contents") + if not _is_iterable(contents): + continue + for content in contents: + usage = _get(content, "usage_details") or _get(content, "usage") + piece = _read_counts(usage) + if piece: + aggregated = aggregated + piece + found_any = True + return aggregated if found_any else None + except Exception as exc: + logger.debug("extract_usage failed: %s", exc, exc_info=True) + return None + + +def extract_usage_from_dict(data: Any) -> Optional[TokenUsage]: + """Extract from a raw dict / SDK usage object.""" + return _read_counts(data) + + +def extract_usage_from_stream_chunk(chunk: Any) -> Optional[TokenUsage]: + """Streaming chunks: try the top-level shape, then ``chunk.metadata.usage``.""" + found = extract_usage(chunk) + if found: + return found + metadata = _get(chunk, "metadata") + if metadata is not None: + return _read_counts(_get(metadata, "usage")) + return None + + +def extract_realtime_usage(response_obj: Any) -> Optional[TokenUsage]: + """Azure AI Voice Live ``response.done`` payload extractor. + + Includes audio / text / cached sub-counts when present. + """ + usage = _get(response_obj, "usage") + if usage is None: + return None + + inp = _to_int(_get(usage, "input_tokens")) + out = _to_int(_get(usage, "output_tokens")) + tot = _to_int(_get(usage, "total_tokens")) + if tot == 0 and (inp or out): + tot = inp + out + + in_details = _get(usage, "input_token_details") or {} + out_details = _get(usage, "output_token_details") or {} + + record = TokenUsage( + input_tokens=inp, + output_tokens=out, + total_tokens=tot, + input_audio_tokens=_to_int_or_none(_get(in_details, "audio_tokens")), + input_text_tokens=_to_int_or_none(_get(in_details, "text_tokens")), + input_cached_tokens=_to_int_or_none(_get(in_details, "cached_tokens")), + output_audio_tokens=_to_int_or_none(_get(out_details, "audio_tokens")), + output_text_tokens=_to_int_or_none(_get(out_details, "text_tokens")), + ) + # Only return if at least one non-zero count surfaced. + if record.has_any or any( + v for v in ( + record.input_audio_tokens, + record.input_text_tokens, + record.input_cached_tokens, + record.output_audio_tokens, + record.output_text_tokens, + ) + ): + return record + return None + + +# --------------------------------------------------------------------------- +# Tool / sub-agent attribution +# --------------------------------------------------------------------------- +def detect_invoked_tools(result: Any) -> set[str]: + """Return the set of tool/function names invoked in an agent result, + inferred from ``function_call`` content items. + + Used by orchestrators that expose sub-agents via ``.as_tool()`` to attribute + token usage only to the sub-agents that were actually called. Never raises. + """ + invoked: set[str] = set() + try: + messages = _get(result, "messages") + if not _is_iterable(messages): + return invoked + for msg in messages: + contents = _get(msg, "contents") + if not _is_iterable(contents): + continue + for content in contents: + if _get(content, "type") == "function_call": + name = _get(content, "name") + if name: + invoked.add(str(name)) + except Exception as exc: + logger.debug("detect_invoked_tools failed: %s", exc, exc_info=True) + return invoked + + +# --------------------------------------------------------------------------- +# Event sink (optional Application Insights dependency) +# --------------------------------------------------------------------------- +EventSink = Callable[[str, Mapping[str, str]], None] + + +def _default_event_sink() -> Optional[EventSink]: + """Return ``azure.monitor.events.extension.track_event`` if importable, + else ``None``. Resolved lazily so the helper still works in unit tests + without the dependency installed.""" + try: + from azure.monitor.events.extension import track_event # type: ignore + except Exception: # pragma: no cover - optional dep + return None + return track_event + + +# --------------------------------------------------------------------------- +# Emitter +# --------------------------------------------------------------------------- +class TokenUsageEmitter: + """Emit standardized token-usage custom events. + + Parameters + ---------- + connection_string: + Application Insights connection string. If ``None`` (default), the + ``APPLICATIONINSIGHTS_CONNECTION_STRING`` env var is consulted. When + no connection string is configured the emitter logs and skips the + ``track_event`` call. + static_dimensions: + Properties merged into every event (e.g. ``{"app": "customer-chatbot"}``). + event_sink: + Callable ``(event_name, props_dict) -> None``. Defaults to + ``azure.monitor.events.extension.track_event``. Override in tests. + pricing: + Optional mapping ``{model_deployment_name -> (usd_per_1k_input, + usd_per_1k_output)}``. When provided, an ``estimated_cost_usd`` + property is attached to agent / model / summary events. Model lookup + is case-insensitive. Use this to avoid hard-coding rates in KQL. + user_id_hasher: + Optional callable ``str -> str`` applied to any ``user_id`` value + before it leaves the emitter. Use this to satisfy PII / GDPR + requirements (e.g. HMAC-SHA256 with a tenant-scoped salt). Applied + to both ``static_dimensions['user_id']`` (at construction) and + per-call ``user_id`` kwargs. + sample_rate: + Fraction of high-cardinality events (agent / model / user / team / + speech) actually shipped, in ``[0.0, 1.0]``. The cheap **summary + event always fires** regardless of sample_rate so per-request totals + remain accurate; only the per-dimension breakdown is sampled. + Defaults to ``1.0`` (no sampling). + logger: + Override the module logger. + """ + + def __init__( + self, + *, + connection_string: Optional[str] = None, + static_dimensions: Optional[Mapping[str, Any]] = None, + event_sink: Optional[EventSink] = None, + pricing: Optional[Mapping[str, tuple[float, float]]] = None, + user_id_hasher: Optional[Callable[[str], str]] = None, + sample_rate: float = 1.0, + logger: Optional[logging.Logger] = None, + ) -> None: + self._cs = connection_string if connection_string is not None else os.getenv( + "APPLICATIONINSIGHTS_CONNECTION_STRING" + ) + self._sink = event_sink if event_sink is not None else _default_event_sink() + self._log = logger or logging.getLogger(__name__) + + # PII hashing applied to user_id everywhere. + self._user_id_hasher = user_id_hasher + + # Sampling clamp to [0, 1]. + try: + sr = float(sample_rate) + except (TypeError, ValueError): + sr = 1.0 + self._sample_rate = max(0.0, min(1.0, sr)) + + # Case-insensitive pricing lookup. Values stored as a (in, out) tuple. + self._pricing: dict[str, tuple[float, float]] = {} + for model, rates in (pricing or {}).items(): + if not model or rates is None: + continue + try: + inp, out = rates + self._pricing[str(model).lower()] = (float(inp), float(out)) + except (TypeError, ValueError): + self._log.warning("Ignoring malformed pricing entry: %s=%r", model, rates) + + # Pre-stringify static dims once. user_id (if present) is hashed here + # so the raw value is never retained on the emitter. + raw_static = dict(static_dimensions or {}) + if "user_id" in raw_static: + raw_static["user_id"] = self._apply_user_id_hash(raw_static["user_id"]) + self._static: dict[str, str] = { + k: ("" if v is None else str(v)) for k, v in raw_static.items() + } + + # Performance counters. ``perf_*`` accumulate wall-clock nanoseconds + # spent inside ``emit()`` so callers can verify telemetry overhead is + # negligible. ``perf_slow_emit_threshold_ms`` is the soft threshold + # above which a WARNING is logged for an individual emit (default + # 50 ms -- emits should normally take well under 1 ms). + self._perf_total_ns: int = 0 + self._perf_emit_count: int = 0 + self._perf_max_ns: int = 0 + self.perf_slow_emit_threshold_ms: float = 50.0 + + # -- public surface --------------------------------------------------- + @property + def enabled(self) -> bool: + return bool(self._cs) and self._sink is not None + + @property + def sample_rate(self) -> float: + return self._sample_rate + + # -- internal helpers ------------------------------------------------- + def _apply_user_id_hash(self, value: Any) -> Any: + """Apply the configured user_id_hasher; never raises.""" + if value is None or value == "" or self._user_id_hasher is None: + return value + try: + return self._user_id_hasher(str(value)) + except Exception as exc: # never let hashing break telemetry + self._log.warning("user_id_hasher raised: %s", exc) + return value + + def _should_sample(self) -> bool: + """Sampling decision for high-cardinality events.""" + if self._sample_rate >= 1.0: + return True + if self._sample_rate <= 0.0: + return False + return random.random() < self._sample_rate + + def _cost_props( + self, model_deployment_name: Optional[str], usage: TokenUsage + ) -> dict[str, str]: + """Return ``{'estimated_cost_usd': '...'}`` when pricing is configured + for the given model, else ``{}``. 6-decimal formatting.""" + if not self._pricing or not model_deployment_name: + return {} + rate = self._pricing.get(model_deployment_name.lower()) + if not rate: + return {} + inp_rate, out_rate = rate + cost = (usage.input_tokens * inp_rate + usage.output_tokens * out_rate) / 1000.0 + return {"estimated_cost_usd": f"{cost:.6f}"} + + def _summary_cost_props( + self, + primary_model: Optional[str], + additional_agents: Mapping[str, str], + usage: TokenUsage, + ) -> dict[str, str]: + """Best-effort cost for the summary event: charge full usage at the + primary model's rate (the SDK aggregates sub-agent tokens to the + orchestrator, so apportioning is not possible without per-agent + usage). Falls back to silent skip when no rate is known.""" + if primary_model: + cost = self._cost_props(primary_model, usage) + if cost: + return cost + for m in additional_agents.values(): + cost = self._cost_props(m, usage) + if cost: + return cost + return {} + + def emit(self, event_name: str, **dimensions: Any) -> None: + """Low-level: emit an event with arbitrary properties. + + Non-string values are stringified. ``None`` values are dropped. Any + ``user_id`` value is passed through the configured hasher. + Never raises. Wall-clock duration is recorded for performance audit + (see :meth:`perf_stats`). + """ + start_ns = time.perf_counter_ns() + try: + props = dict(self._static) # cheap shallow copy of pre-stringified dims + for k, v in dimensions.items(): + if v is None: + continue + if k == "user_id": + v = self._apply_user_id_hash(v) + if v is None or v == "": + continue + props[k] = v if isinstance(v, str) else str(v) + + if not self.enabled: + self._log.debug( + "App Insights not configured -- skipping event %s (%s)", + event_name, props, + ) + return + try: + self._sink(event_name, props) # type: ignore[misc] + except Exception as exc: # never break the caller + self._log.warning("track_event(%s) failed: %s", event_name, exc) + finally: + elapsed_ns = time.perf_counter_ns() - start_ns + self._perf_total_ns += elapsed_ns + self._perf_emit_count += 1 + if elapsed_ns > self._perf_max_ns: + self._perf_max_ns = elapsed_ns + elapsed_ms = elapsed_ns / 1_000_000.0 + if elapsed_ms > self.perf_slow_emit_threshold_ms: + self._log.warning( + "Token telemetry emit slow: event=%s duration_ms=%.3f", + event_name, elapsed_ms, + ) + else: + self._log.debug( + "Token telemetry emit: event=%s duration_ms=%.3f", + event_name, elapsed_ms, + ) + + # -- performance audit ------------------------------------------------ + def perf_stats(self) -> dict[str, float]: + """Return cumulative telemetry-overhead stats since process start + (or since :meth:`reset_perf_stats`). + + Keys: + ``emit_count`` -- number of events emitted + ``total_ms`` -- total wall-clock time spent inside ``emit`` + ``avg_ms`` -- mean per-event duration + ``max_ms`` -- slowest single emit observed + """ + count = self._perf_emit_count + total_ms = self._perf_total_ns / 1_000_000.0 + return { + "emit_count": float(count), + "total_ms": total_ms, + "avg_ms": (total_ms / count) if count else 0.0, + "max_ms": self._perf_max_ns / 1_000_000.0, + } + + def reset_perf_stats(self) -> None: + """Zero the perf counters (useful for tests and load-tests).""" + self._perf_total_ns = 0 + self._perf_emit_count = 0 + self._perf_max_ns = 0 + + # -- typed convenience emitters -------------------------------------- + def emit_agent( + self, + *, + agent_name: str, + model_deployment_name: str, + usage: TokenUsage, + **dimensions: Any, + ) -> None: + if not usage.has_any or not self._should_sample(): + return + self.emit( + EVENT_AGENT, + agent_name=agent_name, + model_deployment_name=model_deployment_name, + **usage.to_event_props(), + **self._cost_props(model_deployment_name, usage), + **dimensions, + ) + + def emit_model( + self, + *, + model_deployment_name: str, + usage: TokenUsage, + **dimensions: Any, + ) -> None: + if not usage.has_any or not self._should_sample(): + return + self.emit( + EVENT_MODEL, + model_deployment_name=model_deployment_name, + **usage.to_event_props(), + **self._cost_props(model_deployment_name, usage), + **dimensions, + ) + + def emit_user( + self, + *, + user_id: str, + usage: TokenUsage, + **dimensions: Any, + ) -> None: + if not usage.has_any or not user_id or not self._should_sample(): + return + self.emit( + EVENT_USER, + user_id=user_id, + **usage.to_event_props(), + **dimensions, + ) + + def emit_team( + self, + *, + team_name: str, + usage: TokenUsage, + **dimensions: Any, + ) -> None: + if not usage.has_any or not team_name or not self._should_sample(): + return + self.emit( + EVENT_TEAM, + team_name=team_name, + **usage.to_event_props(), + **dimensions, + ) + + def emit_summary( + self, + *, + usage: TokenUsage, + agent_count: int = 1, + model_count: int = 1, + primary_model: Optional[str] = None, + additional_agents: Optional[Mapping[str, str]] = None, + **dimensions: Any, + ) -> None: + """The summary event always fires (ignores ``sample_rate``) so per- + request totals remain accurate even when high-cardinality events are + sampled.""" + if not usage.has_any: + return + # Summary historically uses ``total_input_tokens`` / ``total_output_tokens`` + # field names; preserve that wire format for backward compatibility. + props = { + "total_input_tokens": str(usage.input_tokens), + "total_output_tokens": str(usage.output_tokens), + "total_tokens": str(usage.total_tokens), + "agent_count": str(agent_count), + "model_count": str(model_count), + "sample_rate": f"{self._sample_rate:.4f}", + } + # Carry over realtime sub-counts if present. + for k, v in usage.to_event_props().items(): + props.setdefault(k, v) + # Optional total cost. + props.update(self._summary_cost_props(primary_model, additional_agents or {}, usage)) + self.emit(EVENT_SUMMARY, **props, **dimensions) + + def emit_speech( + self, + *, + model_deployment_name: str, + source: str, + usage: TokenUsage, + **dimensions: Any, + ) -> None: + """Voice-Live / realtime speech usage event.""" + if not self._should_sample(): + return + self.emit( + EVENT_SPEECH, + model_deployment_name=model_deployment_name, + source=source, + **usage.to_event_props(), + **self._cost_props(model_deployment_name, usage), + **dimensions, + ) + + # -- combined emit: summary + agent + per-distinct-model --------------- + def emit_all( + self, + *, + agent_name: str, + model_deployment_name: str, + usage: TokenUsage, + additional_agents: Optional[Mapping[str, str]] = None, + emit_user_event: bool = False, + emit_team_event: bool = False, + **dimensions: Any, + ) -> None: + """Convenience: emit summary, agent, and one model event per distinct + model deployment in one shot. + + ``additional_agents`` maps sub-agent name -> its model deployment name + so callers can describe orchestrators that involve multiple agents. + + ``emit_user_event`` / ``emit_team_event`` opt in to the user/team + events; ``user_id`` / ``team_name`` must be present in dimensions for + those to fire. + """ + if not usage.has_any: + return + + agents = {agent_name: model_deployment_name} + if additional_agents: + agents.update({k: v for k, v in additional_agents.items() if k}) + models = {m for m in agents.values() if m} + + # Wall-clock timing of the whole emit_all path so callers (or tests) + # can verify the telemetry path stays cheap relative to the LLM call + # it instruments. + batch_start_ns = time.perf_counter_ns() + + # Defer summary until last so we can stamp the batch overhead on it. + self.emit_agent( + agent_name=agent_name, + model_deployment_name=model_deployment_name, + usage=usage, + **dimensions, + ) + for model in models: + self.emit_model( + model_deployment_name=model, + usage=usage, + **dimensions, + ) + if emit_user_event and dimensions.get("user_id"): + self.emit_user( + user_id=str(dimensions["user_id"]), + usage=usage, + agent_name=agent_name, + model_deployment_name=model_deployment_name, + ) + if emit_team_event and dimensions.get("team_name"): + self.emit_team( + team_name=str(dimensions["team_name"]), + usage=usage, + agent_name=agent_name, + model_deployment_name=model_deployment_name, + ) + + batch_overhead_ms = (time.perf_counter_ns() - batch_start_ns) / 1_000_000.0 + self.emit_summary( + usage=usage, + agent_count=len(agents), + model_count=len(models) or 1, + primary_model=model_deployment_name, + additional_agents=additional_agents, + telemetry_overhead_ms=f"{batch_overhead_ms:.3f}", + **dimensions, + ) + + self._log.debug( + "[TOKEN USAGE] agent=%s model=%s input=%d output=%d total=%d", + agent_name, + model_deployment_name, + usage.input_tokens, + usage.output_tokens, + usage.total_tokens, + ) + + +# --------------------------------------------------------------------------- +# Scope / decorator sugar +# --------------------------------------------------------------------------- +@dataclass +class TokenUsageScope(AbstractContextManager): + """Accumulate usage across multiple results, then emit on exit. + + Example:: + + with TokenUsageScope(emitter, + agent_name="chat", + model_deployment_name=cfg.model, + user_id=user_id) as scope: + result = await agent.run(prompt) + scope.add(result) # extracts and accumulates + """ + + emitter: TokenUsageEmitter + agent_name: str + model_deployment_name: str + dimensions: dict[str, Any] = field(default_factory=dict) + additional_agents: dict[str, str] = field(default_factory=dict) + emit_user_event: bool = False + emit_team_event: bool = False + usage: TokenUsage = field(default_factory=TokenUsage) + + def __init__( + self, + emitter: TokenUsageEmitter, + *, + agent_name: str, + model_deployment_name: str, + additional_agents: Optional[Mapping[str, str]] = None, + emit_user_event: bool = False, + emit_team_event: bool = False, + **dimensions: Any, + ) -> None: + self.emitter = emitter + self.agent_name = agent_name + self.model_deployment_name = model_deployment_name + self.additional_agents = dict(additional_agents or {}) + self.emit_user_event = emit_user_event + self.emit_team_event = emit_team_event + self.dimensions = dict(dimensions) + self.usage = TokenUsage() + # Wall-clock nanoseconds spent inside extraction (``add*``) and the + # final ``__exit__`` emit, respectively. Surfaced for callers that + # want to verify the helper doesn't add measurable latency. Available + # as ``scope.extract_ms`` / ``scope.emit_ms`` after the scope closes. + self._extract_ns: int = 0 + self._emit_ns: int = 0 + + # -- accumulation ----------------------------------------------------- + def add(self, source: Any) -> Optional[TokenUsage]: + """Extract usage from any supported shape and add to the running total. + + Never raises -- extraction failures return ``None`` and are logged + at DEBUG. + """ + start_ns = time.perf_counter_ns() + try: + found = extract_usage_from_stream_chunk(source) + except Exception as exc: # belt + braces; extractors are already safe + logger.debug("TokenUsageScope.add failed: %s", exc, exc_info=True) + return None + finally: + self._extract_ns += time.perf_counter_ns() - start_ns + if found: + self.usage = self.usage + found + return found + + def add_usage(self, usage: TokenUsage) -> None: + self.usage = self.usage + usage + + def add_chunks(self, chunks: Iterable[Any]) -> None: + for c in chunks: + self.add(c) + + # -- timing properties ----------------------------------------------- + @property + def extract_ms(self) -> float: + """Total ms spent inside :meth:`add` / :meth:`add_chunks`.""" + return self._extract_ns / 1_000_000.0 + + @property + def emit_ms(self) -> float: + """Total ms spent in the on-exit emit batch.""" + return self._emit_ns / 1_000_000.0 + + @property + def total_overhead_ms(self) -> float: + """Total telemetry overhead added by this scope (extract + emit).""" + return self.extract_ms + self.emit_ms + + # -- context manager -------------------------------------------------- + def __exit__(self, exc_type, exc, tb) -> None: + # Always emit (best-effort) regardless of exception status. + emit_start_ns = time.perf_counter_ns() + try: + self.emitter.emit_all( + agent_name=self.agent_name, + model_deployment_name=self.model_deployment_name, + usage=self.usage, + additional_agents=self.additional_agents, + emit_user_event=self.emit_user_event, + emit_team_event=self.emit_team_event, + **self.dimensions, + ) + except Exception as emit_exc: # pragma: no cover - belt + braces + logger.warning("TokenUsageScope emit failed: %s", emit_exc) + finally: + self._emit_ns += time.perf_counter_ns() - emit_start_ns + logger.debug( + "TokenUsageScope overhead: agent=%s extract_ms=%.3f " + "emit_ms=%.3f total_ms=%.3f", + self.agent_name, + self.extract_ms, + self.emit_ms, + self.total_overhead_ms, + ) + return None # do not suppress exceptions + + +def track_tokens( + emitter: TokenUsageEmitter, + *, + agent_name: str, + model_deployment_name: str, + dimension_args: Optional[Mapping[str, str]] = None, + additional_agents: Optional[Mapping[str, str]] = None, + emit_user_event: bool = False, + emit_team_event: bool = False, +): + """Decorator: wrap an async or sync function that returns an LLM result. + + ``dimension_args`` maps emitted-property-name -> callable-keyword-argument + name so per-call values (e.g. ``user_id``) are forwarded to the event. + + Example:: + + @track_tokens(emitter, + agent_name="chat", + model_deployment_name=settings.model, + dimension_args={"user_id": "user_id", + "session_id": "session_id"}) + async def run_chat(prompt, *, user_id, session_id): ... + """ + + dim_args = dict(dimension_args or {}) + + def _decorator(fn: Callable[..., Any]): + is_coro = _is_coroutine_function(fn) + + if is_coro: + @functools.wraps(fn) + async def _aw(*args, **kwargs) -> Any: + with _scope_for(kwargs) as scope: + result = await fn(*args, **kwargs) + scope.add(result) + return result + return _aw + + @functools.wraps(fn) + def _sw(*args, **kwargs) -> Any: + with _scope_for(kwargs) as scope: + result = fn(*args, **kwargs) + scope.add(result) + return result + return _sw + + def _scope_for(call_kwargs: Mapping[str, Any]) -> TokenUsageScope: + dimensions = { + prop: call_kwargs.get(kw) + for prop, kw in dim_args.items() + if call_kwargs.get(kw) is not None + } + return TokenUsageScope( + emitter, + agent_name=agent_name, + model_deployment_name=model_deployment_name, + additional_agents=additional_agents, + emit_user_event=emit_user_event, + emit_team_event=emit_team_event, + **dimensions, + ) + + return _decorator + + +def _is_coroutine_function(fn: Callable[..., Any]) -> bool: + return asyncio.iscoroutinefunction(fn) + + +__all__ = [ + "EVENT_SUMMARY", + "EVENT_AGENT", + "EVENT_MODEL", + "EVENT_USER", + "EVENT_TEAM", + "EVENT_SPEECH", + "TokenUsage", + "TokenUsageEmitter", + "TokenUsageScope", + "track_tokens", + "extract_usage", + "extract_usage_from_dict", + "extract_usage_from_stream_chunk", + "extract_realtime_usage", + "detect_invoked_tools", +] diff --git a/src/backend/orchestrator.py b/src/backend/orchestrator.py index c31122259..9e45cad49 100644 --- a/src/backend/orchestrator.py +++ b/src/backend/orchestrator.py @@ -19,7 +19,8 @@ import json import logging import re -from typing import AsyncIterator, Optional, cast +from contextvars import ContextVar +from typing import Any, AsyncIterator, Mapping, Optional, cast from agent_framework import ( Agent, @@ -38,14 +39,34 @@ FOUNDRY_AVAILABLE = False AIProjectClient = None +from llm_token_telemetry import ( + TokenUsage, + TokenUsageEmitter, + extract_usage, + extract_usage_from_dict, + extract_usage_from_stream_chunk, +) from models import CreativeBrief from settings import app_settings +from telemetry import token_emitter logger = logging.getLogger(__name__) # Token endpoint for Azure Cognitive Services (used for Azure OpenAI) TOKEN_ENDPOINT = "https://cognitiveservices.azure.com/.default" +# Per-request user_id propagated to _RequestTokenTracker instances created +# anywhere inside the request (including deep workflow helpers like +# ``_generate_foundry_image``). Set at the entry points of the public +# orchestrator methods; read by ``_new_token_accumulator``. +_current_user_id: ContextVar[str] = ContextVar("_current_user_id", default="") + +# Per-request conversation_id propagated the same way as ``_current_user_id`` +# so token-usage telemetry emitted from deep helpers (image generation, +# regenerate flows, etc.) can be correlated by conversation in Application +# Insights / KQL even when the helper isn't directly given a conversation_id. +_current_conversation_id: ContextVar[str] = ContextVar("_current_conversation_id", default="") + # Event type constants for type-safe dispatch (avoids string typos) EVENT_STATUS: WorkflowEventType = "status" EVENT_REQUEST_INFO: WorkflowEventType = "request_info" @@ -482,6 +503,157 @@ def _filter_system_prompt_from_response(response_text: str) -> str: """ +class _RequestTokenTracker: + """Per-request multi-agent token accumulator. + + Aggregates ``TokenUsage`` per agent and per model deployment over the + lifetime of a single orchestrator request, then emits the standardized + ``LLM_Token_Usage_Summary`` / ``LLM_Agent_Token_Usage`` / + ``LLM_Model_Token_Usage`` custom events via the shared + :class:`TokenUsageEmitter` on :meth:`flush`. Identical event names and + dimension keys to the cross-accelerator helper in + :mod:`llm_token_telemetry`. Telemetry failures are logged but never + raised. + """ + + __slots__ = ( + "_emitter", + "_user_id", + "_conversation_id", + "_agent_model_map", + "_default_model", + "by_agent", + "by_model", + "total", + ) + + def __init__( + self, + emitter: TokenUsageEmitter, + *, + user_id: str = "", + conversation_id: str = "", + agent_model_map: Optional[Mapping[str, str]] = None, + default_model: str = "", + ) -> None: + self._emitter = emitter + self._user_id = user_id or "" + self._conversation_id = conversation_id or "" + self._agent_model_map: dict[str, str] = dict(agent_model_map or {}) + self._default_model = default_model or "" + self.by_agent: dict[str, tuple[TokenUsage, str]] = {} + self.by_model: dict[str, TokenUsage] = {} + self.total: TokenUsage = TokenUsage() + + def _resolve_model(self, agent_name: str) -> str: + return self._agent_model_map.get(agent_name) or self._default_model + + def _add(self, agent_name: str, usage: TokenUsage) -> None: + if not usage.has_any: + return + agent = agent_name or "unknown_agent" + model = self._resolve_model(agent) + prev_usage, prev_model = self.by_agent.get(agent, (TokenUsage(), model)) + if prev_model and model and prev_model != model: + resolved_model = "multiple" + else: + resolved_model = prev_model or model + self.by_agent[agent] = (prev_usage + usage, resolved_model) + if model: + self.by_model[model] = self.by_model.get(model, TokenUsage()) + usage + self.total = self.total + usage + + def record(self, agent_name: str, usage: TokenUsage) -> None: + """Record a pre-extracted :class:`TokenUsage` for the named agent.""" + self._add(agent_name, usage) + + def record_response(self, *, agent_name: str, response: Any) -> bool: + """Extract usage from an ``AgentResponse`` and record it. Returns True on success.""" + usage = extract_usage(response) + if usage: + self._add(agent_name, usage) + return True + return False + + def record_event(self, event: Any) -> bool: + """Extract usage from a workflow ``run_stream`` event and record it. + + Reads ``event.executor_id`` for per-agent attribution and uses + ``extract_usage_from_stream_chunk`` (which tries the top-level shape + then ``metadata.usage``) to cover both ``AgentRunUpdateEvent`` and + ``AgentRunEvent`` data shapes. + """ + if event is None: + return False + executor_id = getattr(event, "executor_id", None) + data = getattr(event, "data", None) + if data is None or not executor_id: + return False + usage = extract_usage_from_stream_chunk(data) + if usage: + self._add(executor_id, usage) + return True + return False + + def record_image_api_response( + self, *, agent_name: str, response_json: Optional[dict], model: str = "" + ) -> bool: + """Record token usage from an image-generation REST response (OpenAI shape).""" + if not isinstance(response_json, dict): + return False + usage = extract_usage_from_dict(response_json.get("usage")) + if not usage: + return False + if model: + self._agent_model_map[agent_name] = model + self._add(agent_name, usage) + return True + + def has_data(self) -> bool: + return self.total.has_any + + def flush(self, *, source: str = "") -> None: + """Emit aggregated LLM_*_Token_Usage events. Safe to call once per request.""" + if not self.has_data(): + return + dims = { + "user_id": self._user_id, + "conversation_id": self._conversation_id, + "source": source, + } + for agent_name, (usage, model) in self.by_agent.items(): + self._emitter.emit_agent( + agent_name=agent_name, + model_deployment_name=model or self._default_model, + usage=usage, + **dims, + ) + for model_name, usage in self.by_model.items(): + self._emitter.emit_model( + model_deployment_name=model_name, + usage=usage, + **dims, + ) + primary_model = next(iter(self.by_model), self._default_model) + self._emitter.emit_summary( + usage=self.total, + agent_count=len(self.by_agent), + model_count=len(self.by_model) or 1, + primary_model=primary_model, + **dims, + ) + logger.debug( + "[TOKEN USAGE] source=%s total=%d (in=%d, out=%d) " + "agents=%s models=%s", + source, + self.total.total_tokens, + self.total.input_tokens, + self.total.output_tokens, + {k: v[0].total_tokens for k, v in self.by_agent.items()}, + {k: v.total_tokens for k, v in self.by_model.items()}, + ) + + class ContentGenerationOrchestrator: """ Orchestrates the multi-agent content generation workflow using @@ -509,6 +681,11 @@ def __init__(self): self._initialized = False self._use_foundry = app_settings.ai_foundry.use_foundry self._credential = None + # agent_name -> deployment name, populated in initialize(). + # Used to attach a model dimension to LLM_*_Token_Usage events. + self._agent_model_map: dict[str, str] = {} + self._default_model: str = "" + self._image_model: str = "" def _get_chat_client(self): """Get or create the chat client (Azure OpenAI or Foundry).""" @@ -692,13 +869,59 @@ def initialize(self) -> None: ) self._initialized = True + + # Build the agent_name -> model deployment map used for token-usage + # telemetry. All chat agents share the same chat client deployment. + chat_model = ( + app_settings.ai_foundry.model_deployment + if self._use_foundry + else app_settings.azure_openai.gpt_model + ) or app_settings.azure_openai.gpt_model or "" + image_model = ( + app_settings.ai_foundry.image_deployment + if self._use_foundry + else app_settings.azure_openai.image_model + ) or app_settings.azure_openai.image_model or "" + self._default_model = chat_model + self._image_model = image_model + self._agent_model_map = { + f"triage{name_sep}agent": chat_model, + f"planning{name_sep}agent": chat_model, + f"research{name_sep}agent": chat_model, + f"text{name_sep}content{name_sep}agent": chat_model, + f"image{name_sep}content{name_sep}agent": chat_model, + f"compliance{name_sep}agent": chat_model, + f"rai{name_sep}agent": chat_model, + } + logger.info(f"Content Generation Orchestrator initialized successfully ({mode_str} mode)") + def _new_token_accumulator( + self, conversation_id: str = "", user_id: str = "" + ) -> _RequestTokenTracker: + """Create a :class:`_RequestTokenTracker` pre-populated with this + orchestrator's agent->model map and default chat model. Telemetry + is best-effort. + + If ``user_id`` / ``conversation_id`` are not provided, falls back to + the per-request values stored in the ``_current_user_id`` / + ``_current_conversation_id`` ContextVars so trackers created deep + inside the workflow still carry the caller's correlation ids. + """ + return _RequestTokenTracker( + token_emitter, + conversation_id=conversation_id or _current_conversation_id.get(""), + user_id=user_id or _current_user_id.get(""), + agent_model_map=self._agent_model_map, + default_model=self._default_model, + ) + async def process_message( self, message: str, conversation_id: str, - context: Optional[dict] = None + context: Optional[dict] = None, + user_id: str = "" ) -> AsyncIterator[dict]: """ Process a user message through the orchestrated workflow. @@ -741,12 +964,26 @@ async def process_message( if context: full_input = f"Context:\n{json.dumps(context, indent=2)}\n\nUser Message:\n{message}" + _ctx_token = _current_user_id.set(user_id or "") + _ctx_conv = _current_conversation_id.set(conversation_id or "") + # Defined outside the try so the except/finally branches can safely + # reference ``token_acc`` even if creation fails. Each flush call is + # guarded by ``if token_acc is not None`` to avoid NoneType errors. + token_acc: Optional[_RequestTokenTracker] = None try: + token_acc = self._new_token_accumulator(conversation_id, user_id) + # Collect events from the workflow stream events = [] async for event in self._workflow.run_stream(full_input): events.append(event) + # Best-effort token-usage capture; never break the user flow. + try: + token_acc.record_event(event) + except Exception as _tu_err: + logger.debug("token_usage record_event failed: %s", _tu_err) + # Handle different event types from the workflow if event.type == EVENT_STATUS: status_name = event.state.name if event.state else str(event.data) @@ -805,20 +1042,36 @@ async def process_message( "metadata": {"conversation_id": conversation_id} } + # Emit aggregated LLM_*_Token_Usage events for the request. + if token_acc is not None: + try: + token_acc.flush(source="process_message") + except Exception as _tu_err: + logger.debug("token_usage flush failed: %s", _tu_err) + except Exception as e: logger.exception(f"Error processing message: {e}") + if token_acc is not None: + try: + token_acc.flush(source="process_message:error") + except Exception: + pass yield { "type": "error", "content": f"An error occurred: {str(e)}", "is_final": True, "metadata": {"conversation_id": conversation_id} } + finally: + _current_user_id.reset(_ctx_token) + _current_conversation_id.reset(_ctx_conv) async def send_user_response( self, request_id: str, user_response: str, - conversation_id: str + conversation_id: str, + user_id: str = "" ) -> AsyncIterator[dict]: """ Send a user response to a pending workflow request. @@ -849,9 +1102,19 @@ async def send_user_response( } return # Exit immediately - do not continue workflow + _ctx_token = _current_user_id.set(user_id or "") + _ctx_conv = _current_conversation_id.set(conversation_id or "") + # See process_message for the rationale of the None-init pattern. + token_acc: Optional[_RequestTokenTracker] = None try: + token_acc = self._new_token_accumulator(conversation_id, user_id) responses = {request_id: user_response} async for event in self._workflow.send_responses_streaming(responses): + try: + token_acc.record_event(event) + except Exception as _tu_err: + logger.debug("token_usage record_event failed: %s", _tu_err) + if event.type == EVENT_STATUS: status_name = event.state.name if event.state else str(event.data) yield { @@ -901,18 +1164,40 @@ async def send_user_response( "metadata": {"conversation_id": conversation_id} } + if token_acc is not None: + try: + token_acc.flush(source="send_user_response") + except Exception as _tu_err: + logger.debug("token_usage flush failed: %s", _tu_err) + except Exception as e: logger.exception(f"Error sending user response: {e}") + if token_acc is not None: + try: + token_acc.flush(source="send_user_response:error") + except Exception: + pass yield { "type": "error", "content": f"An error occurred: {str(e)}", "is_final": True, "metadata": {"conversation_id": conversation_id} } + finally: + try: + _current_user_id.reset(_ctx_token) + except (LookupError, ValueError, NameError): + pass + try: + _current_conversation_id.reset(_ctx_conv) + except (LookupError, ValueError, NameError): + pass async def parse_brief( self, - brief_text: str + brief_text: str, + user_id: str = "", + conversation_id: str = "" ) -> tuple[CreativeBrief, str | None, bool]: """ Parse a free-text creative brief into structured format. @@ -920,6 +1205,9 @@ async def parse_brief( Args: brief_text: Free-text creative brief from user + user_id: Optional caller's user id, propagated to token usage telemetry + conversation_id: Optional conversation id, propagated to token usage + telemetry for correlation in Application Insights. Returns: tuple: (CreativeBrief, clarifying_questions_or_none, is_blocked) @@ -930,6 +1218,19 @@ async def parse_brief( if not self._initialized: self.initialize() + _ctx_token = _current_user_id.set(user_id or "") + _ctx_conv = _current_conversation_id.set(conversation_id or "") + try: + return await self._parse_brief_impl(brief_text, user_id) + finally: + _current_user_id.reset(_ctx_token) + _current_conversation_id.reset(_ctx_conv) + + async def _parse_brief_impl( + self, + brief_text: str, + user_id: str = "" + ) -> tuple[CreativeBrief, str | None, bool]: # PROACTIVE CONTENT SAFETY CHECK - Block harmful content at input layer is_harmful, matched_pattern = _check_input_for_harmful_content(brief_text) if is_harmful: @@ -949,8 +1250,13 @@ async def parse_brief( return empty_brief, RAI_HARMFUL_CONTENT_RESPONSE, True # SECONDARY RAI CHECK - Use LLM-based classifier for comprehensive safety/scope validation + token_acc = self._new_token_accumulator(user_id=user_id) try: rai_response = await self._rai_agent.run(brief_text) + try: + token_acc.record_response(agent_name="rai_agent", response=rai_response) + except Exception as _tu_err: + logger.debug("token_usage record (rai_agent) failed: %s", _tu_err) rai_result = str(rai_response).strip().upper() logger.info(f"RAI agent response for parse_brief: {rai_result}") @@ -967,6 +1273,10 @@ async def parse_brief( visual_guidelines="", cta="" ) + try: + token_acc.flush(source="parse_brief:rai_blocked") + except Exception: + pass return empty_brief, RAI_HARMFUL_CONTENT_RESPONSE, True except Exception as rai_error: # Log the error but continue - don't block legitimate requests due to RAI agent failures @@ -1020,6 +1330,14 @@ async def parse_brief( # Use the agent's run method response = await planning_agent.run(analysis_prompt) + try: + token_acc.record_response(agent_name="planning_agent", response=response) + except Exception as _tu_err: + logger.debug("token_usage record (planning_agent) failed: %s", _tu_err) + try: + token_acc.flush(source="parse_brief") + except Exception: + pass # Parse the analysis response try: @@ -1129,7 +1447,9 @@ async def select_products( self, request_text: str, current_products: list = None, - available_products: list = None + available_products: list = None, + user_id: str = "", + conversation_id: str = "" ) -> dict: """ Select or modify product selection via natural language. @@ -1180,7 +1500,18 @@ async def select_products( """ try: + token_acc = self._new_token_accumulator( + conversation_id=conversation_id, user_id=user_id + ) response = await research_agent.run(select_prompt) + try: + token_acc.record_response(agent_name="research_agent", response=response) + except Exception as _tu_err: + logger.debug("token_usage record (research_agent) failed: %s", _tu_err) + try: + token_acc.flush(source="select_products") + except Exception: + pass response_text = str(response) # Extract JSON from response @@ -1304,6 +1635,19 @@ async def _generate_foundry_image(self, image_prompt: str, results: dict) -> Non response_data = response.json() + # Capture token usage from image API response (gpt-image-1 returns + # a 'usage' field with input/output/total token counts). + try: + img_acc = self._new_token_accumulator() + img_acc.record_image_api_response( + agent_name="image_content_agent", + response_json=response_data, + model=image_deployment or self._image_model, + ) + img_acc.flush(source="foundry_image_generation") + except Exception as _tu_err: + logger.debug("token_usage capture (foundry image) failed: %s", _tu_err) + # Extract image data from response data = response_data.get("data", []) if not data: @@ -1394,7 +1738,9 @@ async def generate_content( self, brief: CreativeBrief, products: list = None, - generate_images: bool = True + generate_images: bool = True, + user_id: str = "", + conversation_id: str = "" ) -> dict: """ Generate complete content package from a confirmed creative brief. @@ -1403,6 +1749,10 @@ async def generate_content( brief: Confirmed creative brief products: List of products to feature generate_images: Whether to generate images + user_id: Optional caller's user id, propagated to token usage telemetry + conversation_id: Optional conversation id, propagated to token usage + telemetry (including from deep helpers like ``_generate_foundry_image``) + so image-generation events can be correlated by conversation in KQL. Returns: dict: Generated content with compliance results @@ -1410,6 +1760,21 @@ async def generate_content( if not self._initialized: self.initialize() + _ctx_token = _current_user_id.set(user_id or "") + _ctx_conv = _current_conversation_id.set(conversation_id or "") + try: + return await self._generate_content_impl(brief, products, generate_images, user_id) + finally: + _current_user_id.reset(_ctx_token) + _current_conversation_id.reset(_ctx_conv) + + async def _generate_content_impl( + self, + brief: CreativeBrief, + products: list = None, + generate_images: bool = True, + user_id: str = "" + ) -> dict: results = { "text_content": None, "image_prompt": None, @@ -1433,9 +1798,16 @@ async def generate_content( Products to feature: {json.dumps(products or [])} """ + # Created outside the try so the post-try flush is safe even if an + # exception fires before the first assignment inside the try block. + token_acc = self._new_token_accumulator(user_id=user_id) try: # Generate text content text_response = await self._agents["text_content"].run(text_request) + try: + token_acc.record_response(agent_name="text_content_agent", response=text_response) + except Exception as _tu_err: + logger.debug("token_usage record (text_content_agent) failed: %s", _tu_err) results["text_content"] = str(text_response) # Generate image prompt if requested @@ -1515,6 +1887,10 @@ async def generate_content( else: # Direct mode: use image agent to create prompt, then generate via image generation model image_response = await self._agents["image_content"].run(image_request) + try: + token_acc.record_response(agent_name="image_content_agent", response=image_response) + except Exception as _tu_err: + logger.debug("token_usage record (image_content_agent) failed: %s", _tu_err) results["image_prompt"] = str(image_response) # Extract clean prompt from the response and generate actual image @@ -1585,6 +1961,10 @@ async def generate_content( Check against brand guidelines and flag any issues. """ compliance_response = await self._agents["compliance"].run(compliance_request) + try: + token_acc.record_response(agent_name="compliance_agent", response=compliance_response) + except Exception as _tu_err: + logger.debug("token_usage record (compliance_agent) failed: %s", _tu_err) results["compliance"] = str(compliance_response) # Try to parse compliance violations @@ -1619,6 +1999,12 @@ async def generate_content( logger.exception(f"Error generating content: {e}") results["error"] = str(e) + # Emit aggregated token usage events for the generate_content request. + try: + token_acc.flush(source="generate_content") + except Exception as _tu_err: + logger.debug("token_usage flush (generate_content) failed: %s", _tu_err) + # Log results summary before returning logger.info(f"Orchestrator returning results with keys: {list(results.keys())}") has_image = bool(results.get("image_base64")) @@ -1632,7 +2018,9 @@ async def regenerate_image( modification_request: str, brief: CreativeBrief, products: list = None, - previous_image_prompt: str = None + previous_image_prompt: str = None, + user_id: str = "", + conversation_id: str = "" ) -> dict: """ Regenerate just the image based on a user modification request. @@ -1645,6 +2033,9 @@ async def regenerate_image( brief: The confirmed creative brief products: List of products to feature previous_image_prompt: The previous image prompt (if available) + user_id: Optional caller's user id, propagated to token usage telemetry + conversation_id: Optional conversation id, propagated to token usage + telemetry for correlation in Application Insights. Returns: dict: Regenerated image with updated prompt @@ -1652,6 +2043,23 @@ async def regenerate_image( if not self._initialized: self.initialize() + _ctx_token = _current_user_id.set(user_id or "") + _ctx_conv = _current_conversation_id.set(conversation_id or "") + try: + return await self._regenerate_image_impl( + modification_request, brief, products, previous_image_prompt + ) + finally: + _current_user_id.reset(_ctx_token) + _current_conversation_id.reset(_ctx_conv) + + async def _regenerate_image_impl( + self, + modification_request: str, + brief: CreativeBrief, + products: list = None, + previous_image_prompt: str = None + ) -> dict: logger.info(f"Regenerating image with modification: {modification_request[:100]}...") # PROACTIVE CONTENT SAFETY CHECK @@ -1764,6 +2172,12 @@ async def regenerate_image( else: # Direct mode: use image agent to interpret the modification image_response = await self._agents["image_content"].run(modification_prompt) + try: + regen_acc = self._new_token_accumulator() + regen_acc.record_response(agent_name="image_content_agent", response=image_response) + regen_acc.flush(source="regenerate_image") + except Exception as _tu_err: + logger.debug("token_usage capture (regenerate_image) failed: %s", _tu_err) prompt_text = str(image_response) # Extract the prompt from JSON response diff --git a/src/backend/services/title_service.py b/src/backend/services/title_service.py index 92289ef12..0f0b2e161 100644 --- a/src/backend/services/title_service.py +++ b/src/backend/services/title_service.py @@ -14,6 +14,8 @@ from azure.identity import DefaultAzureCredential from settings import app_settings +from telemetry import token_emitter +from llm_token_telemetry import TokenUsageScope logger = logging.getLogger(__name__) @@ -82,7 +84,13 @@ def _fallback_title(message: str) -> str: words = message.strip().split()[:4] return " ".join(words) if words else "New Conversation" - async def generate_title(self, first_user_message: str) -> str: + async def generate_title( + self, + first_user_message: str, + *, + user_id: str = "", + conversation_id: str = "", + ) -> str: """ Generate a concise conversation title from the first user message. @@ -109,7 +117,20 @@ async def generate_title(self, first_user_message: str) -> str: ) try: - response = await self._agent.run(prompt) + deployment = ( + (app_settings.ai_foundry.model_deployment or app_settings.azure_openai.gpt_model) + if app_settings.ai_foundry.use_foundry + else app_settings.azure_openai.gpt_model + ) or "" + with TokenUsageScope( + token_emitter, + agent_name="title_agent", + model_deployment_name=deployment, + user_id=user_id, + conversation_id=conversation_id, + ) as scope: + response = await self._agent.run(prompt) + scope.add(response) # Clean up the response title = str(response).strip().splitlines()[0].strip() diff --git a/src/backend/telemetry.py b/src/backend/telemetry.py new file mode 100644 index 000000000..0a2c34f80 --- /dev/null +++ b/src/backend/telemetry.py @@ -0,0 +1,92 @@ +"""Process-wide telemetry singletons. + +A single :class:`TokenUsageEmitter` is constructed at import time so every +router/utility shares the same App Insights connection-string resolution and +static dimensions. Beyond reading ``APPLICATIONINSIGHTS_CONNECTION_STRING`` and +the env vars documented below, constructing that emitter also resolves the +optional App Insights event sink, which may import +``azure.monitor.events.extension`` when the package is installed. + +Optional environment variables +------------------------------ +LLM_TOKEN_SAMPLE_RATE + Float in [0, 1]. Fraction of high-cardinality token events + (agent/model/user/team/speech) to ship. The summary event always fires. + Defaults to ``1.0``. + +LLM_TOKEN_USER_ID_HMAC_KEY + When set, ``user_id`` values are replaced with an HMAC-SHA256 hex digest + (truncated to 16 chars) before leaving the process. Use to satisfy + GDPR / PII handling requirements without modifying call sites. + +LLM_TOKEN_PRICING + Optional comma-separated list of ``model=in_per_1k:out_per_1k`` entries, + e.g. ``gpt-4o=0.0025:0.01,gpt-4o-mini=0.00015:0.0006``. When set the + emitter attaches ``estimated_cost_usd`` to agent / model / summary + events so dashboards can group by cost without hard-coded KQL rates. +""" +from __future__ import annotations + +import hashlib +import hmac +import logging +import os +from typing import Callable, Optional + +from llm_token_telemetry import TokenUsageEmitter + +_log = logging.getLogger(__name__) + + +def _parse_sample_rate() -> float: + raw = os.getenv("LLM_TOKEN_SAMPLE_RATE") + if not raw: + return 1.0 + try: + return max(0.0, min(1.0, float(raw))) + except ValueError: + _log.warning("Invalid LLM_TOKEN_SAMPLE_RATE=%r; defaulting to 1.0", raw) + return 1.0 + + +def _build_user_id_hasher() -> Optional[Callable[[str], str]]: + key = os.getenv("LLM_TOKEN_USER_ID_HMAC_KEY") + if not key: + return None + key_bytes = key.encode("utf-8") + + def _hash(value: str) -> str: + digest = hmac.new(key_bytes, value.encode("utf-8"), hashlib.sha256).hexdigest() + return digest[:16] + + return _hash + + +def _parse_pricing() -> dict[str, tuple[float, float]]: + raw = os.getenv("LLM_TOKEN_PRICING") + if not raw: + return {} + pricing: dict[str, tuple[float, float]] = {} + for entry in raw.split(","): + entry = entry.strip() + if not entry or "=" not in entry: + continue + model, rates = entry.split("=", 1) + if ":" not in rates: + continue + in_s, out_s = rates.split(":", 1) + try: + pricing[model.strip().lower()] = (float(in_s), float(out_s)) + except ValueError: + _log.warning("Ignoring malformed pricing entry: %s", entry) + return pricing + + +token_emitter = TokenUsageEmitter( + static_dimensions={"app": "content-generation"}, + sample_rate=_parse_sample_rate(), + user_id_hasher=_build_user_id_hasher(), + pricing=_parse_pricing(), +) + +__all__ = ["token_emitter"] diff --git a/src/tests/test_app_title_endpoints.py b/src/tests/test_app_title_endpoints.py index 6e4fe0650..f8867d132 100644 --- a/src/tests/test_app_title_endpoints.py +++ b/src/tests/test_app_title_endpoints.py @@ -315,7 +315,9 @@ async def test_generates_title_for_new_conversation(self, client): assert resp.status_code == 200 mock_title_svc.generate_title.assert_called_once_with( - "I need a social media post about paint products" + "I need a social media post about paint products", + user_id="user-1", + conversation_id="conv-chat-1", ) @pytest.mark.asyncio diff --git a/src/tests/test_llm_token_telemetry.py b/src/tests/test_llm_token_telemetry.py new file mode 100644 index 000000000..f465f4c2d --- /dev/null +++ b/src/tests/test_llm_token_telemetry.py @@ -0,0 +1,220 @@ +"""Focused unit tests for the token-usage telemetry helpers. + +Covers the supported usage response shapes (framework ``usage_details``, +aggregated message ``contents`` usage, raw OpenAI ``usage`` fallback, streaming +chunk metadata, and realtime/voice sub-counts) plus ``TokenUsage`` arithmetic, +``TokenUsageScope`` accumulation, and ``TokenUsageEmitter`` behaviour +(user_id hashing, pricing, and the disabled no-op path). + +These guard against regressions as the Agent Framework / OpenAI SDK usage +shapes evolve. +""" +from types import SimpleNamespace + +import pytest + +from llm_token_telemetry import ( + EVENT_AGENT, + EVENT_MODEL, + EVENT_SUMMARY, + TokenUsage, + TokenUsageEmitter, + TokenUsageScope, + extract_realtime_usage, + extract_usage, + extract_usage_from_dict, + extract_usage_from_stream_chunk, +) + + +def test_extract_usage_from_usage_details_attr(): + """Framework result exposing ``usage_details`` with *_token_count keys.""" + result = SimpleNamespace( + usage_details=SimpleNamespace( + input_token_count=120, output_token_count=30, total_token_count=150 + ) + ) + usage = extract_usage(result) + assert usage == TokenUsage(input_tokens=120, output_tokens=30, total_tokens=150) + + +def test_extract_usage_from_raw_openai_usage(): + """OpenAI ChatCompletion shape via ``raw_representation.usage``.""" + result = SimpleNamespace( + raw_representation=SimpleNamespace( + usage={"prompt_tokens": 10, "completion_tokens": 5} + ) + ) + usage = extract_usage(result) + assert usage == TokenUsage(input_tokens=10, output_tokens=5, total_tokens=15) + + +def test_extract_usage_aggregates_message_contents(): + """Usage spread across ``messages[*].contents[*].usage_details`` is summed.""" + msg = SimpleNamespace( + contents=[ + SimpleNamespace(usage_details={"input_tokens": 4, "output_tokens": 1}), + SimpleNamespace(usage_details={"input_tokens": 6, "output_tokens": 2}), + ] + ) + result = SimpleNamespace(messages=[msg]) + usage = extract_usage(result) + assert usage == TokenUsage(input_tokens=10, output_tokens=3, total_tokens=13) + + +def test_extract_usage_returns_none_for_unknown_shape(): + assert extract_usage(None) is None + assert extract_usage(SimpleNamespace(foo="bar")) is None + + +def test_extract_usage_from_dict_fallback(): + usage = extract_usage_from_dict({"prompt_tokens": 7, "completion_tokens": 3}) + assert usage == TokenUsage(input_tokens=7, output_tokens=3, total_tokens=10) + assert extract_usage_from_dict({}) is None + + +def test_extract_usage_from_stream_chunk_metadata(): + chunk = SimpleNamespace(metadata={"usage": {"input_tokens": 2, "output_tokens": 8}}) + usage = extract_usage_from_stream_chunk(chunk) + assert usage == TokenUsage(input_tokens=2, output_tokens=8, total_tokens=10) + + +def test_extract_realtime_usage_omits_absent_subcounts(): + """When the provider does not report sub-counts they stay ``None`` so the + event props omit them (rather than emitting a misleading ``0``).""" + response = {"usage": {"input_tokens": 100, "output_tokens": 20}} + usage = extract_realtime_usage(response) + assert usage.input_tokens == 100 + assert usage.output_tokens == 20 + assert usage.input_audio_tokens is None + assert usage.output_text_tokens is None + props = usage.to_event_props() + assert "input_audio_tokens" not in props + assert "output_text_tokens" not in props + + +def test_extract_realtime_usage_includes_present_subcounts(): + response = { + "usage": { + "input_tokens": 100, + "output_tokens": 20, + "input_token_details": {"audio_tokens": 80, "cached_tokens": 0}, + "output_token_details": {"text_tokens": 20}, + } + } + usage = extract_realtime_usage(response) + assert usage.input_audio_tokens == 80 + assert usage.input_cached_tokens == 0 + assert usage.output_text_tokens == 20 + props = usage.to_event_props() + assert props["input_audio_tokens"] == "80" + assert props["input_cached_tokens"] == "0" + + +def test_extract_realtime_usage_returns_none_when_no_usage(): + assert extract_realtime_usage({}) is None + + +def test_token_usage_add_handles_none_subcounts(): + a = TokenUsage(input_tokens=1, output_tokens=1, total_tokens=2) + b = TokenUsage(input_tokens=2, output_tokens=3, total_tokens=5, input_audio_tokens=4) + combined = a + b + assert combined.input_tokens == 3 + assert combined.output_tokens == 4 + assert combined.total_tokens == 7 + assert combined.input_audio_tokens == 4 + assert combined.output_text_tokens is None + + +def test_to_event_props_only_includes_set_subcounts(): + usage = TokenUsage(input_tokens=5, output_tokens=5, total_tokens=10, input_text_tokens=5) + props = usage.to_event_props() + assert props["input_text_tokens"] == "5" + assert "input_audio_tokens" not in props + + +def _emitter_with_sink(): + """Return (emitter, events) where events captures (name, props) tuples.""" + events: list[tuple[str, dict]] = [] + emitter = TokenUsageEmitter( + connection_string="InstrumentationKey=test", + event_sink=lambda name, props: events.append((name, props)), + ) + return emitter, events + + +def test_token_usage_scope_accumulates_and_emits(): + emitter, events = _emitter_with_sink() + with TokenUsageScope( + emitter, + agent_name="title_agent", + model_deployment_name="gpt-4o", + conversation_id="conv-1", + ) as scope: + scope.add(SimpleNamespace(usage={"prompt_tokens": 10, "completion_tokens": 5})) + scope.add(SimpleNamespace(usage={"prompt_tokens": 4, "completion_tokens": 1})) + + assert scope.usage == TokenUsage(input_tokens=14, output_tokens=6, total_tokens=20) + names = {name for name, _ in events} + assert EVENT_AGENT in names + assert EVENT_MODEL in names + assert EVENT_SUMMARY in names + agent_props = next(p for n, p in events if n == EVENT_AGENT) + assert agent_props["conversation_id"] == "conv-1" + + +def test_token_usage_scope_no_emit_when_no_usage(): + emitter, events = _emitter_with_sink() + with TokenUsageScope(emitter, agent_name="a", model_deployment_name="m"): + pass + assert events == [] + + +def test_emitter_hashes_user_id_before_emitting(): + events: list[tuple[str, dict]] = [] + emitter = TokenUsageEmitter( + connection_string="InstrumentationKey=test", + event_sink=lambda name, props: events.append((name, props)), + user_id_hasher=lambda v: "HASHED", + ) + emitter.emit_agent( + agent_name="a", + model_deployment_name="gpt-4o", + usage=TokenUsage(input_tokens=1, output_tokens=1, total_tokens=2), + user_id="alice@example.com", + ) + assert events + _, props = events[0] + assert props["user_id"] == "HASHED" + + +def test_emitter_attaches_estimated_cost_when_pricing_configured(): + events: list[tuple[str, dict]] = [] + emitter = TokenUsageEmitter( + connection_string="InstrumentationKey=test", + event_sink=lambda name, props: events.append((name, props)), + pricing={"gpt-4o": (0.0025, 0.01)}, + ) + emitter.emit_agent( + agent_name="a", + model_deployment_name="gpt-4o", + usage=TokenUsage(input_tokens=1000, output_tokens=1000, total_tokens=2000), + ) + _, props = events[0] + assert props["estimated_cost_usd"] == "0.012500" + + +def test_disabled_emitter_is_a_noop(): + """No connection string -> emitter disabled -> sink never invoked.""" + events: list[tuple[str, dict]] = [] + emitter = TokenUsageEmitter( + connection_string="", + event_sink=lambda name, props: events.append((name, props)), + ) + assert emitter.enabled is False + emitter.emit_agent( + agent_name="a", + model_deployment_name="m", + usage=TokenUsage(input_tokens=1, output_tokens=1, total_tokens=2), + ) + assert events == []