Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
ecbf6f5
Defer black import to first use
KRRT7 Apr 10, 2026
d4bc744
Batch schema DDL into executescript and pre-compile regex
KRRT7 Apr 10, 2026
88c8d5e
Revert "Batch schema DDL into executescript and pre-compile regex"
KRRT7 Apr 10, 2026
bc9f2df
Batch SQLite INSERTs for semantic ref and property indexing
KRRT7 Apr 10, 2026
2140df4
Add pytest-async-benchmark tests for indexing pipeline
KRRT7 Apr 10, 2026
d0d070b
Rewrite benchmarks to use async_benchmark.pedantic()
KRRT7 Apr 10, 2026
93a05a7
Remove section-divider comments from benchmark tests
KRRT7 Apr 10, 2026
dd5d738
Extract shared benchmark harness to reduce duplication
KRRT7 Apr 10, 2026
4148b7a
Remove underscore prefix from collect helper functions
KRRT7 Apr 10, 2026
aa7e958
Add pytest-async-benchmark as dev dependency
KRRT7 Apr 10, 2026
771e481
Fix pyright errors: use Sequence for batch method signatures
KRRT7 Apr 10, 2026
9e3de1c
Optimize fuzzy_lookup_embedding with numpy vectorized ops
KRRT7 Apr 10, 2026
cfac657
Add batch metadata query to avoid N+1 in lookup_term_filtered
KRRT7 Apr 10, 2026
f1537b4
Extend batch metadata query to remaining N+1 call sites
KRRT7 Apr 10, 2026
c8f69fd
Fix parse_azure_endpoint passing query string to AsyncAzureOpenAI
KRRT7 Apr 10, 2026
ee86888
Add benchmarks for all batch metadata query call sites
KRRT7 Apr 10, 2026
163ef51
Speed up scope-filtering: bisect in contains_range, inline tuple comp…
KRRT7 Apr 10, 2026
c3eff5f
perf: Replace black with stdlib pprint for runtime formatting
KRRT7 Apr 10, 2026
8f858c9
perf: Defer query-time imports in conversation_base
KRRT7 Apr 10, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ classifiers = [
]
dependencies = [
"azure-identity>=1.22.0",
"black>=25.12.0",
"colorama>=0.4.6",
"mcp[cli]>=1.12.1",
"numpy>=2.2.6",
Expand Down Expand Up @@ -58,6 +57,9 @@ Documentation = "https://github.com/microsoft/typeagent-py/tree/main/docs/README
[tool.uv.build-backend]
module-root = "src"

[tool.uv.sources]
pytest-async-benchmark = { git = "https://github.com/KRRT7/pytest-async-benchmark.git", rev = "feat/pedantic-mode" }

[tool.pytest.ini_options]
asyncio_default_fixture_loop_scope = "function"
testpaths = ["tests"]
Expand All @@ -81,6 +83,7 @@ known_local_folder = ["conftest"]
dev = [
"azure-mgmt-authorization>=4.0.0",
"azure-mgmt-keyvault>=12.1.1",
"black>=25.12.0",
"coverage[toml]>=7.9.1",
"google-api-python-client>=2.184.0",
"google-auth-httplib2>=0.2.0",
Expand All @@ -91,6 +94,7 @@ dev = [
"opentelemetry-instrumentation-httpx>=0.57b0",
"pyright>=1.1.408", # 407 has a regression
"pytest>=8.3.5",
"pytest-async-benchmark",
"pytest-asyncio>=0.26.0",
"pytest-mock>=3.14.0",
]
28 changes: 15 additions & 13 deletions src/typeagent/aitools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import sys
import time

import black
import colorama

import typechat
Expand Down Expand Up @@ -45,25 +44,24 @@ def timelog(label: str, verbose: bool = True):


def pretty_print(obj: object, prefix: str = "", suffix: str = "") -> None:
"""Pretty-print an object using black.
"""Pretty-print an object using pprint."""
import pprint

NOTE: Only works if its repr() is a valid Python expression.
"""
print(prefix + format_code(repr(obj)) + suffix)
line_width = min(200, shutil.get_terminal_size().columns)
print(pprint.pformat(obj, width=line_width))


def format_code(text: str, line_width=None) -> str:
"""Format a block of code using black, then reindent to 2 spaces.
"""Format a Python literal expression using pprint.

NOTE: The text must be a valid Python expression or code block.
NOTE: The text must be a valid Python literal expression (as produced by repr()).
"""
import ast
import pprint

if line_width is None:
# Use the terminal width, but cap it to 200 characters.
line_width = min(200, shutil.get_terminal_size().columns)
formatted_text = black.format_str(
text, mode=black.Mode(line_length=line_width)
).rstrip()
return reindent(formatted_text)
return pprint.pformat(ast.literal_eval(text), width=line_width)


def reindent(text: str) -> str:
Expand Down Expand Up @@ -197,7 +195,11 @@ def parse_azure_endpoint(
f"{endpoint_envvar}={azure_endpoint} doesn't contain valid api-version field"
)

return azure_endpoint, m.group(1)
# Strip query string — AsyncAzureOpenAI expects a clean base URL and
# receives api_version as a separate parameter.
clean_endpoint = azure_endpoint.split("?", 1)[0]

return clean_endpoint, m.group(1)


def get_azure_api_key(azure_api_key: str) -> str:
Expand Down
65 changes: 50 additions & 15 deletions src/typeagent/aitools/vectorbase.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from collections.abc import Callable, Iterable
from collections.abc import Callable
from dataclasses import dataclass

import numpy as np
Expand Down Expand Up @@ -132,28 +132,63 @@ def fuzzy_lookup_embedding(
min_score = 0.0
if len(self._vectors) == 0:
return []
# This line does most of the work:
scores: Iterable[float] = np.dot(self._vectors, embedding)
scored_ordinals = [
ScoredInt(i, score)
for i, score in enumerate(scores)
if score >= min_score and (predicate is None or predicate(i))
]
scored_ordinals.sort(key=lambda x: x.score, reverse=True)
return scored_ordinals[:max_hits]
scores = np.dot(self._vectors, embedding)

if predicate is None:
# Fast numpy path: filter and top-k without Python-level iteration.
indices = np.flatnonzero(scores >= min_score)
if len(indices) == 0:
return []
filtered_scores = scores[indices]
if len(indices) <= max_hits:
order = np.argsort(filtered_scores)[::-1]
else:
# argpartition is O(n) vs O(n log n) for full sort.
top_k = np.argpartition(filtered_scores, -max_hits)[-max_hits:]
order = top_k[np.argsort(filtered_scores[top_k])[::-1]]
return [
ScoredInt(int(indices[i]), float(filtered_scores[i])) for i in order
]
else:
# Predicate path: pre-filter by score in numpy, then apply predicate
# only to candidates that pass the score threshold.
candidates = np.flatnonzero(scores >= min_score)
scored_ordinals = [
ScoredInt(int(i), float(scores[i]))
for i in candidates
if predicate(int(i))
]
scored_ordinals.sort(key=lambda x: x.score, reverse=True)
return scored_ordinals[:max_hits]

# TODO: Make this and fuzzy_lookup_embedding() more similar.
def fuzzy_lookup_embedding_in_subset(
self,
embedding: NormalizedEmbedding,
ordinals_of_subset: list[int],
max_hits: int | None = None,
min_score: float | None = None,
) -> list[ScoredInt]:
ordinals_set = set(ordinals_of_subset)
return self.fuzzy_lookup_embedding(
embedding, max_hits, min_score, lambda i: i in ordinals_set
)
if max_hits is None:
max_hits = 10
if min_score is None:
min_score = 0.0
if not ordinals_of_subset or len(self._vectors) == 0:
return []
# Compute dot products only for the subset instead of all vectors.
subset = np.asarray(ordinals_of_subset)
scores = np.dot(self._vectors[subset], embedding)
indices = np.flatnonzero(scores >= min_score)
if len(indices) == 0:
return []
filtered_scores = scores[indices]
if len(indices) <= max_hits:
order = np.argsort(filtered_scores)[::-1]
else:
top_k = np.argpartition(filtered_scores, -max_hits)[-max_hits:]
order = top_k[np.argsort(filtered_scores[top_k])[::-1]]
return [
ScoredInt(int(subset[indices[i]]), float(filtered_scores[i])) for i in order
]

async def fuzzy_lookup(
self,
Expand Down
35 changes: 19 additions & 16 deletions src/typeagent/knowpro/answers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from dataclasses import dataclass
from typing import Any

import black

import typechat

from .answer_context_schema import AnswerContext, RelevantKnowledge, RelevantMessage
Expand Down Expand Up @@ -127,10 +125,12 @@ def create_question_prompt(question: str) -> str:

def create_context_prompt(context: AnswerContext) -> str:
# TODO: Use a more compact representation of the context than JSON.
import pprint

prompt = [
"[ANSWER CONTEXT]",
"===",
black.format_str(str(dictify(context)), mode=black.Mode(line_length=200)),
pprint.pformat(dictify(context), width=200),
"===",
]
return "\n".join(prompt)
Expand Down Expand Up @@ -452,19 +452,22 @@ async def get_scored_semantic_refs_from_ordinals_iter(
semantic_ref_matches: list[ScoredSemanticRefOrdinal],
knowledge_type: KnowledgeType,
) -> list[Scored[SemanticRef]]:
result = []
for semantic_ref_match in semantic_ref_matches:
semantic_ref = await semantic_refs.get_item(
semantic_ref_match.semantic_ref_ordinal
)
if semantic_ref.knowledge.knowledge_type == knowledge_type:
result.append(
Scored(
item=semantic_ref,
score=semantic_ref_match.score,
)
)
return result
if not semantic_ref_matches:
return []
ordinals = [m.semantic_ref_ordinal for m in semantic_ref_matches]
metadata = await semantic_refs.get_metadata_multiple(ordinals)
matching = [
(sr_match, m.ordinal)
for sr_match, m in zip(semantic_ref_matches, metadata)
if m.knowledge_type == knowledge_type
]
if not matching:
return []
full_refs = await semantic_refs.get_multiple([o for _, o in matching])
return [
Scored(item=ref, score=sr_match.score)
for (sr_match, _), ref in zip(matching, full_refs)
]


def merge_scored_concrete_entities(
Expand Down
39 changes: 25 additions & 14 deletions src/typeagent/knowpro/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,13 +331,17 @@ async def group_matches_by_type(
self,
semantic_refs: ISemanticRefCollection,
) -> dict[KnowledgeType, "SemanticRefAccumulator"]:
matches = list(self)
if not matches:
return {}
ordinals = [match.value for match in matches]
metadata = await semantic_refs.get_metadata_multiple(ordinals)
groups: dict[KnowledgeType, SemanticRefAccumulator] = {}
for match in self:
semantic_ref = await semantic_refs.get_item(match.value)
group = groups.get(semantic_ref.knowledge.knowledge_type)
for match, m in zip(matches, metadata):
group = groups.get(m.knowledge_type)
if group is None:
group = SemanticRefAccumulator(self.search_term_matches)
groups[semantic_ref.knowledge.knowledge_type] = group
groups[m.knowledge_type] = group
group.set_match(match)
return groups

Expand All @@ -346,11 +350,14 @@ async def get_matches_in_scope(
semantic_refs: ISemanticRefCollection,
ranges_in_scope: "TextRangesInScope",
) -> "SemanticRefAccumulator":
matches = list(self)
if not matches:
return SemanticRefAccumulator(self.search_term_matches)
ordinals = [match.value for match in matches]
metadata = await semantic_refs.get_metadata_multiple(ordinals)
accumulator = SemanticRefAccumulator(self.search_term_matches)
for match in self:
if ranges_in_scope.is_range_in_scope(
(await semantic_refs.get_item(match.value)).range
):
for match, m in zip(matches, metadata):
if ranges_in_scope.is_range_in_scope(m.range):
accumulator.set_match(match)
return accumulator

Expand Down Expand Up @@ -519,12 +526,16 @@ def add_ranges(self, text_ranges: "list[TextRange] | TextRangeCollection") -> No
self.add_range(text_range)

def contains_range(self, inner_range: TextRange) -> bool:
# Since ranges are sorted by start, once we pass inner_range's start
# no further range can contain it.
for outer_range in self._ranges:
if outer_range.start > inner_range.start:
break
if inner_range in outer_range:
if not self._ranges:
return False
# Bisect on start only to find all ranges with start <= inner.start,
# then scan backwards — the most likely containing range has the
# largest start still <= inner's.
hi = bisect.bisect_right(
self._ranges, inner_range.start, key=lambda r: r.start
)
for i in range(hi - 1, -1, -1):
if inner_range in self._ranges[i]:
return True
return False

Expand Down
13 changes: 8 additions & 5 deletions src/typeagent/knowpro/conversation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,17 @@

"""Base class for conversations with incremental indexing support."""

from __future__ import annotations

from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Generic, Self, TypeVar
from typing import TYPE_CHECKING, Generic, Self, TypeVar

import typechat

from . import (
answer_response_schema,
answers,
convknowledge,
knowledge_schema as kplib,
search_query_schema,
searchlang,
secindex,
)
from ..aitools import model_adapters, utils
Expand All @@ -35,6 +33,9 @@
Topic,
)

if TYPE_CHECKING:
from . import answer_response_schema, answers, search_query_schema, searchlang

TMessage = TypeVar("TMessage", bound=IMessage)


Expand Down Expand Up @@ -350,6 +351,8 @@ async def query(
>>> answer = await conv.query("What topics were discussed?")
>>> print(answer)
"""
from . import answer_response_schema, answers, search_query_schema, searchlang

# Create translators lazily (once per conversation instance)
if self._query_translator is None:
model = model_adapters.create_chat_model()
Expand Down
Loading
Loading