diff --git a/pyproject.toml b/pyproject.toml index 1339e34e..3cb25d7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,9 @@ Documentation = "https://github.com/microsoft/typeagent-py/tree/main/docs/README [tool.uv.build-backend] module-root = "src" +[tool.uv.sources] +pytest-async-benchmark = { git = "https://github.com/KRRT7/pytest-async-benchmark.git", rev = "feat/pedantic-mode" } + [tool.pytest.ini_options] asyncio_default_fixture_loop_scope = "function" testpaths = ["tests"] @@ -91,6 +94,7 @@ dev = [ "opentelemetry-instrumentation-httpx>=0.57b0", "pyright>=1.1.408", # 407 has a regression "pytest>=8.3.5", + "pytest-async-benchmark", "pytest-asyncio>=0.26.0", "pytest-mock>=3.14.0", ] diff --git a/src/typeagent/aitools/utils.py b/src/typeagent/aitools/utils.py index adba401a..46e81fc1 100644 --- a/src/typeagent/aitools/utils.py +++ b/src/typeagent/aitools/utils.py @@ -197,7 +197,11 @@ def parse_azure_endpoint( f"{endpoint_envvar}={azure_endpoint} doesn't contain valid api-version field" ) - return azure_endpoint, m.group(1) + # Strip query string — AsyncAzureOpenAI expects a clean base URL and + # receives api_version as a separate parameter. + clean_endpoint = azure_endpoint.split("?", 1)[0] + + return clean_endpoint, m.group(1) def get_azure_api_key(azure_api_key: str) -> str: diff --git a/src/typeagent/knowpro/answers.py b/src/typeagent/knowpro/answers.py index 2c300506..9bc41b4b 100644 --- a/src/typeagent/knowpro/answers.py +++ b/src/typeagent/knowpro/answers.py @@ -452,19 +452,22 @@ async def get_scored_semantic_refs_from_ordinals_iter( semantic_ref_matches: list[ScoredSemanticRefOrdinal], knowledge_type: KnowledgeType, ) -> list[Scored[SemanticRef]]: - result = [] - for semantic_ref_match in semantic_ref_matches: - semantic_ref = await semantic_refs.get_item( - semantic_ref_match.semantic_ref_ordinal - ) - if semantic_ref.knowledge.knowledge_type == knowledge_type: - result.append( - Scored( - item=semantic_ref, - score=semantic_ref_match.score, - ) - ) - return result + if not semantic_ref_matches: + return [] + ordinals = [m.semantic_ref_ordinal for m in semantic_ref_matches] + metadata = await semantic_refs.get_metadata_multiple(ordinals) + matching = [ + (sr_match, m.ordinal) + for sr_match, m in zip(semantic_ref_matches, metadata) + if m.knowledge_type == knowledge_type + ] + if not matching: + return [] + full_refs = await semantic_refs.get_multiple([o for _, o in matching]) + return [ + Scored(item=ref, score=sr_match.score) + for (sr_match, _), ref in zip(matching, full_refs) + ] def merge_scored_concrete_entities( diff --git a/src/typeagent/knowpro/collections.py b/src/typeagent/knowpro/collections.py index a2716577..d7c07b19 100644 --- a/src/typeagent/knowpro/collections.py +++ b/src/typeagent/knowpro/collections.py @@ -331,13 +331,17 @@ async def group_matches_by_type( self, semantic_refs: ISemanticRefCollection, ) -> dict[KnowledgeType, "SemanticRefAccumulator"]: + matches = list(self) + if not matches: + return {} + ordinals = [match.value for match in matches] + metadata = await semantic_refs.get_metadata_multiple(ordinals) groups: dict[KnowledgeType, SemanticRefAccumulator] = {} - for match in self: - semantic_ref = await semantic_refs.get_item(match.value) - group = groups.get(semantic_ref.knowledge.knowledge_type) + for match, m in zip(matches, metadata): + group = groups.get(m.knowledge_type) if group is None: group = SemanticRefAccumulator(self.search_term_matches) - groups[semantic_ref.knowledge.knowledge_type] = group + groups[m.knowledge_type] = group group.set_match(match) return groups @@ -346,11 +350,14 @@ async def get_matches_in_scope( semantic_refs: ISemanticRefCollection, ranges_in_scope: "TextRangesInScope", ) -> "SemanticRefAccumulator": + matches = list(self) + if not matches: + return SemanticRefAccumulator(self.search_term_matches) + ordinals = [match.value for match in matches] + metadata = await semantic_refs.get_metadata_multiple(ordinals) accumulator = SemanticRefAccumulator(self.search_term_matches) - for match in self: - if ranges_in_scope.is_range_in_scope( - (await semantic_refs.get_item(match.value)).range - ): + for match, m in zip(matches, metadata): + if ranges_in_scope.is_range_in_scope(m.range): accumulator.set_match(match) return accumulator @@ -519,12 +526,16 @@ def add_ranges(self, text_ranges: "list[TextRange] | TextRangeCollection") -> No self.add_range(text_range) def contains_range(self, inner_range: TextRange) -> bool: - # Since ranges are sorted by start, once we pass inner_range's start - # no further range can contain it. - for outer_range in self._ranges: - if outer_range.start > inner_range.start: - break - if inner_range in outer_range: + if not self._ranges: + return False + # Bisect on start only to find all ranges with start <= inner.start, + # then scan backwards — the most likely containing range has the + # largest start still <= inner's. + hi = bisect.bisect_right( + self._ranges, inner_range.start, key=lambda r: r.start + ) + for i in range(hi - 1, -1, -1): + if inner_range in self._ranges[i]: return True return False diff --git a/src/typeagent/knowpro/interfaces_core.py b/src/typeagent/knowpro/interfaces_core.py index 105e45b6..73584121 100644 --- a/src/typeagent/knowpro/interfaces_core.py +++ b/src/typeagent/knowpro/interfaces_core.py @@ -249,32 +249,24 @@ def __repr__(self) -> str: else: return f"{self.__class__.__name__}({self.start}, {self.end})" + @staticmethod + def _effective_end(tr: "TextRange") -> tuple[int, int]: + """Return (message_ordinal, chunk_ordinal) for the effective end.""" + if tr.end is not None: + return (tr.end.message_ordinal, tr.end.chunk_ordinal) + return (tr.start.message_ordinal, tr.start.chunk_ordinal + 1) + def __eq__(self, other: object) -> bool: if not isinstance(other, TextRange): return NotImplemented - if self.start != other.start: return False - - # Get the effective end for both ranges - self_end = self.end or TextLocation( - self.start.message_ordinal, self.start.chunk_ordinal + 1 - ) - other_end = other.end or TextLocation( - other.start.message_ordinal, other.start.chunk_ordinal + 1 - ) - return self_end == other_end + return TextRange._effective_end(self) == TextRange._effective_end(other) def __lt__(self, other: Self) -> bool: if self.start != other.start: return self.start < other.start - self_end = self.end or TextLocation( - self.start.message_ordinal, self.start.chunk_ordinal + 1 - ) - other_end = other.end or TextLocation( - other.start.message_ordinal, other.start.chunk_ordinal + 1 - ) - return self_end < other_end + return TextRange._effective_end(self) < TextRange._effective_end(other) def __gt__(self, other: Self) -> bool: return other.__lt__(self) @@ -286,13 +278,9 @@ def __le__(self, other: Self) -> bool: return not other.__lt__(self) def __contains__(self, other: Self) -> bool: - other_end = other.end or TextLocation( - other.start.message_ordinal, other.start.chunk_ordinal + 1 - ) - self_end = self.end or TextLocation( - self.start.message_ordinal, self.start.chunk_ordinal + 1 - ) - return self.start <= other.start and other_end <= self_end + if not (self.start <= other.start): + return False + return TextRange._effective_end(other) <= TextRange._effective_end(self) def serialize(self) -> TextRangeData: return self.__pydantic_serializer__.to_python( # type: ignore diff --git a/src/typeagent/knowpro/interfaces_storage.py b/src/typeagent/knowpro/interfaces_storage.py index a82fe7ad..97f7b600 100644 --- a/src/typeagent/knowpro/interfaces_storage.py +++ b/src/typeagent/knowpro/interfaces_storage.py @@ -6,16 +6,18 @@ from collections.abc import AsyncIterable, Iterable from datetime import datetime as Datetime -from typing import Any, Protocol, Self +from typing import Any, NamedTuple, Protocol, Self from pydantic.dataclasses import dataclass from .interfaces_core import ( IMessage, ITermToSemanticRefIndex, + KnowledgeType, MessageOrdinal, SemanticRef, SemanticRefOrdinal, + TextRange, ) from .interfaces_indexes import ( IConversationSecondaryIndexes, @@ -57,6 +59,14 @@ class ConversationMetadata: extra: dict[str, str] | None = None +class SemanticRefMetadata(NamedTuple): + """Lightweight metadata for filtering without full knowledge deserialization.""" + + ordinal: SemanticRefOrdinal + range: TextRange + knowledge_type: KnowledgeType + + class IReadonlyCollection[T, TOrdinal](AsyncIterable[T], Protocol): async def size(self) -> int: ... @@ -91,6 +101,12 @@ class IMessageCollection[TMessage: IMessage]( class ISemanticRefCollection(ICollection[SemanticRef, SemanticRefOrdinal], Protocol): """A collection of SemanticRefs.""" + async def get_metadata_multiple( + self, ordinals: list[SemanticRefOrdinal] + ) -> list[SemanticRefMetadata]: + """Batch-fetch lightweight metadata without deserializing knowledge.""" + ... + class IStorageProvider[TMessage: IMessage](Protocol): """API spec for storage providers -- maybe in-memory or persistent.""" @@ -190,4 +206,5 @@ class IConversation[ "ISemanticRefCollection", "IStorageProvider", "STATUS_INGESTED", + "SemanticRefMetadata", ] diff --git a/src/typeagent/knowpro/query.py b/src/typeagent/knowpro/query.py index 44fa06ec..5859e3bc 100644 --- a/src/typeagent/knowpro/query.py +++ b/src/typeagent/knowpro/query.py @@ -37,6 +37,7 @@ ScoredSemanticRefOrdinal, SearchTerm, SemanticRef, + SemanticRefMetadata, SemanticRefOrdinal, SemanticRefSearchResult, Term, @@ -174,17 +175,14 @@ async def lookup_term_filtered( semantic_ref_index: ITermToSemanticRefIndex, term: Term, semantic_refs: ISemanticRefCollection, - filter: Callable[[SemanticRef, ScoredSemanticRefOrdinal], bool], + filter: Callable[[SemanticRefMetadata, ScoredSemanticRefOrdinal], bool], ) -> list[ScoredSemanticRefOrdinal] | None: """Look up a term in the semantic reference index and filter the results.""" scored_refs = await semantic_ref_index.lookup_term(term.text) if scored_refs: - filtered = [] - for sr in scored_refs: - semantic_ref = await semantic_refs.get_item(sr.semantic_ref_ordinal) - if filter(semantic_ref, sr): - filtered.append(sr) - return filtered + ordinals = [sr.semantic_ref_ordinal for sr in scored_refs] + metadata = await semantic_refs.get_metadata_multiple(ordinals) + return [sr for sr, m in zip(scored_refs, metadata) if filter(m, sr)] return None @@ -202,10 +200,10 @@ async def lookup_term( semantic_ref_index, term, semantic_refs, - lambda sr, _: ( - not knowledge_type or sr.knowledge.knowledge_type == knowledge_type + lambda m, _: ( + not knowledge_type or m.knowledge_type == knowledge_type ) - and ranges_in_scope.is_range_in_scope(sr.range), + and ranges_in_scope.is_range_in_scope(m.range), ) return await semantic_ref_index.lookup_term(term.text) diff --git a/src/typeagent/storage/memory/collections.py b/src/typeagent/storage/memory/collections.py index 9973a290..8a5b14eb 100644 --- a/src/typeagent/storage/memory/collections.py +++ b/src/typeagent/storage/memory/collections.py @@ -10,6 +10,7 @@ IMessage, MessageOrdinal, SemanticRef, + SemanticRefMetadata, SemanticRefOrdinal, ) @@ -63,6 +64,18 @@ async def extend(self, items: Iterable[T]) -> None: class MemorySemanticRefCollection(MemoryCollection[SemanticRef, SemanticRefOrdinal]): """A collection of semantic references.""" + async def get_metadata_multiple( + self, ordinals: list[SemanticRefOrdinal] + ) -> list[SemanticRefMetadata]: + return [ + SemanticRefMetadata( + ordinal=o, + range=self.items[o].range, + knowledge_type=self.items[o].knowledge.knowledge_type, + ) + for o in ordinals + ] + class MemoryMessageCollection[TMessage: IMessage]( MemoryCollection[TMessage, MessageOrdinal] diff --git a/src/typeagent/storage/memory/propindex.py b/src/typeagent/storage/memory/propindex.py index acc7b89a..6290671a 100644 --- a/src/typeagent/storage/memory/propindex.py +++ b/src/typeagent/storage/memory/propindex.py @@ -252,12 +252,13 @@ async def lookup_property_in_property_index( property_value, ) if ranges_in_scope is not None and scored_refs: - filtered_refs = [] - for sr in scored_refs: - semantic_ref = await semantic_refs.get_item(sr.semantic_ref_ordinal) - if ranges_in_scope.is_range_in_scope(semantic_ref.range): - filtered_refs.append(sr) - scored_refs = filtered_refs + ordinals = [sr.semantic_ref_ordinal for sr in scored_refs] + metadata = await semantic_refs.get_metadata_multiple(ordinals) + scored_refs = [ + sr + for sr, m in zip(scored_refs, metadata) + if ranges_in_scope.is_range_in_scope(m.range) + ] return scored_refs or None # Return None if no results diff --git a/src/typeagent/storage/sqlite/collections.py b/src/typeagent/storage/sqlite/collections.py index 9730f6d1..fe394dcb 100644 --- a/src/typeagent/storage/sqlite/collections.py +++ b/src/typeagent/storage/sqlite/collections.py @@ -340,6 +340,50 @@ async def get_multiple(self, arg: list[int]) -> list[interfaces.SemanticRef]: assert set(rowdict) == set(arg) return [self._deserialize_semantic_ref_from_row(rowdict[ordl]) for ordl in arg] + async def get_metadata_multiple( + self, ordinals: list[int] + ) -> list[interfaces.SemanticRefMetadata]: + if not ordinals: + return [] + cursor = self.db.cursor() + placeholders = ",".join("?" * len(ordinals)) + cursor.execute( + f""" + SELECT semref_id, range_json, knowledge_type + FROM SemanticRefs WHERE semref_id IN ({placeholders}) + """, + ordinals, + ) + rows = cursor.fetchall() + rowdict = {r[0]: r for r in rows} + result = [] + for o in ordinals: + row = rowdict[o] + range_data = json.loads(row[1]) + start = range_data["start"] + end_data = range_data.get("end") + result.append( + interfaces.SemanticRefMetadata( + ordinal=row[0], + range=interfaces.TextRange( + start=interfaces.TextLocation( + start["messageOrdinal"], + start.get("chunkOrdinal", 0), + ), + end=( + interfaces.TextLocation( + end_data["messageOrdinal"], + end_data.get("chunkOrdinal", 0), + ) + if end_data + else None + ), + ), + knowledge_type=row[2], + ) + ) + return result + async def append(self, item: interfaces.SemanticRef) -> None: cursor = self.db.cursor() semref_id, range_json, knowledge_type, knowledge_json = ( diff --git a/tests/benchmarks/test_benchmark_query.py b/tests/benchmarks/test_benchmark_query.py new file mode 100644 index 00000000..7e948ad2 --- /dev/null +++ b/tests/benchmarks/test_benchmark_query.py @@ -0,0 +1,253 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Benchmarks for batch metadata query optimization. + +After indexing 200 synthetic messages, exercises each function that was +converted from N+1 get_item() to batch get_metadata_multiple(). + +Run: + uv run python -m pytest tests/benchmarks/test_benchmark_query.py -v -s +""" + +import os +import shutil +import tempfile + +import pytest + +from typeagent.aitools.model_adapters import create_test_embedding_model +from typeagent.knowpro.answers import get_scored_semantic_refs_from_ordinals_iter +from typeagent.knowpro.collections import ( + SemanticRefAccumulator, + TextRangeCollection, + TextRangesInScope, +) +from typeagent.knowpro.convsettings import ConversationSettings +from typeagent.knowpro.interfaces_core import Term, TextLocation, TextRange +from typeagent.knowpro.query import lookup_term_filtered +from typeagent.storage.memory.propindex import ( + PropertyNames, + lookup_property_in_property_index, +) +from typeagent.storage.sqlite.provider import SqliteStorageProvider +from typeagent.transcripts.transcript import ( + Transcript, + TranscriptMessage, + TranscriptMessageMeta, +) + + +def make_settings() -> ConversationSettings: + model = create_test_embedding_model() + settings = ConversationSettings(model=model) + settings.semantic_ref_index_settings.auto_extract_knowledge = False + return settings + + +def synthetic_messages(n: int) -> list[TranscriptMessage]: + return [ + TranscriptMessage( + text_chunks=[f"Message {i} about topic {i % 10}"], + metadata=TranscriptMessageMeta(speaker=f"Speaker{i % 3}"), + tags=[f"tag{i % 5}"], + ) + for i in range(n) + ] + + +async def create_indexed_transcript( + db_path: str, settings: ConversationSettings, n_messages: int +) -> Transcript: + """Create and index a transcript, returning it ready for queries.""" + storage = SqliteStorageProvider( + db_path, + message_type=TranscriptMessage, + message_text_index_settings=settings.message_text_index_settings, + related_term_index_settings=settings.related_term_index_settings, + ) + settings.storage_provider = storage + transcript = await Transcript.create(settings, name="bench") + messages = synthetic_messages(n_messages) + await transcript.add_messages_with_indexing(messages) + return transcript + + +async def find_best_term(semref_index) -> tuple[str, int]: + """Find the term with the most matches in the semantic ref index.""" + terms = await semref_index.get_terms() + best_term = None + best_count = 0 + for t in terms: + refs = await semref_index.lookup_term(t) + if refs and len(refs) > best_count: + best_count = len(refs) + best_term = t + assert best_term is not None, "No terms found after indexing" + return best_term, best_count + + +def make_scope_first_half(n_messages: int) -> TextRangesInScope: + """Build a TextRangesInScope covering the first half of messages.""" + ranges = [ + TextRange( + start=TextLocation(i, 0), + end=TextLocation(i, 0), + ) + for i in range(n_messages // 2) + ] + scope = TextRangesInScope() + scope.add_text_ranges(TextRangeCollection(ranges)) + return scope + + +@pytest.mark.asyncio +async def test_benchmark_lookup_term_filtered(async_benchmark): + """Benchmark lookup_term_filtered with batch get_metadata_multiple.""" + settings = make_settings() + tmpdir = tempfile.mkdtemp() + db_path = os.path.join(tmpdir, "bench_ltf.db") + + transcript = await create_indexed_transcript(db_path, settings, 200) + semref_index = transcript.semantic_ref_index + best_term, best_count = await find_best_term(semref_index) + print(f"\nBenchmarking term '{best_term}' with {best_count} matches") + + term = Term(text=best_term) + semantic_refs = transcript.semantic_refs + accept_all = lambda sr, scored: True + + async def target(): + await lookup_term_filtered(semref_index, term, semantic_refs, accept_all) + + try: + await async_benchmark.pedantic(target, rounds=200, warmup_rounds=20) + finally: + await settings.storage_provider.close() + shutil.rmtree(tmpdir, ignore_errors=True) + + +@pytest.mark.asyncio +async def test_benchmark_lookup_property_in_property_index(async_benchmark): + """Benchmark property lookup with range filtering.""" + settings = make_settings() + tmpdir = tempfile.mkdtemp() + db_path = os.path.join(tmpdir, "bench_prop.db") + + transcript = await create_indexed_transcript(db_path, settings, 200) + assert transcript.secondary_indexes is not None + property_index = transcript.secondary_indexes.property_to_semantic_ref_index + assert property_index is not None, "Property index not built" + + # Verify there are matches for entity type "person" + refs = await property_index.lookup_property( + PropertyNames.EntityType.value, "person" + ) + match_count = len(refs) if refs else 0 + print(f"\nBenchmarking property 'type=person' with {match_count} matches") + assert match_count > 0 + + scope = make_scope_first_half(200) + + async def target(): + await lookup_property_in_property_index( + property_index, + PropertyNames.EntityType.value, + "person", + transcript.semantic_refs, + ranges_in_scope=scope, + ) + + try: + await async_benchmark.pedantic(target, rounds=200, warmup_rounds=20) + finally: + await settings.storage_provider.close() + shutil.rmtree(tmpdir, ignore_errors=True) + + +@pytest.mark.asyncio +async def test_benchmark_group_matches_by_type(async_benchmark): + """Benchmark grouping accumulated matches by knowledge type.""" + settings = make_settings() + tmpdir = tempfile.mkdtemp() + db_path = os.path.join(tmpdir, "bench_group.db") + + transcript = await create_indexed_transcript(db_path, settings, 200) + semref_index = transcript.semantic_ref_index + best_term, best_count = await find_best_term(semref_index) + print(f"\nBenchmarking group_matches_by_type: term '{best_term}' ({best_count} matches)") + + scored_refs = await semref_index.lookup_term(best_term) + accumulator = SemanticRefAccumulator() + accumulator.add_term_matches( + Term(text=best_term), scored_refs, is_exact_match=True + ) + + async def target(): + await accumulator.group_matches_by_type(transcript.semantic_refs) + + try: + await async_benchmark.pedantic(target, rounds=200, warmup_rounds=20) + finally: + await settings.storage_provider.close() + shutil.rmtree(tmpdir, ignore_errors=True) + + +@pytest.mark.asyncio +async def test_benchmark_get_matches_in_scope(async_benchmark): + """Benchmark filtering accumulated matches by range scope.""" + settings = make_settings() + tmpdir = tempfile.mkdtemp() + db_path = os.path.join(tmpdir, "bench_scope.db") + + transcript = await create_indexed_transcript(db_path, settings, 200) + semref_index = transcript.semantic_ref_index + best_term, best_count = await find_best_term(semref_index) + print(f"\nBenchmarking get_matches_in_scope: term '{best_term}' ({best_count} matches)") + + scored_refs = await semref_index.lookup_term(best_term) + accumulator = SemanticRefAccumulator() + accumulator.add_term_matches( + Term(text=best_term), scored_refs, is_exact_match=True + ) + + scope = make_scope_first_half(200) + + async def target(): + await accumulator.get_matches_in_scope(transcript.semantic_refs, scope) + + try: + await async_benchmark.pedantic(target, rounds=200, warmup_rounds=20) + finally: + await settings.storage_provider.close() + shutil.rmtree(tmpdir, ignore_errors=True) + + +@pytest.mark.asyncio +async def test_benchmark_get_scored_semantic_refs_from_ordinals_iter(async_benchmark): + """Benchmark two-phase metadata filter + batch fetch for scored refs.""" + settings = make_settings() + tmpdir = tempfile.mkdtemp() + db_path = os.path.join(tmpdir, "bench_scored.db") + + transcript = await create_indexed_transcript(db_path, settings, 200) + semref_index = transcript.semantic_ref_index + best_term, best_count = await find_best_term(semref_index) + print( + f"\nBenchmarking get_scored_semantic_refs_from_ordinals_iter: " + f"term '{best_term}' ({best_count} matches), filter=entity" + ) + + scored_refs = await semref_index.lookup_term(best_term) + assert scored_refs is not None + + async def target(): + await get_scored_semantic_refs_from_ordinals_iter( + transcript.semantic_refs, scored_refs, "entity" + ) + + try: + await async_benchmark.pedantic(target, rounds=200, warmup_rounds=20) + finally: + await settings.storage_provider.close() + shutil.rmtree(tmpdir, ignore_errors=True) diff --git a/tests/test_utils.py b/tests/test_utils.py index 5966af61..7f806f74 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -67,7 +67,7 @@ def test_api_version_after_question_mark( ) endpoint, version = utils.parse_azure_endpoint("TEST_ENDPOINT") assert version == "2025-01-01-preview" - assert endpoint.startswith("https://") + assert endpoint == "https://myhost.openai.azure.com/openai/deployments/gpt-4" def test_api_version_after_ampersand(self, monkeypatch: pytest.MonkeyPatch) -> None: """api-version preceded by & (not the first query parameter).""" @@ -84,6 +84,44 @@ def test_missing_env_var_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: with pytest.raises(RuntimeError, match="not found"): utils.parse_azure_endpoint("NONEXISTENT_ENDPOINT") + def test_query_string_stripped_from_endpoint( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Returned endpoint should not contain query string parameters.""" + monkeypatch.setenv( + "TEST_ENDPOINT", + "https://myhost.openai.azure.com?api-version=2024-06-01", + ) + endpoint, version = utils.parse_azure_endpoint("TEST_ENDPOINT") + assert endpoint == "https://myhost.openai.azure.com" + assert version == "2024-06-01" + + def test_query_string_stripped_with_path( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Query string stripped even when endpoint includes a path.""" + monkeypatch.setenv( + "TEST_ENDPOINT", + "https://myhost.openai.azure.com/openai/deployments/gpt-4?api-version=2025-01-01-preview", + ) + endpoint, version = utils.parse_azure_endpoint("TEST_ENDPOINT") + assert endpoint == "https://myhost.openai.azure.com/openai/deployments/gpt-4" + assert "?" not in endpoint + assert version == "2025-01-01-preview" + + def test_query_string_stripped_multiple_params( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """All query parameters stripped, not just api-version.""" + monkeypatch.setenv( + "TEST_ENDPOINT", + "https://myhost.openai.azure.com?foo=bar&api-version=2024-06-01", + ) + endpoint, version = utils.parse_azure_endpoint("TEST_ENDPOINT") + assert endpoint == "https://myhost.openai.azure.com" + assert "foo" not in endpoint + assert version == "2024-06-01" + def test_no_api_version_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: """RuntimeError when the endpoint has no api-version field.""" monkeypatch.setenv( diff --git a/uv.lock b/uv.lock index 4eab7ead..e2b66a3d 100644 --- a/uv.lock +++ b/uv.lock @@ -1922,6 +1922,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, ] +[[package]] +name = "pytest-async-benchmark" +version = "0.2.0" +source = { git = "https://github.com/KRRT7/pytest-async-benchmark.git?rev=feat%2Fpedantic-mode#029d03634d140789baebc6c3c8f72d5c81a67f9a" } +dependencies = [ + { name = "pytest" }, + { name = "rich" }, +] + [[package]] name = "pytest-asyncio" version = "1.3.0" @@ -2398,6 +2407,7 @@ dev = [ { name = "opentelemetry-instrumentation-httpx" }, { name = "pyright" }, { name = "pytest" }, + { name = "pytest-async-benchmark" }, { name = "pytest-asyncio" }, { name = "pytest-mock" }, ] @@ -2436,6 +2446,7 @@ dev = [ { name = "opentelemetry-instrumentation-httpx", specifier = ">=0.57b0" }, { name = "pyright", specifier = ">=1.1.408" }, { name = "pytest", specifier = ">=8.3.5" }, + { name = "pytest-async-benchmark", git = "https://github.com/KRRT7/pytest-async-benchmark.git?rev=feat%2Fpedantic-mode" }, { name = "pytest-asyncio", specifier = ">=0.26.0" }, { name = "pytest-mock", specifier = ">=3.14.0" }, ]