diff --git a/pyproject.toml b/pyproject.toml index 1339e34e..3cb25d7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,9 @@ Documentation = "https://github.com/microsoft/typeagent-py/tree/main/docs/README [tool.uv.build-backend] module-root = "src" +[tool.uv.sources] +pytest-async-benchmark = { git = "https://github.com/KRRT7/pytest-async-benchmark.git", rev = "feat/pedantic-mode" } + [tool.pytest.ini_options] asyncio_default_fixture_loop_scope = "function" testpaths = ["tests"] @@ -91,6 +94,7 @@ dev = [ "opentelemetry-instrumentation-httpx>=0.57b0", "pyright>=1.1.408", # 407 has a regression "pytest>=8.3.5", + "pytest-async-benchmark", "pytest-asyncio>=0.26.0", "pytest-mock>=3.14.0", ] diff --git a/tests/benchmarks/test_benchmark_indexing.py b/tests/benchmarks/test_benchmark_indexing.py new file mode 100644 index 00000000..d730ca5c --- /dev/null +++ b/tests/benchmarks/test_benchmark_indexing.py @@ -0,0 +1,126 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Benchmarks for add_messages_with_indexing — the core indexing pipeline. + +Exercises: message storage, semantic ref creation, term index insertion, +property index insertion, and embedding computation. + +Only the hot path (add_messages_with_indexing) is timed — DB creation, +storage provider init, VTT parsing, and teardown are excluded via +async_benchmark.pedantic(). + +Run: + uv run python -m pytest tests/benchmarks/test_benchmark_indexing.py -v -s +""" + +import itertools +import os +import shutil +import tempfile + +import pytest + +from typeagent.aitools.model_adapters import create_test_embedding_model +from typeagent.knowpro.convsettings import ConversationSettings +from typeagent.knowpro.universal_message import ConversationMessage +from typeagent.storage.sqlite.provider import SqliteStorageProvider +from typeagent.transcripts.transcript import ( + Transcript, + TranscriptMessage, + TranscriptMessageMeta, +) +from typeagent.transcripts.transcript_ingest import ingest_vtt_transcript + +TESTDATA = os.path.join(os.path.dirname(__file__), "..", "testdata") +CONFUSE_A_CAT_VTT = os.path.join(TESTDATA, "Confuse-A-Cat.vtt") + + +def make_settings() -> ConversationSettings: + """Create conversation settings with fake embedding model (no API keys).""" + model = create_test_embedding_model() + settings = ConversationSettings(model=model) + settings.semantic_ref_index_settings.auto_extract_knowledge = False + return settings + + +async def extract_vtt_messages(vtt_path: str) -> list[ConversationMessage]: + """Parse a VTT file via ingest_vtt_transcript and return the messages.""" + settings = make_settings() + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "parse.db") + transcript = await ingest_vtt_transcript(vtt_path, settings, dbname=db_path) + n = await transcript.messages.size() + messages = await transcript.messages.get_slice(0, n) + await settings.storage_provider.close() + return messages + + +def synthetic_messages(n: int) -> list[TranscriptMessage]: + """Build n synthetic TranscriptMessages.""" + return [ + TranscriptMessage( + text_chunks=[f"Message {i} about topic {i % 10}"], + metadata=TranscriptMessageMeta(speaker=f"Speaker{i % 3}"), + tags=[f"tag{i % 5}"], + ) + for i in range(n) + ] + + +async def run_indexing_benchmark(async_benchmark, messages, message_type): + """Shared benchmark harness: fresh DB per round, only hot path timed.""" + settings = make_settings() + tmpdir = tempfile.mkdtemp() + counter = itertools.count() + + async def setup(): + i = next(counter) + db_path = os.path.join(tmpdir, f"bench_{i}.db") + storage = SqliteStorageProvider( + db_path, + message_type=message_type, + message_text_index_settings=settings.message_text_index_settings, + related_term_index_settings=settings.related_term_index_settings, + ) + settings.storage_provider = storage + transcript = await Transcript.create(settings, name="bench") + return transcript, storage, db_path + + async def teardown(setup_rv): + _, storage, db_path = setup_rv + await storage.close() + os.remove(db_path) + + async def target(transcript, storage, db_path): + await transcript.add_messages_with_indexing(messages) + + try: + await async_benchmark.pedantic( + target, setup=setup, teardown=teardown, rounds=20, warmup_rounds=3 + ) + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + +@pytest.mark.asyncio +async def test_benchmark_vtt_ingest(async_benchmark): + """Benchmark indexing of pre-parsed VTT messages (Confuse-A-Cat, 40 msgs).""" + messages = await extract_vtt_messages(CONFUSE_A_CAT_VTT) + await run_indexing_benchmark(async_benchmark, messages, ConversationMessage) + + +@pytest.mark.asyncio +async def test_benchmark_add_messages_50(async_benchmark): + """Benchmark add_messages_with_indexing with 50 synthetic messages.""" + await run_indexing_benchmark( + async_benchmark, synthetic_messages(50), TranscriptMessage + ) + + +@pytest.mark.asyncio +async def test_benchmark_add_messages_200(async_benchmark): + """Benchmark add_messages_with_indexing with 200 synthetic messages.""" + await run_indexing_benchmark( + async_benchmark, synthetic_messages(200), TranscriptMessage + ) diff --git a/uv.lock b/uv.lock index 4eab7ead..e2b66a3d 100644 --- a/uv.lock +++ b/uv.lock @@ -1922,6 +1922,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, ] +[[package]] +name = "pytest-async-benchmark" +version = "0.2.0" +source = { git = "https://github.com/KRRT7/pytest-async-benchmark.git?rev=feat%2Fpedantic-mode#029d03634d140789baebc6c3c8f72d5c81a67f9a" } +dependencies = [ + { name = "pytest" }, + { name = "rich" }, +] + [[package]] name = "pytest-asyncio" version = "1.3.0" @@ -2398,6 +2407,7 @@ dev = [ { name = "opentelemetry-instrumentation-httpx" }, { name = "pyright" }, { name = "pytest" }, + { name = "pytest-async-benchmark" }, { name = "pytest-asyncio" }, { name = "pytest-mock" }, ] @@ -2436,6 +2446,7 @@ dev = [ { name = "opentelemetry-instrumentation-httpx", specifier = ">=0.57b0" }, { name = "pyright", specifier = ">=1.1.408" }, { name = "pytest", specifier = ">=8.3.5" }, + { name = "pytest-async-benchmark", git = "https://github.com/KRRT7/pytest-async-benchmark.git?rev=feat%2Fpedantic-mode" }, { name = "pytest-asyncio", specifier = ">=0.26.0" }, { name = "pytest-mock", specifier = ">=3.14.0" }, ]