diff --git a/pyproject.toml b/pyproject.toml index f11c5905..73e17a65 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,9 @@ Documentation = "https://github.com/microsoft/typeagent-py/tree/main/docs/README [tool.uv.build-backend] module-root = "src" +[tool.uv.sources] +pytest-async-benchmark = { git = "https://github.com/KRRT7/pytest-async-benchmark.git", rev = "feat/pedantic-mode" } + [tool.pytest.ini_options] asyncio_default_fixture_loop_scope = "function" testpaths = ["tests"] @@ -91,6 +94,7 @@ dev = [ "opentelemetry-instrumentation-httpx>=0.57b0", "pyright>=1.1.408", # 407 has a regression "pytest>=8.3.5", + "pytest-async-benchmark", "pytest-asyncio>=0.26.0", "pytest-benchmark>=5.1.0", "pytest-mock>=3.14.0", diff --git a/src/typeagent/knowpro/answers.py b/src/typeagent/knowpro/answers.py index eb77e12f..9e984e18 100644 --- a/src/typeagent/knowpro/answers.py +++ b/src/typeagent/knowpro/answers.py @@ -452,19 +452,22 @@ async def get_scored_semantic_refs_from_ordinals_iter( semantic_ref_matches: list[ScoredSemanticRefOrdinal], knowledge_type: KnowledgeType, ) -> list[Scored[SemanticRef]]: - result = [] - for semantic_ref_match in semantic_ref_matches: - semantic_ref = await semantic_refs.get_item( - semantic_ref_match.semantic_ref_ordinal - ) - if semantic_ref.knowledge.knowledge_type == knowledge_type: - result.append( - Scored( - item=semantic_ref, - score=semantic_ref_match.score, - ) - ) - return result + if not semantic_ref_matches: + return [] + ordinals = [m.semantic_ref_ordinal for m in semantic_ref_matches] + metadata = await semantic_refs.get_metadata_multiple(ordinals) + matching = [ + (sr_match, m.ordinal) + for sr_match, m in zip(semantic_ref_matches, metadata) + if m.knowledge_type == knowledge_type + ] + if not matching: + return [] + full_refs = await semantic_refs.get_multiple([o for _, o in matching]) + return [ + Scored(item=ref, score=sr_match.score) + for (sr_match, _), ref in zip(matching, full_refs) + ] def merge_scored_concrete_entities( diff --git a/src/typeagent/knowpro/collections.py b/src/typeagent/knowpro/collections.py index a2716577..975926a3 100644 --- a/src/typeagent/knowpro/collections.py +++ b/src/typeagent/knowpro/collections.py @@ -331,13 +331,17 @@ async def group_matches_by_type( self, semantic_refs: ISemanticRefCollection, ) -> dict[KnowledgeType, "SemanticRefAccumulator"]: + matches = list(self) + if not matches: + return {} + ordinals = [match.value for match in matches] + metadata = await semantic_refs.get_metadata_multiple(ordinals) groups: dict[KnowledgeType, SemanticRefAccumulator] = {} - for match in self: - semantic_ref = await semantic_refs.get_item(match.value) - group = groups.get(semantic_ref.knowledge.knowledge_type) + for match, m in zip(matches, metadata): + group = groups.get(m.knowledge_type) if group is None: group = SemanticRefAccumulator(self.search_term_matches) - groups[semantic_ref.knowledge.knowledge_type] = group + groups[m.knowledge_type] = group group.set_match(match) return groups @@ -346,11 +350,14 @@ async def get_matches_in_scope( semantic_refs: ISemanticRefCollection, ranges_in_scope: "TextRangesInScope", ) -> "SemanticRefAccumulator": + matches = list(self) + if not matches: + return SemanticRefAccumulator(self.search_term_matches) + ordinals = [match.value for match in matches] + metadata = await semantic_refs.get_metadata_multiple(ordinals) accumulator = SemanticRefAccumulator(self.search_term_matches) - for match in self: - if ranges_in_scope.is_range_in_scope( - (await semantic_refs.get_item(match.value)).range - ): + for match, m in zip(matches, metadata): + if ranges_in_scope.is_range_in_scope(m.range): accumulator.set_match(match) return accumulator @@ -519,12 +526,14 @@ def add_ranges(self, text_ranges: "list[TextRange] | TextRangeCollection") -> No self.add_range(text_range) def contains_range(self, inner_range: TextRange) -> bool: - # Since ranges are sorted by start, once we pass inner_range's start - # no further range can contain it. - for outer_range in self._ranges: - if outer_range.start > inner_range.start: - break - if inner_range in outer_range: + if not self._ranges: + return False + # Bisect on start only to find all ranges with start <= inner.start, + # then scan backwards — the most likely containing range has the + # largest start still <= inner's. + hi = bisect.bisect_right(self._ranges, inner_range.start, key=lambda r: r.start) + for i in range(hi - 1, -1, -1): + if inner_range in self._ranges[i]: return True return False diff --git a/src/typeagent/knowpro/interfaces_core.py b/src/typeagent/knowpro/interfaces_core.py index cd9e885c..4dc8fc8e 100644 --- a/src/typeagent/knowpro/interfaces_core.py +++ b/src/typeagent/knowpro/interfaces_core.py @@ -255,32 +255,24 @@ def __repr__(self) -> str: else: return f"{self.__class__.__name__}({self.start}, {self.end})" + @staticmethod + def _effective_end(tr: "TextRange") -> tuple[int, int]: + """Return (message_ordinal, chunk_ordinal) for the effective end.""" + if tr.end is not None: + return (tr.end.message_ordinal, tr.end.chunk_ordinal) + return (tr.start.message_ordinal, tr.start.chunk_ordinal + 1) + def __eq__(self, other: object) -> bool: if not isinstance(other, TextRange): return NotImplemented - if self.start != other.start: return False - - # Get the effective end for both ranges - self_end = self.end or TextLocation( - self.start.message_ordinal, self.start.chunk_ordinal + 1 - ) - other_end = other.end or TextLocation( - other.start.message_ordinal, other.start.chunk_ordinal + 1 - ) - return self_end == other_end + return TextRange._effective_end(self) == TextRange._effective_end(other) def __lt__(self, other: Self) -> bool: if self.start != other.start: return self.start < other.start - self_end = self.end or TextLocation( - self.start.message_ordinal, self.start.chunk_ordinal + 1 - ) - other_end = other.end or TextLocation( - other.start.message_ordinal, other.start.chunk_ordinal + 1 - ) - return self_end < other_end + return TextRange._effective_end(self) < TextRange._effective_end(other) def __gt__(self, other: Self) -> bool: return other.__lt__(self) @@ -292,13 +284,9 @@ def __le__(self, other: Self) -> bool: return not other.__lt__(self) def __contains__(self, other: Self) -> bool: - other_end = other.end or TextLocation( - other.start.message_ordinal, other.start.chunk_ordinal + 1 - ) - self_end = self.end or TextLocation( - self.start.message_ordinal, self.start.chunk_ordinal + 1 - ) - return self.start <= other.start and other_end <= self_end + if not (self.start <= other.start): + return False + return TextRange._effective_end(other) <= TextRange._effective_end(self) def serialize(self) -> TextRangeData: return self.__pydantic_serializer__.to_python( # type: ignore diff --git a/src/typeagent/knowpro/interfaces_storage.py b/src/typeagent/knowpro/interfaces_storage.py index a82fe7ad..97f7b600 100644 --- a/src/typeagent/knowpro/interfaces_storage.py +++ b/src/typeagent/knowpro/interfaces_storage.py @@ -6,16 +6,18 @@ from collections.abc import AsyncIterable, Iterable from datetime import datetime as Datetime -from typing import Any, Protocol, Self +from typing import Any, NamedTuple, Protocol, Self from pydantic.dataclasses import dataclass from .interfaces_core import ( IMessage, ITermToSemanticRefIndex, + KnowledgeType, MessageOrdinal, SemanticRef, SemanticRefOrdinal, + TextRange, ) from .interfaces_indexes import ( IConversationSecondaryIndexes, @@ -57,6 +59,14 @@ class ConversationMetadata: extra: dict[str, str] | None = None +class SemanticRefMetadata(NamedTuple): + """Lightweight metadata for filtering without full knowledge deserialization.""" + + ordinal: SemanticRefOrdinal + range: TextRange + knowledge_type: KnowledgeType + + class IReadonlyCollection[T, TOrdinal](AsyncIterable[T], Protocol): async def size(self) -> int: ... @@ -91,6 +101,12 @@ class IMessageCollection[TMessage: IMessage]( class ISemanticRefCollection(ICollection[SemanticRef, SemanticRefOrdinal], Protocol): """A collection of SemanticRefs.""" + async def get_metadata_multiple( + self, ordinals: list[SemanticRefOrdinal] + ) -> list[SemanticRefMetadata]: + """Batch-fetch lightweight metadata without deserializing knowledge.""" + ... + class IStorageProvider[TMessage: IMessage](Protocol): """API spec for storage providers -- maybe in-memory or persistent.""" @@ -190,4 +206,5 @@ class IConversation[ "ISemanticRefCollection", "IStorageProvider", "STATUS_INGESTED", + "SemanticRefMetadata", ] diff --git a/src/typeagent/knowpro/query.py b/src/typeagent/knowpro/query.py index 44fa06ec..71dbf7ea 100644 --- a/src/typeagent/knowpro/query.py +++ b/src/typeagent/knowpro/query.py @@ -37,6 +37,7 @@ ScoredSemanticRefOrdinal, SearchTerm, SemanticRef, + SemanticRefMetadata, SemanticRefOrdinal, SemanticRefSearchResult, Term, @@ -174,17 +175,14 @@ async def lookup_term_filtered( semantic_ref_index: ITermToSemanticRefIndex, term: Term, semantic_refs: ISemanticRefCollection, - filter: Callable[[SemanticRef, ScoredSemanticRefOrdinal], bool], + filter: Callable[[SemanticRefMetadata, ScoredSemanticRefOrdinal], bool], ) -> list[ScoredSemanticRefOrdinal] | None: """Look up a term in the semantic reference index and filter the results.""" scored_refs = await semantic_ref_index.lookup_term(term.text) if scored_refs: - filtered = [] - for sr in scored_refs: - semantic_ref = await semantic_refs.get_item(sr.semantic_ref_ordinal) - if filter(semantic_ref, sr): - filtered.append(sr) - return filtered + ordinals = [sr.semantic_ref_ordinal for sr in scored_refs] + metadata = await semantic_refs.get_metadata_multiple(ordinals) + return [sr for sr, m in zip(scored_refs, metadata) if filter(m, sr)] return None @@ -202,10 +200,8 @@ async def lookup_term( semantic_ref_index, term, semantic_refs, - lambda sr, _: ( - not knowledge_type or sr.knowledge.knowledge_type == knowledge_type - ) - and ranges_in_scope.is_range_in_scope(sr.range), + lambda m, _: (not knowledge_type or m.knowledge_type == knowledge_type) + and ranges_in_scope.is_range_in_scope(m.range), ) return await semantic_ref_index.lookup_term(term.text) diff --git a/src/typeagent/storage/memory/collections.py b/src/typeagent/storage/memory/collections.py index 9973a290..8a5b14eb 100644 --- a/src/typeagent/storage/memory/collections.py +++ b/src/typeagent/storage/memory/collections.py @@ -10,6 +10,7 @@ IMessage, MessageOrdinal, SemanticRef, + SemanticRefMetadata, SemanticRefOrdinal, ) @@ -63,6 +64,18 @@ async def extend(self, items: Iterable[T]) -> None: class MemorySemanticRefCollection(MemoryCollection[SemanticRef, SemanticRefOrdinal]): """A collection of semantic references.""" + async def get_metadata_multiple( + self, ordinals: list[SemanticRefOrdinal] + ) -> list[SemanticRefMetadata]: + return [ + SemanticRefMetadata( + ordinal=o, + range=self.items[o].range, + knowledge_type=self.items[o].knowledge.knowledge_type, + ) + for o in ordinals + ] + class MemoryMessageCollection[TMessage: IMessage]( MemoryCollection[TMessage, MessageOrdinal] diff --git a/src/typeagent/storage/memory/propindex.py b/src/typeagent/storage/memory/propindex.py index f9717b24..ecb3e85d 100644 --- a/src/typeagent/storage/memory/propindex.py +++ b/src/typeagent/storage/memory/propindex.py @@ -330,12 +330,13 @@ async def lookup_property_in_property_index( property_value, ) if ranges_in_scope is not None and scored_refs: - filtered_refs = [] - for sr in scored_refs: - semantic_ref = await semantic_refs.get_item(sr.semantic_ref_ordinal) - if ranges_in_scope.is_range_in_scope(semantic_ref.range): - filtered_refs.append(sr) - scored_refs = filtered_refs + ordinals = [sr.semantic_ref_ordinal for sr in scored_refs] + metadata = await semantic_refs.get_metadata_multiple(ordinals) + scored_refs = [ + sr + for sr, m in zip(scored_refs, metadata) + if ranges_in_scope.is_range_in_scope(m.range) + ] return scored_refs or None # Return None if no results diff --git a/src/typeagent/storage/sqlite/collections.py b/src/typeagent/storage/sqlite/collections.py index 9730f6d1..fe394dcb 100644 --- a/src/typeagent/storage/sqlite/collections.py +++ b/src/typeagent/storage/sqlite/collections.py @@ -340,6 +340,50 @@ async def get_multiple(self, arg: list[int]) -> list[interfaces.SemanticRef]: assert set(rowdict) == set(arg) return [self._deserialize_semantic_ref_from_row(rowdict[ordl]) for ordl in arg] + async def get_metadata_multiple( + self, ordinals: list[int] + ) -> list[interfaces.SemanticRefMetadata]: + if not ordinals: + return [] + cursor = self.db.cursor() + placeholders = ",".join("?" * len(ordinals)) + cursor.execute( + f""" + SELECT semref_id, range_json, knowledge_type + FROM SemanticRefs WHERE semref_id IN ({placeholders}) + """, + ordinals, + ) + rows = cursor.fetchall() + rowdict = {r[0]: r for r in rows} + result = [] + for o in ordinals: + row = rowdict[o] + range_data = json.loads(row[1]) + start = range_data["start"] + end_data = range_data.get("end") + result.append( + interfaces.SemanticRefMetadata( + ordinal=row[0], + range=interfaces.TextRange( + start=interfaces.TextLocation( + start["messageOrdinal"], + start.get("chunkOrdinal", 0), + ), + end=( + interfaces.TextLocation( + end_data["messageOrdinal"], + end_data.get("chunkOrdinal", 0), + ) + if end_data + else None + ), + ), + knowledge_type=row[2], + ) + ) + return result + async def append(self, item: interfaces.SemanticRef) -> None: cursor = self.db.cursor() semref_id, range_json, knowledge_type, knowledge_json = ( diff --git a/tests/benchmarks/test_benchmark_query.py b/tests/benchmarks/test_benchmark_query.py new file mode 100644 index 00000000..8c4dd137 --- /dev/null +++ b/tests/benchmarks/test_benchmark_query.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Benchmark for lookup_term_filtered — measures the N+1 query pattern. + +After indexing 200 synthetic messages, looks up a high-frequency term +and filters results via lookup_term_filtered. Each call triggers +one get_item() SELECT per matching semantic ref (N+1 pattern). + +Run: + uv run python -m pytest tests/benchmarks/test_benchmark_query.py -v -s +""" + +import os +import tempfile + +import pytest + +from typeagent.aitools.model_adapters import create_test_embedding_model +from typeagent.knowpro.convsettings import ConversationSettings +from typeagent.knowpro.interfaces_core import Term +from typeagent.knowpro.query import lookup_term_filtered +from typeagent.storage.sqlite.provider import SqliteStorageProvider +from typeagent.transcripts.transcript import ( + Transcript, + TranscriptMessage, + TranscriptMessageMeta, +) + + +def make_settings() -> ConversationSettings: + model = create_test_embedding_model() + settings = ConversationSettings(model=model) + settings.semantic_ref_index_settings.auto_extract_knowledge = False + return settings + + +def synthetic_messages(n: int) -> list[TranscriptMessage]: + return [ + TranscriptMessage( + text_chunks=[f"Message {i} about topic {i % 10}"], + metadata=TranscriptMessageMeta(speaker=f"Speaker{i % 3}"), + tags=[f"tag{i % 5}"], + ) + for i in range(n) + ] + + +async def create_indexed_transcript( + db_path: str, settings: ConversationSettings, n_messages: int +) -> Transcript: + """Create and index a transcript, returning it ready for queries.""" + storage = SqliteStorageProvider( + db_path, + message_type=TranscriptMessage, + message_text_index_settings=settings.message_text_index_settings, + related_term_index_settings=settings.related_term_index_settings, + ) + settings.storage_provider = storage + transcript = await Transcript.create(settings, name="bench") + messages = synthetic_messages(n_messages) + await transcript.add_messages_with_indexing(messages) + return transcript + + +@pytest.mark.asyncio +async def test_benchmark_lookup_term_filtered(async_benchmark): + """Benchmark lookup_term_filtered with N+1 get_item pattern.""" + settings = make_settings() + tmpdir = tempfile.mkdtemp() + db_path = os.path.join(tmpdir, "query_bench.db") + + transcript = await create_indexed_transcript(db_path, settings, 200) + + # Find a high-frequency term to look up. + semref_index = transcript.semantic_ref_index + terms = await semref_index.get_terms() + # Pick the term with the most matches. + best_term = None + best_count = 0 + for t in terms: + refs = await semref_index.lookup_term(t) + if refs and len(refs) > best_count: + best_count = len(refs) + best_term = t + + assert best_term is not None, "No terms found after indexing" + print(f"\nBenchmarking term '{best_term}' with {best_count} matches") + + term = Term(text=best_term) + semantic_refs = transcript.semantic_refs + # Filter that accepts all — isolates the get_item overhead. + accept_all = lambda sr, scored: True + + async def target(): + await lookup_term_filtered(semref_index, term, semantic_refs, accept_all) + + try: + await async_benchmark.pedantic(target, rounds=200, warmup_rounds=20) + finally: + await settings.storage_provider.close() + import shutil + + shutil.rmtree(tmpdir, ignore_errors=True) diff --git a/uv.lock b/uv.lock index cfc880eb..2bf87b41 100644 --- a/uv.lock +++ b/uv.lock @@ -1970,6 +1970,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d4/24/a372aaf5c9b7208e7112038812994107bc65a84cd00e0354a88c2c77a617/pytest-9.0.3-py3-none-any.whl", hash = "sha256:2c5efc453d45394fdd706ade797c0a81091eccd1d6e4bccfcd476e2b8e0ab5d9", size = 375249, upload-time = "2026-04-07T17:16:16.13Z" }, ] +[[package]] +name = "pytest-async-benchmark" +version = "0.2.0" +source = { git = "https://github.com/KRRT7/pytest-async-benchmark.git?rev=feat%2Fpedantic-mode#029d03634d140789baebc6c3c8f72d5c81a67f9a" } +dependencies = [ + { name = "pytest" }, + { name = "rich" }, +] + [[package]] name = "pytest-asyncio" version = "1.3.0" @@ -2447,6 +2456,7 @@ dev = [ { name = "opentelemetry-instrumentation-httpx" }, { name = "pyright" }, { name = "pytest" }, + { name = "pytest-async-benchmark" }, { name = "pytest-asyncio" }, { name = "pytest-benchmark" }, { name = "pytest-mock" }, @@ -2486,6 +2496,7 @@ dev = [ { name = "opentelemetry-instrumentation-httpx", specifier = ">=0.57b0" }, { name = "pyright", specifier = ">=1.1.408" }, { name = "pytest", specifier = ">=8.3.5" }, + { name = "pytest-async-benchmark", git = "https://github.com/KRRT7/pytest-async-benchmark.git?rev=feat%2Fpedantic-mode" }, { name = "pytest-asyncio", specifier = ">=0.26.0" }, { name = "pytest-benchmark", specifier = ">=5.1.0" }, { name = "pytest-mock", specifier = ">=3.14.0" },