diff --git a/evals/cli/oaieval.py b/evals/cli/oaieval.py index e48a09ac19..05b7b99fd9 100644 --- a/evals/cli/oaieval.py +++ b/evals/cli/oaieval.py @@ -5,6 +5,7 @@ import logging import shlex import sys +from numbers import Number from typing import Any, Mapping, Optional, Union, cast import evals @@ -18,6 +19,26 @@ logger = logging.getLogger(__name__) +def _flatten_usage_metrics(usage: Any, prefix: str = "") -> dict[str, Number]: + if usage is None: + return {} + + if isinstance(usage, Number): + return {prefix: usage} if prefix else {} + + if hasattr(usage, "model_dump"): + return _flatten_usage_metrics(usage.model_dump(exclude_none=True), prefix) + + if isinstance(usage, Mapping): + flattened: dict[str, Number] = {} + for key, value in usage.items(): + nested_prefix = f"{prefix}_{key}" if prefix else str(key) + flattened.update(_flatten_usage_metrics(value, nested_prefix)) + return flattened + + return {} + + def _purple(str: str) -> str: return f"\033[1;35m{str}\033[0m" @@ -274,13 +295,14 @@ def add_token_usage_to_result(result: dict[str, Any], recorder: RecorderBase) -> sampling_events = recorder.get_events("sampling") for event in sampling_events: if "usage" in event.data: - usage_events.append(dict(event.data["usage"])) + usage_events.append(_flatten_usage_metrics(event.data["usage"])) logger.info(f"Found {len(usage_events)}/{len(sampling_events)} sampling events with usage data") if usage_events: - # Sum up the usage of all samples (assumes the usage is the same for all samples) + # Sum up token usage across all sampling events, including nested usage breakdowns. + usage_keys = set().union(*(usage_event.keys() for usage_event in usage_events)) total_usage = { - key: sum(u[key] if u[key] is not None else 0 for u in usage_events) - for key in usage_events[0] + key: sum(usage_event.get(key, 0) for usage_event in usage_events) + for key in sorted(usage_keys) } total_usage_str = "\n".join(f"{key}: {value:,}" for key, value in total_usage.items()) logger.info(f"Token usage from {len(usage_events)} sampling events:\n{total_usage_str}") diff --git a/evals/cli/oaieval_test.py b/evals/cli/oaieval_test.py new file mode 100644 index 0000000000..6330ef4069 --- /dev/null +++ b/evals/cli/oaieval_test.py @@ -0,0 +1,54 @@ +from openai.types.completion_usage import CompletionTokensDetails, CompletionUsage, PromptTokensDetails + +from evals.base import RunSpec +from evals.cli.oaieval import add_token_usage_to_result +from evals.record import DummyRecorder + + +def test_add_token_usage_to_result_flattens_nested_usage_details() -> None: + spec = RunSpec( + completion_fns=[""], + eval_name="", + base_eval="", + split="", + run_config={}, + created_by="", + run_id="", + created_at="", + ) + recorder = DummyRecorder(spec) + + with recorder.as_default_recorder("sample-1"): + recorder.record_sampling( + prompt="prompt-1", + sampled="answer-1", + usage=CompletionUsage( + prompt_tokens=10, + completion_tokens=5, + total_tokens=15, + completion_tokens_details=CompletionTokensDetails(reasoning_tokens=2), + prompt_tokens_details=PromptTokensDetails(cached_tokens=4), + ), + ) + + with recorder.as_default_recorder("sample-2"): + recorder.record_sampling( + prompt="prompt-2", + sampled="answer-2", + usage=CompletionUsage( + prompt_tokens=7, + completion_tokens=6, + total_tokens=13, + completion_tokens_details=CompletionTokensDetails(reasoning_tokens=3), + prompt_tokens_details=PromptTokensDetails(cached_tokens=1), + ), + ) + + result: dict[str, int] = {} + add_token_usage_to_result(result, recorder) + + assert result["usage_prompt_tokens"] == 17 + assert result["usage_completion_tokens"] == 11 + assert result["usage_total_tokens"] == 28 + assert result["usage_completion_tokens_details_reasoning_tokens"] == 5 + assert result["usage_prompt_tokens_details_cached_tokens"] == 5