From 916d0c4f5d82a62de002e62da093b6aad94b8754 Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Thu, 23 Apr 2026 14:11:52 +0200 Subject: [PATCH 1/8] enhancing testcoverage for existing modules: - updated import statements - added additional test functions --- tests/test_mcp_server.py | 86 +++++++++++++-- tests/test_transcripts.py | 224 +++++++++++++++++++++++++++++++++++--- tests/test_utils.py | 158 ++++++++++++++++++++++++--- 3 files changed, 433 insertions(+), 35 deletions(-) diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py index 16577f86..2bfa4b01 100644 --- a/tests/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -8,7 +8,7 @@ import sys from types import SimpleNamespace from typing import Any, cast -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, MagicMock import pytest @@ -26,7 +26,16 @@ import typechat from typeagent.aitools.utils import create_async_openai_client, resolve_azure_model_name -from typeagent.mcp.server import MCPTypeChatModel, QuestionResponse +from typeagent.knowpro import answers, searchlang +from typeagent.knowpro.answer_response_schema import AnswerResponse +from typeagent.knowpro.convsettings import ConversationSettings +import typeagent.mcp.server as typeagent_mcp_server +from typeagent.mcp.server import ( + load_podcast_database_or_index, + MCPTypeChatModel, + ProcessingContext, + QuestionResponse, +) from conftest import EPISODE_53_INDEX @@ -204,9 +213,7 @@ async def test_mcp_server_empty_question(server_params: StdioServerParameters): def test_server_module_imports() -> None: """Importing the server module should not raise even without coverage.""" - import typeagent.mcp.server as mod - - assert hasattr(mod, "mcp") # The FastMCP instance exists + assert hasattr(typeagent_mcp_server, "mcp") # The FastMCP instance exists # --------------------------------------------------------------------------- @@ -342,8 +349,6 @@ def test_known_types(self) -> None: def test_answer_type_coverage(self) -> None: """AnswerResponse.type should only be 'Answered' or 'NoAnswer'.""" - from typeagent.knowpro.answer_response_schema import AnswerResponse - answered = AnswerResponse(type="Answered", answer="yes") assert answered.type == "Answered" no_answer = AnswerResponse(type="NoAnswer", why_no_answer="dunno") @@ -408,3 +413,70 @@ async def test_sampling_callback_uses_azure_deployment_name( temperature=1.0, ) assert result.model == "gpt-4o-2" + + +# --------------------------------------------------------------------------- +# MCPTypeChatModel — additional response format coverage +# --------------------------------------------------------------------------- + + +class TestMCPTypeChatModelResponseFormats: + @staticmethod + def _make_model_with_result(content: Any) -> MCPTypeChatModel: + session = AsyncMock() + session.create_message.return_value = AsyncMock(content=content) + return MCPTypeChatModel(session) + + @pytest.mark.asyncio + async def test_list_content_no_text_items_returns_failure(self) -> None: + """A list response with no TextContent items should return Failure.""" + # Use a non-TextContent item type (ImageContent would work but we mock with a dict) + model = self._make_model_with_result([]) + result = await model.complete("test") + assert isinstance(result, typechat.Failure) + assert "No text content" in result.message + + @pytest.mark.asyncio + async def test_unknown_content_type_returns_failure(self) -> None: + """A response with an unrecognized content type should return Failure.""" + # Simulate some unknown object that is neither TextContent nor list + model = self._make_model_with_result(42) + result = await model.complete("test") + assert isinstance(result, typechat.Failure) + assert "No text content" in result.message + + +# --------------------------------------------------------------------------- +# ProcessingContext.__repr__ +# --------------------------------------------------------------------------- + + +class TestProcessingContextRepr: + def test_repr_contains_options(self) -> None: + lang_opts = searchlang.LanguageSearchOptions(max_message_matches=10) + ctx_opts = answers.AnswerContextOptions(entities_top_k=5) + + proc = ProcessingContext( + lang_search_options=lang_opts, + answer_context_options=ctx_opts, + query_context=MagicMock(), + embedding_model=MagicMock(), + query_translator=MagicMock(), + answer_translator=MagicMock(), + ) + r = repr(proc) + assert r.startswith("Context(") + assert "LanguageSearchOptions" in r + + +# --------------------------------------------------------------------------- +# load_podcast_database_or_index — ValueError path +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_load_podcast_no_args_raises() -> None: + """Passing neither dbname nor podcast_index must raise ValueError.""" + settings = ConversationSettings() + with pytest.raises(ValueError, match="Either --database or --podcast-index"): + await load_podcast_database_or_index(settings, dbname=None, podcast_index=None) diff --git a/tests/test_transcripts.py b/tests/test_transcripts.py index 9d98ae88..cc218bf9 100644 --- a/tests/test_transcripts.py +++ b/tests/test_transcripts.py @@ -11,7 +11,13 @@ from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.knowpro.universal_message import format_timestamp_utc, UNIX_EPOCH +from typeagent.storage.memory.collections import ( + MemoryMessageCollection, + MemorySemanticRefCollection, +) +from typeagent.storage.memory.semrefindex import TermToSemanticRefIndex from typeagent.transcripts.transcript import ( + split_speaker_name, Transcript, TranscriptMessage, TranscriptMessageMeta, @@ -20,6 +26,7 @@ extract_speaker_from_text, get_transcript_duration, get_transcript_speakers, + parse_voice_tags, webvtt_timestamp_to_seconds, ) @@ -103,13 +110,6 @@ def conversation_settings( @pytest.mark.asyncio async def test_ingest_vtt_transcript(conversation_settings: ConversationSettings): """Test importing a VTT file into a Transcript object.""" - from typeagent.storage.memory.collections import ( - MemoryMessageCollection, - MemorySemanticRefCollection, - ) - from typeagent.storage.memory.semrefindex import TermToSemanticRefIndex - from typeagent.transcripts.transcript_ingest import parse_voice_tags - vtt_file = CONFUSE_A_CAT_VTT # Use in-memory storage to avoid database cleanup issues @@ -252,12 +252,6 @@ async def test_transcript_knowledge_extraction_slow( 4. Verifies both mechanical extraction (entities/actions from metadata) and LLM extraction (topics from content) work correctly """ - from typeagent.storage.memory.collections import ( - MemoryMessageCollection, - MemorySemanticRefCollection, - ) - from typeagent.storage.memory.semrefindex import TermToSemanticRefIndex - # Use in-memory storage for speed settings = ConversationSettings(embedding_model) @@ -345,3 +339,207 @@ async def test_transcript_knowledge_extraction_slow( ) print(f"Knowledge types: {knowledge_types}") print(f"Indexed terms: {len(terms)}") + + +# --------------------------------------------------------------------------- +# split_speaker_name +# --------------------------------------------------------------------------- + + +class TestSplitSpeakerName: + def test_single_word(self) -> None: + result = split_speaker_name("alice") + assert result is not None + assert result.first_name == "alice" + assert result.last_name is None + assert result.middle_name is None + + def test_two_words(self) -> None: + result = split_speaker_name("john smith") + assert result is not None + assert result.first_name == "john" + assert result.last_name == "smith" + assert result.middle_name is None + + def test_three_words(self) -> None: + result = split_speaker_name("john michael smith") + assert result is not None + assert result.first_name == "john" + assert result.middle_name == "michael" + assert result.last_name == "smith" + + def test_van_prefix_merged_into_last_name(self) -> None: + result = split_speaker_name("jan van eyck") + assert result is not None + assert result.first_name == "jan" + assert result.last_name == "van eyck" + assert result.middle_name is None + + def test_empty_string_returns_none(self) -> None: + result = split_speaker_name("") + assert result is None + + +# --------------------------------------------------------------------------- +# Serialize / deserialize roundtrip (in-memory, no LLM) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_transcript_serialize_deserialize_roundtrip() -> None: + """Serialize a transcript and deserialize into a fresh one — data is preserved.""" + embedding_model = create_test_embedding_model() + settings = ConversationSettings(embedding_model) + settings.semantic_ref_index_settings.auto_extract_knowledge = False + + # Build original transcript — use add_messages_with_indexing so the + # message text index (and its embeddings) are populated before serializing. + original = await Transcript.create(settings, name="roundtrip-test", tags=["foo"]) + msg1 = TranscriptMessage( + text_chunks=["Hello world"], + metadata=TranscriptMessageMeta(speaker="Alice", recipients=["Bob"]), + tags=["t1"], + timestamp="2024-01-01T00:00:00Z", + ) + msg2 = TranscriptMessage( + text_chunks=["Goodbye"], + metadata=TranscriptMessageMeta(speaker="Bob", recipients=[]), + tags=[], + timestamp="2024-01-01T00:01:00Z", + ) + await original.add_messages_with_indexing([msg1, msg2]) + data = await original.serialize() + + # Deserialize into a fresh transcript. + fresh_settings = ConversationSettings(embedding_model) + fresh_settings.semantic_ref_index_settings.auto_extract_knowledge = False + fresh = await Transcript.create(fresh_settings, name="", tags=[]) + await fresh.deserialize(data) + + assert fresh.name_tag == "roundtrip-test" + assert "foo" in fresh.tags + assert await fresh.messages.size() == 2 + + first = await fresh.messages.get_item(0) + assert first.text_chunks == ["Hello world"] + assert first.metadata.speaker == "Alice" + assert first.metadata.recipients == ["Bob"] + assert first.timestamp == "2024-01-01T00:00:00Z" + + +@pytest.mark.asyncio +async def test_transcript_deserialize_non_empty_raises() -> None: + """Deserializing into a non-empty Transcript raises RuntimeError.""" + embedding_model = create_test_embedding_model() + settings = ConversationSettings(embedding_model) + + transcript = await Transcript.create(settings, name="test", tags=[]) + await transcript.messages.append( + TranscriptMessage( + text_chunks=["existing"], + metadata=TranscriptMessageMeta(speaker=None, recipients=[]), + ) + ) + data = await transcript.serialize() + + # Trying to deserialize into it again must raise. + with pytest.raises(RuntimeError): + await transcript.deserialize(data) + + +# --------------------------------------------------------------------------- +# write_to_file / read_from_file roundtrip +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_write_and_read_from_file(tmp_path: os.PathLike[str]) -> None: + """write_to_file + read_from_file preserves names, tags, and messages.""" + embedding_model = create_test_embedding_model() + settings = ConversationSettings(embedding_model) + settings.semantic_ref_index_settings.auto_extract_knowledge = False + + original = await Transcript.create(settings, name="file-test", tags=["persisted"]) + msg = TranscriptMessage( + text_chunks=["Persisted message"], + metadata=TranscriptMessageMeta(speaker="Eve", recipients=[]), + timestamp="2024-06-01T12:00:00Z", + ) + # Use add_messages_with_indexing so embeddings are built before writing. + await original.add_messages_with_indexing([msg]) + prefix = os.path.join(str(tmp_path), "test_transcript") + await original.write_to_file(prefix) + + # Verify the _data.json file was written. + assert os.path.exists(prefix + "_data.json") + + # Read it back. + fresh_settings = ConversationSettings(embedding_model) + fresh_settings.semantic_ref_index_settings.auto_extract_knowledge = False + loaded = await Transcript.read_from_file(prefix, fresh_settings) + + assert loaded.name_tag == "file-test" + assert "persisted" in loaded.tags + assert await loaded.messages.size() == 1 + first = await loaded.messages.get_item(0) + assert first.text_chunks == ["Persisted message"] + assert first.metadata.speaker == "Eve" + + +# --------------------------------------------------------------------------- +# Speaker alias building +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_build_speaker_aliases_full_name() -> None: + """Full-name speakers create first-name ↔ full-name aliases.""" + embedding_model = create_test_embedding_model() + settings = ConversationSettings(embedding_model) + + transcript = await Transcript.create(settings, name="alias-test", tags=[]) + msg = TranscriptMessage( + text_chunks=["Hi"], + metadata=TranscriptMessageMeta(speaker="John Smith", recipients=[]), + ) + await transcript.messages.append(msg) + + # Rebuild aliases explicitly. + await transcript._build_speaker_aliases() + + secondary = transcript._get_secondary_indexes() + assert secondary.term_to_related_terms_index is not None + aliases = secondary.term_to_related_terms_index.aliases + + # "john" should be aliased to "john smith" and vice-versa. + john_aliases = await aliases.lookup_term("john") + assert john_aliases is not None + alias_texts = [t.text for t in john_aliases] + assert "john smith" in alias_texts + + full_aliases = await aliases.lookup_term("john smith") + assert full_aliases is not None + assert "john" in [t.text for t in full_aliases] + + +@pytest.mark.asyncio +async def test_build_speaker_aliases_single_name_no_alias() -> None: + """Single-word speaker names produce no aliases.""" + embedding_model = create_test_embedding_model() + settings = ConversationSettings(embedding_model) + + transcript = await Transcript.create(settings, name="alias-test2", tags=[]) + msg = TranscriptMessage( + text_chunks=["Hello"], + metadata=TranscriptMessageMeta(speaker="Alice", recipients=[]), + ) + await transcript.messages.append(msg) + await transcript._build_speaker_aliases() + + secondary = transcript._get_secondary_indexes() + assert secondary.term_to_related_terms_index is not None + aliases = secondary.term_to_related_terms_index.aliases + + # Single-name speaker — no alias entry expected. + result = await aliases.lookup_term("alice") + assert not result diff --git a/tests/test_utils.py b/tests/test_utils.py index ad526129..22f930ae 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -8,6 +8,7 @@ from dotenv import load_dotenv import pytest +from openai import AsyncAzureOpenAI, AsyncOpenAI import pydantic.dataclasses import typechat @@ -208,36 +209,163 @@ def test_no_api_version_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: with pytest.raises(RuntimeError, match="doesn't contain valid api-version"): utils.parse_azure_endpoint("TEST_ENDPOINT") + def test_no_deployment_returns_none(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Endpoint without /deployments/ yields deployment_name=None.""" + monkeypatch.setenv( + "TEST_ENDPOINT", + "https://myhost.openai.azure.com/openai?api-version=2024-06-01", + ) + endpoint, version, deployment = utils.parse_azure_endpoint_parts( + "TEST_ENDPOINT" + ) + assert endpoint == "https://myhost.openai.azure.com" + assert version == "2024-06-01" + assert deployment is None -class TestResolveAzureModelName: - """Tests for resolve_azure_model_name.""" - - def test_returns_deployment_name_from_endpoint( + def test_apim_style_deployment_extracted( self, monkeypatch: pytest.MonkeyPatch ) -> None: - """Deployment name in the endpoint takes precedence over the fallback.""" + """APIM-style URL: prefix before /openai kept, deployment name extracted.""" monkeypatch.setenv( "TEST_ENDPOINT", - "https://myhost.openai.azure.com/openai/deployments/gpt-4o-custom?api-version=2025-01-01-preview", + "https://apim.net/openai/openai/deployments/gpt-4o/chat/completions?api-version=2025-01-01-preview", + ) + endpoint, version, deployment = utils.parse_azure_endpoint_parts( + "TEST_ENDPOINT" + ) + assert endpoint == "https://apim.net/openai" + assert version == "2025-01-01-preview" + assert deployment == "gpt-4o" + + +class TestReindent: + def test_four_spaces_to_two(self) -> None: + text = "def foo():\n pass\n return 1" + result = utils.reindent(text) + assert result == "def foo():\n pass\n return 1" + + def test_empty_string(self) -> None: + assert utils.reindent("") == "" + + def test_no_indent(self) -> None: + assert utils.reindent("hello") == "hello" + + def test_nested_indent(self) -> None: + text = "a\n b\n c" + result = utils.reindent(text) + assert result == "a\n b\n c" + + +class TestTimelog: + def test_verbose_false_no_output(self) -> None: + buf = StringIO() + with redirect_stderr(buf): + with utils.timelog("silent", verbose=False): + pass + assert buf.getvalue() == "" + + def test_verbose_true_shows_label(self) -> None: + buf = StringIO() + with redirect_stderr(buf): + with utils.timelog("myblock", verbose=True): + pass + assert "myblock" in buf.getvalue() + + +class TestListDiff: + def test_identical_lists(self) -> None: + buf = StringIO() + with redirect_stdout(buf): + utils.list_diff("a", [1, 2, 3], "b", [1, 2, 3], max_items=10) + out = buf.getvalue() + assert "1" in out + assert "2" in out + + def test_different_lists(self) -> None: + buf = StringIO() + with redirect_stdout(buf): + utils.list_diff("left", [1, 2], "right", [1, 3], max_items=10) + assert buf.getvalue() != "" + + def test_no_max_items(self) -> None: + buf = StringIO() + with redirect_stdout(buf): + utils.list_diff("a", [1], "b", [2], max_items=0) + assert "1" in buf.getvalue() or "2" in buf.getvalue() + + def test_empty_lists(self) -> None: + buf = StringIO() + with redirect_stdout(buf): + utils.list_diff("a", [], "b", [], max_items=10) + # No output expected (nothing to diff) + assert buf.getvalue() == "" + + +class TestGetAzureApiKey: + def test_plain_key_returned_as_is(self) -> None: + assert utils.get_azure_api_key("my-secret-key") == "my-secret-key" + + def test_uppercase_identity_not_plain(self) -> None: + # "IDENTITY" as a plain key is not routed to token provider; only "identity" + # (lowercased) triggers that path. Since we can't call the identity provider + # in tests, just verify non-identity keys pass through unchanged. + assert utils.get_azure_api_key("APIKEY123") == "APIKEY123" + + +class TestCreateAsyncOpenAIClient: + def test_no_keys_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + monkeypatch.delenv("AZURE_OPENAI_API_KEY", raising=False) + with pytest.raises(RuntimeError, match="Neither OPENAI_API_KEY"): + utils.create_async_openai_client() + + def test_openai_key_returns_async_openai( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + + monkeypatch.setenv("OPENAI_API_KEY", "sk-test") + monkeypatch.delenv("AZURE_OPENAI_API_KEY", raising=False) + client = utils.create_async_openai_client() + assert isinstance(client, AsyncOpenAI) + + def test_azure_key_returns_async_azure_openai( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "azure-key") + monkeypatch.setenv( + "AZURE_OPENAI_ENDPOINT", + "https://myhost.openai.azure.com/openai/deployments/gpt-4o?api-version=2025-01-01-preview", ) - result = utils.resolve_azure_model_name("gpt-4o", "TEST_ENDPOINT") - assert result == "gpt-4o-custom" + client = utils.create_async_openai_client() + assert isinstance(client, AsyncAzureOpenAI) + - def test_falls_back_to_model_name(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Falls back to the provided model_name when no deployment in the endpoint.""" +class TestMakeAgent: + def test_no_keys_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + monkeypatch.delenv("AZURE_OPENAI_API_KEY", raising=False) + with pytest.raises(RuntimeError, match="Neither OPENAI_API_KEY"): + utils.make_agent(str) + + +class TestResolveAzureModelName: + def test_returns_model_name_when_no_deployment( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: monkeypatch.setenv( - "TEST_ENDPOINT", "https://myhost.openai.azure.com?api-version=2024-06-01" + "AZURE_OPENAI_ENDPOINT", + "https://myhost.openai.azure.com/openai?api-version=2024-06-01", ) - result = utils.resolve_azure_model_name("gpt-4o", "TEST_ENDPOINT") + result = utils.resolve_azure_model_name("gpt-4o") assert result == "gpt-4o" - def test_uses_default_endpoint_envvar( + def test_returns_deployment_when_present( self, monkeypatch: pytest.MonkeyPatch ) -> None: - """Uses AZURE_OPENAI_ENDPOINT by default.""" monkeypatch.setenv( "AZURE_OPENAI_ENDPOINT", - "https://myhost.openai.azure.com/openai/deployments/my-deploy?api-version=2024-06-01", + "https://myhost.openai.azure.com/openai/deployments/my-deploy/chat?api-version=2024-06-01", ) result = utils.resolve_azure_model_name("gpt-4o") assert result == "my-deploy" From 91fd4fb9042fd36159d08722b02380c566dac78e Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Thu, 23 Apr 2026 14:14:04 +0200 Subject: [PATCH 2/8] enhance testcoverage - added new test modules for modules with missing tests cases --- tests/test_convutils.py | 60 +++ tests/test_email_message.py | 223 +++++++++++ tests/test_messageutils.py | 85 +++++ tests/test_searchlang_compile.py | 636 +++++++++++++++++++++++++++++++ 4 files changed, 1004 insertions(+) create mode 100644 tests/test_convutils.py create mode 100644 tests/test_email_message.py create mode 100644 tests/test_messageutils.py create mode 100644 tests/test_searchlang_compile.py diff --git a/tests/test_convutils.py b/tests/test_convutils.py new file mode 100644 index 00000000..b9ac654c --- /dev/null +++ b/tests/test_convutils.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import pytest + +from typeagent.knowpro.convutils import ( + get_time_range_for_conversation, + get_time_range_prompt_section_for_conversation, +) + +from conftest import FakeConversation, FakeMessage + + +class TestGetTimeRangeForConversation: + @pytest.mark.asyncio + async def test_empty_conversation_returns_none(self) -> None: + conv = FakeConversation(messages=[]) + result = await get_time_range_for_conversation(conv) + assert result is None + + @pytest.mark.asyncio + async def test_message_without_timestamp_returns_none(self) -> None: + msg = FakeMessage("hello") # no message_ordinal → timestamp=None + conv = FakeConversation(messages=[msg]) + result = await get_time_range_for_conversation(conv) + assert result is None + + @pytest.mark.asyncio + async def test_single_message_with_timestamp(self) -> None: + msg = FakeMessage("hello", message_ordinal=0) + conv = FakeConversation(messages=[msg]) + result = await get_time_range_for_conversation(conv) + assert result is not None + assert result.start.isoformat().startswith("2020-01-01T00") + + @pytest.mark.asyncio + async def test_multiple_messages_range_start_end(self) -> None: + msgs = [FakeMessage(f"msg{i}", message_ordinal=i) for i in range(3)] + conv = FakeConversation(messages=msgs) + result = await get_time_range_for_conversation(conv) + assert result is not None + assert result.start < result.end # type: ignore[operator] + + +class TestGetTimeRangePromptSection: + @pytest.mark.asyncio + async def test_no_timestamps_returns_none(self) -> None: + conv = FakeConversation(messages=[FakeMessage("hello")]) + result = await get_time_range_prompt_section_for_conversation(conv) + assert result is None + + @pytest.mark.asyncio + async def test_with_timestamps_returns_prompt_section(self) -> None: + msgs = [FakeMessage(f"msg{i}", message_ordinal=i) for i in range(2)] + conv = FakeConversation(messages=msgs) + result = await get_time_range_prompt_section_for_conversation(conv) + assert result is not None + assert result["role"] == "system" + assert "CONVERSATION TIME RANGE" in result["content"] + assert "2020-01-01" in result["content"] diff --git a/tests/test_email_message.py b/tests/test_email_message.py new file mode 100644 index 00000000..17930486 --- /dev/null +++ b/tests/test_email_message.py @@ -0,0 +1,223 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typeagent.emails.email_message import EmailMessage, EmailMessageMeta + + +def make_meta( + sender: str = "Alice ", + recipients: list[str] | None = None, + cc: list[str] | None = None, + bcc: list[str] | None = None, + subject: str | None = None, +) -> EmailMessageMeta: + return EmailMessageMeta( + sender=sender, + recipients=recipients or [], + cc=cc or [], + bcc=bcc or [], + subject=subject, + ) + + +class TestEmailMessageMetaProperties: + def test_source_returns_sender(self) -> None: + meta = make_meta(sender="bob@example.com") + assert meta.source == "bob@example.com" + + def test_dest_returns_recipients(self) -> None: + meta = make_meta(recipients=["a@b.com", "c@d.com"]) + assert meta.dest == ["a@b.com", "c@d.com"] + + def test_dest_empty_list(self) -> None: + meta = make_meta(recipients=[]) + assert meta.dest == [] + + +class TestEmailAddressToEntities: + def test_plain_address_no_display_name(self) -> None: + meta = make_meta() + entities = meta._email_address_to_entities("bob@example.com") + names = [e.name for e in entities] + assert "bob@example.com" in names + assert len(entities) == 1 + + def test_address_with_display_name(self) -> None: + meta = make_meta() + entities = meta._email_address_to_entities("Alice ") + names = [e.name for e in entities] + assert "Alice" in names + assert "alice@example.com" in names + assert len(entities) == 2 + + def test_display_name_entity_has_email_facet(self) -> None: + meta = make_meta() + entities = meta._email_address_to_entities("Alice ") + person_entity = next(e for e in entities if e.name == "Alice") + assert person_entity.facets is not None + assert len(person_entity.facets) == 1 + assert person_entity.facets[0].name == "email_address" + assert person_entity.facets[0].value == "alice@example.com" + + def test_display_name_only_no_address(self) -> None: + # parseaddr("Alice") returns ("", "Alice") — treated as address only + meta = make_meta() + entities = meta._email_address_to_entities("Alice") + # No display name, just the address "Alice" + assert len(entities) == 1 + assert entities[0].name == "Alice" + + +class TestToEntities: + def test_entities_include_sender(self) -> None: + meta = make_meta(sender="Alice ") + entities = meta.to_entities() + names = [e.name for e in entities] + assert "Alice" in names + assert "alice@example.com" in names + + def test_entities_include_recipient(self) -> None: + meta = make_meta( + sender="alice@example.com", + recipients=["Bob "], + ) + entities = meta.to_entities() + names = [e.name for e in entities] + assert "Bob" in names + assert "bob@example.com" in names + + def test_entities_include_cc(self) -> None: + meta = make_meta( + sender="a@x.com", + cc=["cc@example.com"], + ) + entities = meta.to_entities() + names = [e.name for e in entities] + assert "cc@example.com" in names + + def test_entities_include_bcc(self) -> None: + meta = make_meta( + sender="a@x.com", + bcc=["bcc@example.com"], + ) + entities = meta.to_entities() + names = [e.name for e in entities] + assert "bcc@example.com" in names + + def test_entities_always_include_email_message_entity(self) -> None: + meta = make_meta() + entities = meta.to_entities() + msg_entity = next((e for e in entities if e.name == "email"), None) + assert msg_entity is not None + assert "message" in msg_entity.type + + +class TestToTopics: + def test_no_subject_returns_empty(self) -> None: + meta = make_meta(subject=None) + assert meta.to_topics() == [] + + def test_subject_returned_as_topic(self) -> None: + meta = make_meta(subject="Hello World") + topics = meta.to_topics() + assert topics == ["Hello World"] + + +class TestToActions: + def test_no_recipients_returns_empty(self) -> None: + meta = make_meta(sender="alice@example.com", recipients=[]) + assert meta.to_actions() == [] + + def test_sent_and_received_actions_created(self) -> None: + meta = make_meta( + sender="Alice ", + recipients=["Bob "], + ) + actions = meta.to_actions() + verbs = [a.verbs[0] for a in actions] + assert "sent" in verbs + assert "received" in verbs + + def test_multiple_recipients_produce_actions(self) -> None: + meta = make_meta( + sender="alice@example.com", + recipients=["bob@example.com", "carol@example.com"], + ) + actions = meta.to_actions() + assert len(actions) > 0 + + def test_action_subject_is_sender(self) -> None: + meta = make_meta( + sender="alice@example.com", + recipients=["bob@example.com"], + ) + actions = meta.to_actions() + sent_actions = [a for a in actions if "sent" in a.verbs] + assert all(a.subject_entity_name == "alice@example.com" for a in sent_actions) + + +class TestGetKnowledge: + def test_get_knowledge_returns_response(self) -> None: + meta = make_meta( + sender="Alice ", + recipients=["Bob "], + subject="Test Subject", + ) + result = meta.get_knowledge() + assert result is not None + assert len(result.entities) > 0 + assert len(result.topics) > 0 + assert len(result.actions) > 0 + + +class TestEmailMessage: + def test_basic_construction(self) -> None: + meta = make_meta(sender="alice@example.com") + msg = EmailMessage( + text_chunks=["Hello world"], + metadata=meta, + ) + assert msg.text_chunks == ["Hello world"] + assert msg.metadata is meta + + def test_get_knowledge_delegates_to_metadata(self) -> None: + meta = make_meta( + sender="Alice ", + recipients=["bob@example.com"], + subject="Hi", + ) + msg = EmailMessage(text_chunks=["body"], metadata=meta) + result = msg.get_knowledge() + assert result is not None + + def test_add_timestamp(self) -> None: + meta = make_meta() + msg = EmailMessage(text_chunks=["body"], metadata=meta) + msg.add_timestamp("2025-01-01T00:00:00") + assert msg.timestamp == "2025-01-01T00:00:00" + + def test_add_content_empty_chunks(self) -> None: + meta = make_meta() + msg = EmailMessage(text_chunks=[], metadata=meta) + msg.add_content("new content") + assert msg.text_chunks == ["new content"] + + def test_add_content_existing_chunk(self) -> None: + meta = make_meta() + msg = EmailMessage(text_chunks=["existing"], metadata=meta) + msg.add_content(" more") + assert msg.text_chunks[0] == "existing more" + + def test_serialize_roundtrip(self) -> None: + meta = make_meta( + sender="Alice ", + recipients=["bob@example.com"], + subject="Hi", + ) + msg = EmailMessage(text_chunks=["Hello"], metadata=meta, tags=["work"]) + data = msg.serialize() + assert isinstance(data, dict) + restored = EmailMessage.deserialize(data) + assert restored.text_chunks == msg.text_chunks + assert restored.metadata.sender == msg.metadata.sender + assert restored.tags == msg.tags diff --git a/tests/test_messageutils.py b/tests/test_messageutils.py new file mode 100644 index 00000000..37b10d70 --- /dev/null +++ b/tests/test_messageutils.py @@ -0,0 +1,85 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import pytest + +from typeagent.knowpro.interfaces import TextLocation, TextRange +from typeagent.knowpro.messageutils import ( + get_message_chunk_batch, + text_range_from_message_chunk, +) +from typeagent.storage.memory.collections import MemoryMessageCollection + +from conftest import FakeMessage + + +class TestTextRangeFromMessageChunk: + def test_default_chunk_ordinal(self) -> None: + tr = text_range_from_message_chunk(message_ordinal=3) + assert tr.start == TextLocation(3, 0) + assert tr.end is None + + def test_explicit_chunk_ordinal(self) -> None: + tr = text_range_from_message_chunk(message_ordinal=5, chunk_ordinal=2) + assert tr.start == TextLocation(5, 2) + assert tr.end is None + + def test_returns_text_range(self) -> None: + tr = text_range_from_message_chunk(0) + assert isinstance(tr, TextRange) + + +class TestGetMessageChunkBatch: + @pytest.mark.asyncio + async def test_empty_collection(self) -> None: + messages: MemoryMessageCollection[FakeMessage] = MemoryMessageCollection() + batches = await get_message_chunk_batch(messages, 0, 10) + assert batches == [] + + @pytest.mark.asyncio + async def test_single_message_single_chunk(self) -> None: + messages: MemoryMessageCollection[FakeMessage] = MemoryMessageCollection( + [FakeMessage("hello")] + ) + batches = await get_message_chunk_batch(messages, 0, 10) + assert len(batches) == 1 + assert len(batches[0]) == 1 + assert batches[0][0] == TextLocation(0, 0) + + @pytest.mark.asyncio + async def test_message_with_multiple_chunks(self) -> None: + messages: MemoryMessageCollection[FakeMessage] = MemoryMessageCollection( + [FakeMessage(["chunk0", "chunk1", "chunk2"])] + ) + batches = await get_message_chunk_batch(messages, 0, 10) + assert len(batches) == 1 + locs = batches[0] + assert locs == [TextLocation(0, 0), TextLocation(0, 1), TextLocation(0, 2)] + + @pytest.mark.asyncio + async def test_batch_size_splits_across_messages(self) -> None: + messages: MemoryMessageCollection[FakeMessage] = MemoryMessageCollection( + [FakeMessage("a"), FakeMessage("b"), FakeMessage("c")] + ) + batches = await get_message_chunk_batch(messages, 0, batch_size=2) + assert len(batches) == 2 + assert len(batches[0]) == 2 + assert len(batches[1]) == 1 + + @pytest.mark.asyncio + async def test_exact_batch_size(self) -> None: + messages: MemoryMessageCollection[FakeMessage] = MemoryMessageCollection( + [FakeMessage("a"), FakeMessage("b")] + ) + batches = await get_message_chunk_batch(messages, 0, batch_size=2) + assert len(batches) == 1 + assert len(batches[0]) == 2 + + @pytest.mark.asyncio + async def test_start_offset_skips_earlier_messages(self) -> None: + messages: MemoryMessageCollection[FakeMessage] = MemoryMessageCollection( + [FakeMessage("skip"), FakeMessage("include")] + ) + batches = await get_message_chunk_batch(messages, 1, batch_size=10) + assert len(batches) == 1 + assert batches[0][0] == TextLocation(1, 0) diff --git a/tests/test_searchlang_compile.py b/tests/test_searchlang_compile.py new file mode 100644 index 00000000..c907393d --- /dev/null +++ b/tests/test_searchlang_compile.py @@ -0,0 +1,636 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for searchlang.py — compile_search_query, SearchQueryCompiler, +and related helper functions that don't require a live LLM.""" + +import datetime +from typing import Literal + +from typeagent.knowpro.date_time_schema import DateTime, DateTimeRange, DateVal, TimeVal +from typeagent.knowpro.interfaces import SearchTerm, SearchTermGroup +from typeagent.knowpro.search_query_schema import ( + ActionTerm, + EntityTerm, + FacetTerm, + SearchExpr, + SearchFilter, + SearchQuery, + VerbsTerm, +) +from typeagent.knowpro.searchlang import ( + _compile_fallback_query, + compile_search_filter, + compile_search_query, + date_range_from_datetime_range, + datetime_from_date_time, + is_entity_term_list, + LanguageQueryCompileOptions, + LanguageSearchFilter, + optimize_or_max, + SearchQueryCompiler, +) + +from conftest import FakeConversation + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_entity( + name: str, + types: list[str] | None = None, + facets: list[FacetTerm] | None = None, + is_pronoun: bool = False, +) -> EntityTerm: + return EntityTerm(name=name, is_name_pronoun=is_pronoun, type=types, facets=facets) + + +def make_action( + actor: list[EntityTerm] | Literal["*"] = "*", + verbs: list[str] | None = None, + targets: list[EntityTerm] | None = None, + additional: list[EntityTerm] | None = None, + is_informational: bool = False, +) -> ActionTerm: + return ActionTerm( + actor_entities=actor, + is_informational=is_informational, + action_verbs=VerbsTerm(words=verbs) if verbs else None, + target_entities=targets, + additional_entities=additional, + ) + + +def make_filter( + entities: list[EntityTerm] | None = None, + action: ActionTerm | None = None, + search_terms: list[str] | None = None, + time_range: DateTimeRange | None = None, +) -> SearchFilter: + return SearchFilter( + entity_search_terms=entities, + action_search_term=action, + search_terms=search_terms, + time_range=time_range, + ) + + +def make_query(filters: list[SearchFilter]) -> SearchQuery: + expr = SearchExpr( + rewritten_query="test query", + filters=filters, + ) + return SearchQuery(search_expressions=[expr]) + + +def make_compiler( + options: LanguageQueryCompileOptions | None = None, + lang_filter: LanguageSearchFilter | None = None, +) -> SearchQueryCompiler: + conv = FakeConversation() + return SearchQueryCompiler(conv, options, lang_filter) + + +# --------------------------------------------------------------------------- +# is_entity_term_list +# --------------------------------------------------------------------------- + + +class TestIsEntityTermList: + def test_list_returns_true(self) -> None: + terms = [make_entity("Alice")] + assert is_entity_term_list(terms) is True + + def test_empty_list_returns_true(self) -> None: + assert is_entity_term_list([]) is True + + def test_star_returns_false(self) -> None: + assert is_entity_term_list("*") is False + + def test_none_returns_false(self) -> None: + assert is_entity_term_list(None) is False + + +# --------------------------------------------------------------------------- +# optimize_or_max +# --------------------------------------------------------------------------- + + +class TestOptimizeOrMax: + def test_single_term_unwrapped(self) -> None: + inner = SearchTermGroup(boolean_op="and", terms=[]) + group = SearchTermGroup(boolean_op="or_max", terms=[inner]) + result = optimize_or_max(group) + assert result is inner + + def test_multiple_terms_kept_as_group(self) -> None: + inner1 = SearchTermGroup(boolean_op="and", terms=[]) + inner2 = SearchTermGroup(boolean_op="and", terms=[]) + group = SearchTermGroup(boolean_op="or_max", terms=[inner1, inner2]) + result = optimize_or_max(group) + assert result is group + + +# --------------------------------------------------------------------------- +# date_range_from_datetime_range / datetime_from_date_time +# --------------------------------------------------------------------------- + + +class TestDatetimeFromDateTime: + def test_date_only_zeros_time(self) -> None: + dt = datetime_from_date_time(DateTime(date=DateVal(day=15, month=6, year=2024))) + assert dt.year == 2024 + assert dt.month == 6 + assert dt.day == 15 + assert dt.hour == 0 + assert dt.minute == 0 + assert dt.second == 0 + assert dt.tzinfo == datetime.timezone.utc + + def test_with_time(self) -> None: + dt = datetime_from_date_time( + DateTime( + date=DateVal(day=1, month=1, year=2020), + time=TimeVal(hour=14, minute=30, seconds=45), + ) + ) + assert dt.hour == 14 + assert dt.minute == 30 + assert dt.second == 45 + + +class TestDateRangeFromDatetimeRange: + def test_start_only(self) -> None: + dtr = DateTimeRange( + start_date=DateTime(date=DateVal(day=1, month=1, year=2023)) + ) + dr = date_range_from_datetime_range(dtr) + assert dr.start.year == 2023 + assert dr.end is None + + def test_start_and_stop(self) -> None: + dtr = DateTimeRange( + start_date=DateTime(date=DateVal(day=1, month=1, year=2023)), + stop_date=DateTime(date=DateVal(day=31, month=12, year=2023)), + ) + dr = date_range_from_datetime_range(dtr) + assert dr.start.year == 2023 + assert dr.end is not None + assert dr.end.year == 2023 + assert dr.end.month == 12 + assert dr.end.day == 31 + + +# --------------------------------------------------------------------------- +# compile_search_query (standalone function) +# --------------------------------------------------------------------------- + + +class TestCompileSearchQuery: + def test_empty_search_expressions(self) -> None: + conv = FakeConversation() + query = SearchQuery(search_expressions=[]) + result = compile_search_query(conv, query) + assert result == [] + + def test_single_search_terms_filter(self) -> None: + conv = FakeConversation() + query = make_query([make_filter(search_terms=["robots", "AI"])]) + result = compile_search_query(conv, query) + assert len(result) == 1 + expr = result[0] + assert len(expr.select_expressions) == 1 + terms_in_group = expr.select_expressions[0].search_term_group.terms + assert any( + isinstance(t, SearchTerm) and t.term.text == "robots" for t in terms_in_group + ) + + def test_entity_filter_produces_expr(self) -> None: + conv = FakeConversation() + query = make_query([make_filter(entities=[make_entity("Alice", ["person"])])]) + result = compile_search_query(conv, query) + assert len(result) == 1 + + def test_multiple_filters_produce_multiple_select_exprs(self) -> None: + conv = FakeConversation() + filter1 = make_filter(search_terms=["alpha"]) + filter2 = make_filter(search_terms=["beta"]) + expr = SearchExpr(rewritten_query="test", filters=[filter1, filter2]) + query = SearchQuery(search_expressions=[expr]) + result = compile_search_query(conv, query) + assert len(result) == 1 + assert len(result[0].select_expressions) == 2 + + def test_raw_query_preserved(self) -> None: + conv = FakeConversation() + query = make_query([make_filter(search_terms=["foo"])]) + query.search_expressions[0].rewritten_query = "my rewritten query" + result = compile_search_query(conv, query) + assert result[0].raw_query == "my rewritten query" + + +# --------------------------------------------------------------------------- +# compile_search_filter (standalone function) +# --------------------------------------------------------------------------- + + +class TestCompileSearchFilter: + def test_entity_filter(self) -> None: + conv = FakeConversation() + f = make_filter(entities=[make_entity("Bob")]) + result = compile_search_filter(conv, f) + assert result.search_term_group is not None + + def test_search_terms_filter(self) -> None: + conv = FakeConversation() + f = make_filter(search_terms=["climate", "change"]) + result = compile_search_filter(conv, f) + terms = result.search_term_group.terms + assert len(terms) == 2 + + def test_empty_filter_uses_topic_wildcard(self) -> None: + """A filter with no entity, action, or search_terms should produce a topic:* term.""" + conv = FakeConversation() + f = SearchFilter() + result = compile_search_filter(conv, f) + # Should produce a single topic:* property search term + terms = result.search_term_group.terms + assert len(terms) == 1 + + def test_time_range_produces_when(self) -> None: + conv = FakeConversation() + dtr = DateTimeRange( + start_date=DateTime(date=DateVal(day=1, month=1, year=2024)) + ) + f = make_filter(search_terms=["foo"], time_range=dtr) + result = compile_search_filter(conv, f) + assert result.when is not None + assert result.when.date_range is not None + + def test_no_time_range_when_is_none(self) -> None: + conv = FakeConversation() + f = make_filter(search_terms=["foo"]) + result = compile_search_filter(conv, f) + assert result.when is None + + +# --------------------------------------------------------------------------- +# SearchQueryCompiler — compile_term_group and related +# --------------------------------------------------------------------------- + + +class TestSearchQueryCompilerTermGroup: + def test_search_terms_added(self) -> None: + compiler = make_compiler() + f = make_filter(search_terms=["hello", "world"]) + group = compiler.compile_term_group(f) + texts = [t.term.text for t in group.terms if isinstance(t, SearchTerm)] + assert "hello" in texts + assert "world" in texts + + def test_entity_name_added_as_property_term(self) -> None: + compiler = make_compiler() + f = make_filter(entities=[make_entity("Ada")]) + group = compiler.compile_term_group(f) + # Should have at least one term + assert len(group.terms) > 0 + + def test_empty_entity_name_ignored(self) -> None: + compiler = make_compiler() + f = make_filter(entities=[make_entity("")]) + group = compiler.compile_term_group(f) + # Empty string is not searchable; fallback to topic:* for empty term group + # (there are topic terms added for entity_terms in compile_entity_terms) + # We just check no crash and group is returned + assert group is not None + + def test_star_entity_name_ignored(self) -> None: + compiler = make_compiler() + f = make_filter(entities=[make_entity("*")]) + group = compiler.compile_term_group(f) + assert group is not None + + def test_noise_term_ignored(self) -> None: + compiler = make_compiler() + f = make_filter(search_terms=["thing", "object", "hello"]) + group = compiler.compile_term_group(f) + texts = [t.term.text for t in group.terms if isinstance(t, SearchTerm)] + # noise terms filtered from property groups but not from search_terms path + # search_terms path does NOT call add_property_term_to_group + assert "hello" in texts + + def test_custom_term_filter_excludes_property_terms(self) -> None: + # term_filter applies to add_property_term_to_group, not compile_search_terms. + options = LanguageQueryCompileOptions(term_filter=lambda t: t != "excluded") + compiler = make_compiler(options=options) + group = SearchTermGroup(boolean_op="or", terms=[]) + compiler.add_property_term_to_group("name", "excluded", group) + compiler.add_property_term_to_group("name", "included", group) + assert len(group.terms) == 1 + + +# --------------------------------------------------------------------------- +# SearchQueryCompiler — entity terms with facets +# --------------------------------------------------------------------------- + + +class TestEntityTermsWithFacets: + def test_entity_with_type(self) -> None: + compiler = make_compiler() + entity = make_entity("Alice", types=["person"]) + f = make_filter(entities=[entity]) + group = compiler.compile_term_group(f) + assert len(group.terms) > 0 + + def test_entity_with_facet_name_and_value(self) -> None: + compiler = make_compiler() + facet = FacetTerm(facet_name="profession", facet_value="writer") + entity = make_entity("Bob", facets=[facet]) + f = make_filter(entities=[entity]) + group = compiler.compile_term_group(f) + assert len(group.terms) > 0 + + def test_entity_with_wildcard_facet_value(self) -> None: + compiler = make_compiler() + facet = FacetTerm(facet_name="profession", facet_value="*") + entity = make_entity("Bob", facets=[facet]) + f = make_filter(entities=[entity]) + group = compiler.compile_term_group(f) + assert len(group.terms) > 0 + + def test_entity_with_wildcard_facet_name(self) -> None: + compiler = make_compiler() + facet = FacetTerm(facet_name="*", facet_value="writer") + entity = make_entity("Bob", facets=[facet]) + f = make_filter(entities=[entity]) + group = compiler.compile_term_group(f) + assert len(group.terms) > 0 + + def test_entity_with_both_wildcards_no_facet_term(self) -> None: + compiler = make_compiler() + facet = FacetTerm(facet_name="*", facet_value="*") + entity = make_entity("Bob", facets=[facet]) + f = make_filter(entities=[entity]) + group = compiler.compile_term_group(f) + # Both wildcards => no facet term added; entity name term still present + assert len(group.terms) >= 0 # Just no crash + + def test_pronoun_entity_skipped(self) -> None: + compiler = make_compiler() + pronoun = make_entity("it", is_pronoun=True) + normal = make_entity("Alice") + f = make_filter(entities=[pronoun, normal]) + group = compiler.compile_term_group(f) + # Only Alice's term should be added + assert len(group.terms) > 0 + + +# --------------------------------------------------------------------------- +# SearchQueryCompiler — action terms +# --------------------------------------------------------------------------- + + +class TestActionTerms: + def test_action_with_verbs_adds_verb_terms(self) -> None: + compiler = make_compiler() + actor = make_entity("Alice") + action = make_action(actor=[actor], verbs=["sent", "emailed"]) + f = make_filter(action=action) + group = compiler.compile_term_group(f) + assert len(group.terms) > 0 + + def test_action_with_target_entities(self) -> None: + compiler = make_compiler() + actor = make_entity("Alice") + target = make_entity("Bob") + action = make_action(actor=[actor], verbs=["sent"], targets=[target]) + f = make_filter(action=action) + group = compiler.compile_term_group(f) + assert len(group.terms) > 0 + + def test_action_with_additional_entities(self) -> None: + compiler = make_compiler() + actor = make_entity("Alice") + extra = make_entity("Charlie") + action = make_action(actor=[actor], verbs=["spoke"], additional=[extra]) + f = make_filter(action=action) + group = compiler.compile_term_group(f) + assert len(group.terms) > 0 + + def test_action_star_actor_no_scope(self) -> None: + """When actor_entities is '*', scope is not applied.""" + action = make_action(actor="*", verbs=["played"]) + f = make_filter(action=action) + result = compile_search_filter(FakeConversation(), f) + # should have no scope (when is None or when.scope_defining_terms is empty) + when = result.when + assert when is None or ( + when.scope_defining_terms is None + or len(when.scope_defining_terms.terms) == 0 + ) + + +# --------------------------------------------------------------------------- +# SearchQueryCompiler — compile_when with scope +# --------------------------------------------------------------------------- + + +class TestCompileWhen: + def test_no_action_no_when(self) -> None: + compiler = make_compiler() + f = make_filter(search_terms=["foo"]) + when = compiler.compile_when(f) + assert when is None + + def test_time_range_produces_date_range(self) -> None: + compiler = make_compiler() + dtr = DateTimeRange( + start_date=DateTime(date=DateVal(day=1, month=3, year=2025)), + stop_date=DateTime(date=DateVal(day=31, month=3, year=2025)), + ) + f = make_filter(search_terms=["foo"], time_range=dtr) + when = compiler.compile_when(f) + assert when is not None + assert when.date_range is not None + assert when.date_range.start.month == 3 + + def test_informational_action_no_scope(self) -> None: + compiler = make_compiler() + actor = make_entity("Alice") + action = make_action(actor=[actor], verbs=["spoke"], is_informational=True) + f = make_filter(action=action) + when = compiler.compile_when(f) + # is_informational = True → should_add_scope returns False → no scope in when + assert when is None or ( + when.scope_defining_terms is None + or len(when.scope_defining_terms.terms) == 0 + ) + + def test_actor_entities_list_adds_scope(self) -> None: + compiler = make_compiler() + actor = make_entity("Alice") + action = make_action(actor=[actor], verbs=["sent"]) + f = make_filter(action=action) + when = compiler.compile_when(f) + assert when is not None + assert when.scope_defining_terms is not None + assert len(when.scope_defining_terms.terms) > 0 + + +# --------------------------------------------------------------------------- +# SearchQueryCompiler — compile_search_terms +# --------------------------------------------------------------------------- + + +class TestCompileSearchTerms: + def test_returns_search_term_group(self) -> None: + compiler = make_compiler() + group = compiler.compile_search_terms(["alpha", "beta"]) + texts = [t.term.text for t in group.terms if isinstance(t, SearchTerm)] + assert "alpha" in texts + assert "beta" in texts + + def test_appends_to_existing_group(self) -> None: + compiler = make_compiler() + existing = SearchTermGroup(boolean_op="or", terms=[]) + compiler.compile_search_terms(["gamma"], existing) + texts = [t.term.text for t in existing.terms if isinstance(t, SearchTerm)] + assert "gamma" in texts + + +# --------------------------------------------------------------------------- +# SearchQueryCompiler — is_searchable_string / is_noise_term +# --------------------------------------------------------------------------- + + +class TestIsSearchableString: + def test_normal_string_is_searchable(self) -> None: + compiler = make_compiler() + assert compiler.is_searchable_string("hello") is True + + def test_empty_string_not_searchable(self) -> None: + compiler = make_compiler() + assert compiler.is_searchable_string("") is False + + def test_star_not_searchable(self) -> None: + compiler = make_compiler() + assert compiler.is_searchable_string("*") is False + + def test_term_filter_respected(self) -> None: + options = LanguageQueryCompileOptions(term_filter=lambda t: t != "skip") + compiler = make_compiler(options=options) + assert compiler.is_searchable_string("skip") is False + assert compiler.is_searchable_string("keep") is True + + +class TestIsNoiseTerm: + def test_noise_words(self) -> None: + compiler = make_compiler() + for word in ("thing", "object", "concept", "idea", "entity"): + assert compiler.is_noise_term(word) is True + + def test_non_noise_word(self) -> None: + compiler = make_compiler() + assert compiler.is_noise_term("robot") is False + + def test_case_insensitive(self) -> None: + compiler = make_compiler() + assert compiler.is_noise_term("THING") is True + + +# --------------------------------------------------------------------------- +# SearchQueryCompiler — deduplication +# --------------------------------------------------------------------------- + + +class TestDeduplication: + def test_duplicate_property_term_not_added_twice(self) -> None: + compiler = make_compiler() + group = SearchTermGroup(boolean_op="or", terms=[]) + compiler.add_property_term_to_group("name", "Alice", group) + compiler.add_property_term_to_group("name", "Alice", group) + assert len(group.terms) == 1 + + def test_different_property_names_both_added(self) -> None: + compiler = make_compiler() + group = SearchTermGroup(boolean_op="or", terms=[]) + compiler.add_property_term_to_group("name", "Alice", group) + compiler.add_property_term_to_group("topic", "Alice", group) + assert len(group.terms) == 2 + + def test_dedupe_disabled_allows_duplicates(self) -> None: + compiler = make_compiler() + compiler.dedupe = False + group = SearchTermGroup(boolean_op="or", terms=[]) + compiler.add_property_term_to_group("name", "Alice", group) + compiler.add_property_term_to_group("name", "Alice", group) + assert len(group.terms) == 2 + + +# --------------------------------------------------------------------------- +# _compile_fallback_query +# --------------------------------------------------------------------------- + + +class TestCompileFallbackQuery: + def test_exact_scope_no_fallback(self) -> None: + conv = FakeConversation() + options = LanguageQueryCompileOptions(exact_scope=True, verb_scope=True) + query = make_query([make_filter(search_terms=["foo"])]) + result = _compile_fallback_query(conv, query, options) + assert result is None + + def test_no_verb_scope_no_fallback(self) -> None: + conv = FakeConversation() + options = LanguageQueryCompileOptions(exact_scope=False, verb_scope=False) + query = make_query([make_filter(search_terms=["foo"])]) + result = _compile_fallback_query(conv, query, options) + assert result is None + + def test_verb_scope_and_not_exact_produces_fallback(self) -> None: + conv = FakeConversation() + options = LanguageQueryCompileOptions(exact_scope=False, verb_scope=True) + query = make_query([make_filter(search_terms=["foo"])]) + result = _compile_fallback_query(conv, query, options) + # Should return a list of SearchQueryExpr (fallback without verb matching) + assert result is not None + assert isinstance(result, list) + assert len(result) == 1 + + +# --------------------------------------------------------------------------- +# SearchQueryCompiler — compile_action_term_as_search_terms (use_or_max=False) +# --------------------------------------------------------------------------- + + +class TestCompileActionTermAsSearchTerms: + def test_no_verbs_no_actor_empty_group(self) -> None: + compiler = make_compiler() + action = ActionTerm( + actor_entities="*", + is_informational=False, + ) + group = compiler.compile_action_term_as_search_terms(action, use_or_max=False) + # actor is "*" so no actor entities; no verbs; result depends on implementation + assert group is not None + + def test_use_or_max_false_merges_into_same_group(self) -> None: + compiler = make_compiler() + actor = make_entity("Alice") + action = make_action(actor=[actor], verbs=["sent"]) + group = compiler.compile_action_term_as_search_terms(action, use_or_max=False) + assert len(group.terms) > 0 + + def test_empty_or_max_not_appended(self) -> None: + """With use_or_max=True but no verbs/actors, or_max wrapper should not be appended.""" + compiler = make_compiler() + action = ActionTerm( + actor_entities="*", + is_informational=False, + ) + outer = SearchTermGroup(boolean_op="or", terms=[]) + compiler.compile_action_term_as_search_terms(action, outer, use_or_max=True) + # or_max only appended if non-empty + assert len(outer.terms) == 0 From c53e96b2852f02e3616e379bfdc92a2e89dc8b23 Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Thu, 23 Apr 2026 19:05:42 +0200 Subject: [PATCH 3/8] added additional tests --- tests/test_convthreads.py | 123 +++++++++++++++++++++ tests/test_memory_semrefindex.py | 180 +++++++++++++++++++++++++++++++ tests/test_search.py | 114 ++++++++++++++++++++ tests/test_searchlang_compile.py | 3 +- tests/test_serialization.py | 172 +++++++++++++++++++++++++++++ tests/test_textlocindex.py | 146 +++++++++++++++++++++++++ 6 files changed, 737 insertions(+), 1 deletion(-) create mode 100644 tests/test_convthreads.py create mode 100644 tests/test_memory_semrefindex.py create mode 100644 tests/test_search.py create mode 100644 tests/test_textlocindex.py diff --git a/tests/test_convthreads.py b/tests/test_convthreads.py new file mode 100644 index 00000000..9cbdaa1f --- /dev/null +++ b/tests/test_convthreads.py @@ -0,0 +1,123 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for storage/memory/convthreads.py.""" + +import pytest + +from typeagent.aitools.model_adapters import create_test_embedding_model +from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings +from typeagent.knowpro.interfaces import TextLocation, TextRange, Thread +from typeagent.storage.memory.convthreads import ConversationThreads + + +@pytest.fixture +def settings() -> TextEmbeddingIndexSettings: + return TextEmbeddingIndexSettings(create_test_embedding_model()) + + +@pytest.fixture +def threads(settings: TextEmbeddingIndexSettings) -> ConversationThreads: + return ConversationThreads(settings) + + +def make_thread(description: str, start: int = 0, end: int = 1) -> Thread: + return Thread( + description=description, + ranges=[ + TextRange(start=TextLocation(start), end=TextLocation(end)), + ], + ) + + +@pytest.mark.asyncio +async def test_add_thread_appends(threads: ConversationThreads) -> None: + await threads.add_thread(make_thread("topic one")) + assert len(threads.threads) == 1 + assert threads.threads[0].description == "topic one" + + +@pytest.mark.asyncio +async def test_add_multiple_threads(threads: ConversationThreads) -> None: + await threads.add_thread(make_thread("alpha")) + await threads.add_thread(make_thread("beta")) + await threads.add_thread(make_thread("gamma")) + assert len(threads.threads) == 3 + + +@pytest.mark.asyncio +async def test_clear_resets_state(threads: ConversationThreads) -> None: + await threads.add_thread(make_thread("something")) + threads.clear() + assert len(threads.threads) == 0 + assert len(threads.vector_base) == 0 + + +@pytest.mark.asyncio +async def test_build_index_rebuilds_from_threads(threads: ConversationThreads) -> None: + # Manually add threads without building the vector index. + t1 = make_thread("python programming") + t2 = make_thread("data science") + threads.threads.append(t1) + threads.threads.append(t2) + # build_index should embed all existing threads. + await threads.build_index() + assert len(threads.vector_base) == 2 + + +@pytest.mark.asyncio +async def test_serialize_roundtrip(threads: ConversationThreads) -> None: + await threads.add_thread(make_thread("episode one", 0, 5)) + await threads.add_thread(make_thread("episode two", 5, 10)) + + data = threads.serialize() + assert "threads" in data + thread_list = data["threads"] + assert thread_list is not None + assert len(thread_list) == 2 + + # Deserialize into a fresh instance. + settings = TextEmbeddingIndexSettings(create_test_embedding_model()) + fresh = ConversationThreads(settings) + fresh.deserialize(data) + assert len(fresh.threads) == 2 + assert fresh.threads[0].description == "episode one" + assert fresh.threads[1].description == "episode two" + + +@pytest.mark.asyncio +async def test_deserialize_empty_data(threads: ConversationThreads) -> None: + from typeagent.knowpro.interfaces import ConversationThreadData + + data: ConversationThreadData = {} # type: ignore[typeddict-item] + threads.deserialize(data) + assert len(threads.threads) == 0 + + +@pytest.mark.asyncio +async def test_serialize_without_embeddings(threads: ConversationThreads) -> None: + # Add a thread without going through add_thread (so no embedding yet). + threads.threads.append(make_thread("bare thread")) + data = threads.serialize() + thread_list = data["threads"] + assert thread_list is not None + assert len(thread_list) == 1 + # Embedding may be None because vector_base has no entries for this slot. + assert thread_list[0]["embedding"] is None or isinstance( + thread_list[0]["embedding"], list + ) + + +@pytest.mark.asyncio +async def test_lookup_thread_returns_matches(threads: ConversationThreads) -> None: + await threads.add_thread(make_thread("machine learning and AI")) + await threads.add_thread(make_thread("cooking recipes")) + results = await threads.lookup_thread("artificial intelligence") + # Should return at least one result (exact scoring depends on embedding model). + assert isinstance(results, list) + + +@pytest.mark.asyncio +async def test_lookup_thread_empty_index(threads: ConversationThreads) -> None: + results = await threads.lookup_thread("anything") + assert results == [] diff --git a/tests/test_memory_semrefindex.py b/tests/test_memory_semrefindex.py new file mode 100644 index 00000000..5044d42e --- /dev/null +++ b/tests/test_memory_semrefindex.py @@ -0,0 +1,180 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for storage/memory/semrefindex.py helper functions.""" + +import pytest + +from typeagent.knowpro import knowledge_schema as kplib +from typeagent.knowpro.interfaces import Topic +from typeagent.storage.memory import MemorySemanticRefCollection +from typeagent.storage.memory.semrefindex import ( + add_action, + add_entity, + add_facet, + add_term_to_index, + add_topic, +) + +from conftest import FakeTermIndex + + +def make_semrefs() -> MemorySemanticRefCollection: + return MemorySemanticRefCollection([]) + + +def make_index() -> FakeTermIndex: + return FakeTermIndex() + + +# --------------------------------------------------------------------------- +# add_term_to_index +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_add_term_to_index_basic() -> None: + index = make_index() + terms_added: set[str] = set() + await add_term_to_index(index, "hello", 0, terms_added) + assert "hello" in terms_added + assert await index.size() == 1 + + +@pytest.mark.asyncio +async def test_add_term_to_index_no_terms_added_set() -> None: + index = make_index() + await add_term_to_index(index, "world", 1) + assert await index.size() == 1 + + +@pytest.mark.asyncio +async def test_add_term_empty_string_is_stored() -> None: + """The function does not filter empty terms — delegated to the index.""" + index = make_index() + await add_term_to_index(index, "", 0) + # FakeTermIndex stores empty strings too + assert await index.size() == 1 + + +# --------------------------------------------------------------------------- +# add_facet +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_add_facet_none_does_nothing() -> None: + index = make_index() + await add_facet(None, 0, index) + assert await index.size() == 0 + + +@pytest.mark.asyncio +async def test_add_facet_string_value() -> None: + index = make_index() + facet = kplib.Facet(name="colour", value="red") + await add_facet(facet, 0, index) + terms = await index.get_terms() + assert "colour" in terms + assert "red" in terms + + +@pytest.mark.asyncio +async def test_add_facet_numeric_value() -> None: + index = make_index() + facet = kplib.Facet(name="count", value=42.0) + await add_facet(facet, 0, index) + terms = await index.get_terms() + assert "count" in terms + assert "42.0" in terms + + +# --------------------------------------------------------------------------- +# add_entity +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_add_entity_registers_name_and_types() -> None: + semrefs = make_semrefs() + index = make_index() + entity = kplib.ConcreteEntity(name="Alice", type=["person", "employee"]) + terms_added: set[str] = set() + await add_entity(entity, semrefs, index, message_ordinal=0, terms_added=terms_added) + assert "Alice" in terms_added + assert "person" in terms_added + assert "employee" in terms_added + assert await semrefs.size() == 1 + + +@pytest.mark.asyncio +async def test_add_entity_with_facets() -> None: + semrefs = make_semrefs() + index = make_index() + entity = kplib.ConcreteEntity( + name="Book", + type=["item"], + facets=[kplib.Facet(name="genre", value="fiction")], + ) + await add_entity(entity, semrefs, index, message_ordinal=1) + terms = await index.get_terms() + assert "genre" in terms + assert "fiction" in terms + + +# --------------------------------------------------------------------------- +# add_topic +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_add_topic_registers_text() -> None: + semrefs = make_semrefs() + index = make_index() + topic = Topic(text="machine learning") + terms_added: set[str] = set() + await add_topic(topic, semrefs, index, message_ordinal=2, terms_added=terms_added) + assert "machine learning" in terms_added + assert await semrefs.size() == 1 + + +# --------------------------------------------------------------------------- +# add_action +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_add_action_registers_verbs() -> None: + semrefs = make_semrefs() + index = make_index() + action = kplib.Action( + verbs=["run", "execute"], + verb_tense="present", + subject_entity_name="Alice", + object_entity_name="script", + indirect_object_entity_name="none", + ) + terms_added: set[str] = set() + await add_action(action, semrefs, index, message_ordinal=0, terms_added=terms_added) + terms = set(await index.get_terms()) + assert "run execute" in terms + assert "Alice" in terms + assert "script" in terms + assert await semrefs.size() == 1 + + +@pytest.mark.asyncio +async def test_add_action_none_entities_skipped() -> None: + semrefs = make_semrefs() + index = make_index() + action = kplib.Action( + verbs=["go"], + verb_tense="present", + subject_entity_name="none", + object_entity_name="none", + indirect_object_entity_name="none", + ) + await add_action(action, semrefs, index, message_ordinal=0) + terms = await index.get_terms() + assert "none" not in terms + assert "go" in terms diff --git a/tests/test_search.py b/tests/test_search.py new file mode 100644 index 00000000..abb94bc2 --- /dev/null +++ b/tests/test_search.py @@ -0,0 +1,114 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for knowpro/search.py — SearchOptions, ConversationSearchResult.""" + +import pytest + +from typeagent.knowpro.interfaces import ( + SearchTerm, + SearchTermGroup, + Term, +) +from typeagent.knowpro.query import is_conversation_searchable +from typeagent.knowpro.search import ( + ConversationSearchResult, + search_conversation_knowledge, + SearchOptions, +) + +from conftest import FakeConversation, FakeMessage, FakeTermIndex + +# --------------------------------------------------------------------------- +# SearchOptions +# --------------------------------------------------------------------------- + + +def test_search_options_defaults() -> None: + opts = SearchOptions() + assert opts.max_knowledge_matches is None + assert opts.exact_match is False + assert opts.max_message_matches is None + assert opts.max_chars_in_budget is None + assert opts.threshold_score is None + + +def test_search_options_repr_empty() -> None: + opts = SearchOptions() + # Only non-None values appear in repr; exact_match=False is still included. + r = repr(opts) + assert r.startswith("SearchOptions(") + + +def test_search_options_repr_with_fields() -> None: + opts = SearchOptions(max_knowledge_matches=5, exact_match=True) + r = repr(opts) + assert "max_knowledge_matches=5" in r + assert "exact_match=True" in r + + +# --------------------------------------------------------------------------- +# ConversationSearchResult +# --------------------------------------------------------------------------- + + +def test_conversation_search_result_basic() -> None: + from typeagent.knowpro.interfaces import ScoredMessageOrdinal + + result = ConversationSearchResult( + message_matches=[ScoredMessageOrdinal(0, 0.9)], + knowledge_matches={}, + raw_query_text="test", + ) + assert len(result.message_matches) == 1 + assert result.raw_query_text == "test" + + +def test_conversation_search_result_defaults() -> None: + result = ConversationSearchResult(message_matches=[], knowledge_matches={}) + assert result.raw_query_text is None + + +# --------------------------------------------------------------------------- +# is_conversation_searchable (from query.py, used heavily in search.py) +# --------------------------------------------------------------------------- + + +def test_is_conversation_searchable_true() -> None: + conv = FakeConversation( + messages=[FakeMessage("hello", 0)], + has_secondary_indexes=False, + ) + conv.semantic_ref_index = FakeTermIndex() + assert is_conversation_searchable(conv) is True + + +def test_is_conversation_searchable_no_index() -> None: + conv = FakeConversation(has_secondary_indexes=False) + conv.semantic_ref_index = None + assert is_conversation_searchable(conv) is False + + +def test_is_conversation_searchable_no_semrefs() -> None: + conv = FakeConversation(has_secondary_indexes=False) + conv.semantic_refs = None # type: ignore[assignment] + assert is_conversation_searchable(conv) is False + + +# --------------------------------------------------------------------------- +# search_conversation_knowledge returns None when not searchable +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_search_conversation_knowledge_non_searchable_returns_none() -> None: + """When the conversation has no semantic ref index, result should be None.""" + conv = FakeConversation(has_secondary_indexes=False) + conv.semantic_ref_index = None + + group = SearchTermGroup( + boolean_op="or", + terms=[SearchTerm(term=Term("hello"))], + ) + result = await search_conversation_knowledge(conv, group) + assert result is None diff --git a/tests/test_searchlang_compile.py b/tests/test_searchlang_compile.py index c907393d..797588dc 100644 --- a/tests/test_searchlang_compile.py +++ b/tests/test_searchlang_compile.py @@ -204,7 +204,8 @@ def test_single_search_terms_filter(self) -> None: assert len(expr.select_expressions) == 1 terms_in_group = expr.select_expressions[0].search_term_group.terms assert any( - isinstance(t, SearchTerm) and t.term.text == "robots" for t in terms_in_group + isinstance(t, SearchTerm) and t.term.text == "robots" + for t in terms_in_group ) def test_entity_filter_produces_expr(self) -> None: diff --git a/tests/test_serialization.py b/tests/test_serialization.py index eec08fda..5e1f1a47 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -17,6 +17,7 @@ from typeagent.knowpro.serialization import ( create_file_header, DeserializationError, + deserialize_knowledge, deserialize_object, from_conversation_file_data, serialize_embeddings, @@ -133,3 +134,174 @@ def test_deserialization_error(): """Test that DeserializationError is raised for invalid data.""" with pytest.raises(DeserializationError, match="Pydantic validation failed"): deserialize_object(Quantity, {"invalid_key": "value"}) + + +# --------------------------------------------------------------------------- +# Additional tests for broader coverage +# --------------------------------------------------------------------------- + + +def test_from_conversation_file_data_missing_header_raises(): + """from_conversation_file_data raises when fileHeader is absent.""" + from typeagent.knowpro.serialization import ( + ConversationBinaryData, + ConversationFileData, + ConversationJsonData, + ) + + json_data: ConversationJsonData[Any] = ConversationJsonData( + nameTag="x", messages=[], tags=[], semanticRefs=None + ) + file_data: ConversationFileData[Any] = ConversationFileData( + jsonData=json_data, + binaryData=ConversationBinaryData(embeddingsList=[]), + ) + with pytest.raises(DeserializationError, match="Missing file header"): + from_conversation_file_data(file_data) + + +def test_from_conversation_file_data_bad_version_raises(): + """from_conversation_file_data raises on unsupported version.""" + from typeagent.knowpro.serialization import ( + ConversationBinaryData, + ConversationFileData, + ConversationJsonData, + ) + + json_data: ConversationJsonData[Any] = ConversationJsonData( + nameTag="x", + messages=[], + tags=[], + semanticRefs=None, + fileHeader={"version": "99.9"}, + embeddingFileHeader={}, + ) + file_data: ConversationFileData[Any] = ConversationFileData( + jsonData=json_data, + binaryData=ConversationBinaryData(embeddingsList=[]), + ) + with pytest.raises(DeserializationError, match="Unsupported file version"): + from_conversation_file_data(file_data) + + +def test_from_conversation_file_data_missing_embedding_header_raises(): + """from_conversation_file_data raises when embeddingFileHeader is absent.""" + from typeagent.knowpro.serialization import ( + ConversationBinaryData, + ConversationFileData, + ConversationJsonData, + ) + + json_data: ConversationJsonData[Any] = ConversationJsonData( + nameTag="x", + messages=[], + tags=[], + semanticRefs=None, + fileHeader={"version": "0.1"}, + ) + file_data: ConversationFileData[Any] = ConversationFileData( + jsonData=json_data, + binaryData=ConversationBinaryData(embeddingsList=[]), + ) + with pytest.raises(DeserializationError, match="Missing embedding file header"): + from_conversation_file_data(file_data) + + +def test_from_conversation_file_data_missing_embeddings_list_raises(): + """from_conversation_file_data raises when embeddingsList is None.""" + from typeagent.knowpro.serialization import ( + ConversationBinaryData, + ConversationFileData, + ConversationJsonData, + ) + + json_data: ConversationJsonData[Any] = ConversationJsonData( + nameTag="x", + messages=[], + tags=[], + semanticRefs=None, + fileHeader={"version": "0.1"}, + embeddingFileHeader={}, + ) + file_data: ConversationFileData[Any] = ConversationFileData( + jsonData=json_data, + binaryData=ConversationBinaryData(embeddingsList=None), + ) + with pytest.raises(DeserializationError, match="Missing embeddings list"): + from_conversation_file_data(file_data) + + +def test_from_conversation_file_data_success_empty(): + """from_conversation_file_data succeeds with minimal valid data.""" + from typeagent.knowpro.serialization import ( + ConversationBinaryData, + ConversationFileData, + ConversationJsonData, + ) + + emb = np.zeros((0, 4), dtype=np.float32) + json_data: ConversationJsonData[Any] = ConversationJsonData( + nameTag="test", + messages=[], + tags=[], + semanticRefs=None, + fileHeader={"version": "0.1"}, + embeddingFileHeader={}, + ) + file_data: ConversationFileData[Any] = ConversationFileData( + jsonData=json_data, + binaryData=ConversationBinaryData(embeddingsList=[emb]), + ) + result = from_conversation_file_data(file_data) + assert result["nameTag"] == "test" + + +def test_is_primitive(): + """Test is_primitive classification.""" + from typeagent.knowpro.serialization import is_primitive + + for t in (int, float, bool, str, type(None)): + assert is_primitive(t), f"Expected {t} to be primitive" + assert not is_primitive(list) + assert not is_primitive(dict) + + +def test_deserialize_object_union_none(): + """deserialize_object handles optional (X | None) type with None input.""" + result = deserialize_object(int | None, None) + assert result is None + + +def test_deserialize_object_list_of_int(): + """deserialize_object can deserialize a list of ints.""" + result = deserialize_object(list[int], [1, 2, 3]) + assert result == [1, 2, 3] + + +def test_deserialize_knowledge_entity(): + """deserialize_knowledge reconstructs a ConcreteEntity.""" + from typeagent.knowpro.knowledge_schema import ConcreteEntity + + obj = {"name": "Bob", "type": ["person"]} + result = deserialize_knowledge("entity", obj) + assert isinstance(result, ConcreteEntity) + assert result.name == "Bob" + + +def test_deserialize_knowledge_topic(): + """deserialize_knowledge reconstructs a Topic.""" + from typeagent.knowpro.interfaces import Topic + + obj = {"text": "AI ethics"} + result = deserialize_knowledge("topic", obj) + assert isinstance(result, Topic) + assert result.text == "AI ethics" + + +def test_deserialize_knowledge_tag(): + """deserialize_knowledge reconstructs a Tag.""" + from typeagent.knowpro.interfaces import Tag + + obj = {"text": "important"} + result = deserialize_knowledge("tag", obj) + assert isinstance(result, Tag) diff --git a/tests/test_textlocindex.py b/tests/test_textlocindex.py new file mode 100644 index 00000000..a9e6454f --- /dev/null +++ b/tests/test_textlocindex.py @@ -0,0 +1,146 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for knowpro/textlocindex.py (TextToTextLocationIndex).""" + +import numpy as np +import pytest + +from typeagent.aitools.model_adapters import create_test_embedding_model +from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings +from typeagent.knowpro.interfaces import TextLocation, TextToTextLocationIndexData +from typeagent.knowpro.textlocindex import TextToTextLocationIndex + + +@pytest.fixture +def settings() -> TextEmbeddingIndexSettings: + return TextEmbeddingIndexSettings(create_test_embedding_model()) + + +@pytest.fixture +def index(settings: TextEmbeddingIndexSettings) -> TextToTextLocationIndex: + return TextToTextLocationIndex(settings) + + +# --------------------------------------------------------------------------- +# Empty index +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_empty_size(index: TextToTextLocationIndex) -> None: + assert await index.size() == 0 + + +@pytest.mark.asyncio +async def test_empty_is_empty(index: TextToTextLocationIndex) -> None: + assert await index.is_empty() + + +def test_get_out_of_range_returns_default(index: TextToTextLocationIndex) -> None: + assert index.get(0) is None + assert index.get(-1) is None + assert index.get(0, TextLocation(99)) == TextLocation(99) + + +# --------------------------------------------------------------------------- +# clear() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_clear_resets(index: TextToTextLocationIndex) -> None: + loc = TextLocation(message_ordinal=0) + await index.add_text_location("hello world", loc) + assert await index.size() == 1 + index.clear() + assert await index.size() == 0 + assert await index.is_empty() + + +# --------------------------------------------------------------------------- +# serialize / deserialize round-trip (no real embeddings needed) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_serialize_empty(index: TextToTextLocationIndex) -> None: + data = index.serialize() + assert data["textLocations"] == [] + # embeddings may be None or an empty ndarray + emb = data["embeddings"] + assert emb is None or (hasattr(emb, "shape") and emb.size == 0) + + +def test_deserialize_raises_on_no_embeddings( + index: TextToTextLocationIndex, +) -> None: + data: TextToTextLocationIndexData = { + "textLocations": [{"messageOrdinal": 0, "chunkOrdinal": 0}], + "embeddings": None, + } + with pytest.raises(ValueError, match="No embeddings found"): + index.deserialize(data) + + +def test_deserialize_raises_on_length_mismatch( + index: TextToTextLocationIndex, settings: TextEmbeddingIndexSettings +) -> None: + # The test embedding model uses size 3 by default. + emb_size = 3 + fake_emb = np.zeros((3, emb_size), dtype=np.float32) + data: TextToTextLocationIndexData = { + # 2 locations but 3 embeddings → mismatch + "textLocations": [ + {"messageOrdinal": 0, "chunkOrdinal": 0}, + {"messageOrdinal": 1, "chunkOrdinal": 0}, + ], + "embeddings": fake_emb, + } + with pytest.raises(ValueError): + index.deserialize(data) + + +def test_deserialize_valid_data( + index: TextToTextLocationIndex, settings: TextEmbeddingIndexSettings +) -> None: + emb_size = 3 # default size for create_test_embedding_model() + n = 2 + fake_emb = np.zeros((n, emb_size), dtype=np.float32) + data: TextToTextLocationIndexData = { + "textLocations": [ + {"messageOrdinal": 0, "chunkOrdinal": 0}, + {"messageOrdinal": 1, "chunkOrdinal": 0}, + ], + "embeddings": fake_emb, + } + index.deserialize(data) + assert index.get(0) == TextLocation(0) + assert index.get(1) == TextLocation(1) + assert index.get(2) is None + + +# --------------------------------------------------------------------------- +# get() helper +# --------------------------------------------------------------------------- + + +def test_get_returns_correct_location( + index: TextToTextLocationIndex, settings: TextEmbeddingIndexSettings +) -> None: + emb_size = 3 # default size for create_test_embedding_model() + n = 3 + fake_emb = np.zeros((n, emb_size), dtype=np.float32) + data: TextToTextLocationIndexData = { + "textLocations": [ + {"messageOrdinal": 10, "chunkOrdinal": 0}, + {"messageOrdinal": 20, "chunkOrdinal": 1}, + {"messageOrdinal": 30, "chunkOrdinal": 0}, + ], + "embeddings": fake_emb, + } + index.deserialize(data) + assert index.get(0) == TextLocation(10, 0) + assert index.get(1) == TextLocation(20, 1) + assert index.get(2) == TextLocation(30, 0) + assert index.get(3) is None From df22f8d8d6c271f8a7afdc3bb6e39b6bdadb1b5b Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Thu, 23 Apr 2026 19:34:58 +0200 Subject: [PATCH 4/8] fixed import locations Co-authored-by: Copilot --- tests/test_serialization.py | 44 +++++-------------------------------- tests/test_utils.py | 2 +- 2 files changed, 7 insertions(+), 39 deletions(-) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 5e1f1a47..d32b9526 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -10,16 +10,22 @@ from typeagent.knowpro.interfaces import ( ConversationDataWithIndexes, MessageTextIndexData, + Tag, TermsToRelatedTermsIndexData, TextToTextLocationIndexData, + Topic, ) from typeagent.knowpro.knowledge_schema import ConcreteEntity, Quantity from typeagent.knowpro.serialization import ( + ConversationBinaryData, + ConversationFileData, + ConversationJsonData, create_file_header, DeserializationError, deserialize_knowledge, deserialize_object, from_conversation_file_data, + is_primitive, serialize_embeddings, serialize_object, to_conversation_file_data, @@ -143,12 +149,6 @@ def test_deserialization_error(): def test_from_conversation_file_data_missing_header_raises(): """from_conversation_file_data raises when fileHeader is absent.""" - from typeagent.knowpro.serialization import ( - ConversationBinaryData, - ConversationFileData, - ConversationJsonData, - ) - json_data: ConversationJsonData[Any] = ConversationJsonData( nameTag="x", messages=[], tags=[], semanticRefs=None ) @@ -162,12 +162,6 @@ def test_from_conversation_file_data_missing_header_raises(): def test_from_conversation_file_data_bad_version_raises(): """from_conversation_file_data raises on unsupported version.""" - from typeagent.knowpro.serialization import ( - ConversationBinaryData, - ConversationFileData, - ConversationJsonData, - ) - json_data: ConversationJsonData[Any] = ConversationJsonData( nameTag="x", messages=[], @@ -186,12 +180,6 @@ def test_from_conversation_file_data_bad_version_raises(): def test_from_conversation_file_data_missing_embedding_header_raises(): """from_conversation_file_data raises when embeddingFileHeader is absent.""" - from typeagent.knowpro.serialization import ( - ConversationBinaryData, - ConversationFileData, - ConversationJsonData, - ) - json_data: ConversationJsonData[Any] = ConversationJsonData( nameTag="x", messages=[], @@ -209,12 +197,6 @@ def test_from_conversation_file_data_missing_embedding_header_raises(): def test_from_conversation_file_data_missing_embeddings_list_raises(): """from_conversation_file_data raises when embeddingsList is None.""" - from typeagent.knowpro.serialization import ( - ConversationBinaryData, - ConversationFileData, - ConversationJsonData, - ) - json_data: ConversationJsonData[Any] = ConversationJsonData( nameTag="x", messages=[], @@ -233,12 +215,6 @@ def test_from_conversation_file_data_missing_embeddings_list_raises(): def test_from_conversation_file_data_success_empty(): """from_conversation_file_data succeeds with minimal valid data.""" - from typeagent.knowpro.serialization import ( - ConversationBinaryData, - ConversationFileData, - ConversationJsonData, - ) - emb = np.zeros((0, 4), dtype=np.float32) json_data: ConversationJsonData[Any] = ConversationJsonData( nameTag="test", @@ -258,8 +234,6 @@ def test_from_conversation_file_data_success_empty(): def test_is_primitive(): """Test is_primitive classification.""" - from typeagent.knowpro.serialization import is_primitive - for t in (int, float, bool, str, type(None)): assert is_primitive(t), f"Expected {t} to be primitive" assert not is_primitive(list) @@ -280,8 +254,6 @@ def test_deserialize_object_list_of_int(): def test_deserialize_knowledge_entity(): """deserialize_knowledge reconstructs a ConcreteEntity.""" - from typeagent.knowpro.knowledge_schema import ConcreteEntity - obj = {"name": "Bob", "type": ["person"]} result = deserialize_knowledge("entity", obj) assert isinstance(result, ConcreteEntity) @@ -290,8 +262,6 @@ def test_deserialize_knowledge_entity(): def test_deserialize_knowledge_topic(): """deserialize_knowledge reconstructs a Topic.""" - from typeagent.knowpro.interfaces import Topic - obj = {"text": "AI ethics"} result = deserialize_knowledge("topic", obj) assert isinstance(result, Topic) @@ -300,8 +270,6 @@ def test_deserialize_knowledge_topic(): def test_deserialize_knowledge_tag(): """deserialize_knowledge reconstructs a Tag.""" - from typeagent.knowpro.interfaces import Tag - obj = {"text": "important"} result = deserialize_knowledge("tag", obj) assert isinstance(result, Tag) diff --git a/tests/test_utils.py b/tests/test_utils.py index 22f930ae..d6edf57e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -8,6 +8,7 @@ from dotenv import load_dotenv import pytest +from openai import AsyncAzureOpenAI, AsyncOpenAI from openai import AsyncAzureOpenAI, AsyncOpenAI import pydantic.dataclasses import typechat @@ -322,7 +323,6 @@ def test_no_keys_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: def test_openai_key_returns_async_openai( self, monkeypatch: pytest.MonkeyPatch ) -> None: - monkeypatch.setenv("OPENAI_API_KEY", "sk-test") monkeypatch.delenv("AZURE_OPENAI_API_KEY", raising=False) client = utils.create_async_openai_client() From 79ba575313146cd22e26e98cda147012bc8e08b9 Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Thu, 23 Apr 2026 14:14:04 +0200 Subject: [PATCH 5/8] enhance testcoverage - added new test modules for modules with missing tests cases --- tests/test_searchlang_compile.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_searchlang_compile.py b/tests/test_searchlang_compile.py index 797588dc..c907393d 100644 --- a/tests/test_searchlang_compile.py +++ b/tests/test_searchlang_compile.py @@ -204,8 +204,7 @@ def test_single_search_terms_filter(self) -> None: assert len(expr.select_expressions) == 1 terms_in_group = expr.select_expressions[0].search_term_group.terms assert any( - isinstance(t, SearchTerm) and t.term.text == "robots" - for t in terms_in_group + isinstance(t, SearchTerm) and t.term.text == "robots" for t in terms_in_group ) def test_entity_filter_produces_expr(self) -> None: From 54d676d12e985d84968475e53b3ee8e92ef78900 Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Thu, 23 Apr 2026 19:05:42 +0200 Subject: [PATCH 6/8] added additional tests --- tests/test_searchlang_compile.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_searchlang_compile.py b/tests/test_searchlang_compile.py index c907393d..797588dc 100644 --- a/tests/test_searchlang_compile.py +++ b/tests/test_searchlang_compile.py @@ -204,7 +204,8 @@ def test_single_search_terms_filter(self) -> None: assert len(expr.select_expressions) == 1 terms_in_group = expr.select_expressions[0].search_term_group.terms assert any( - isinstance(t, SearchTerm) and t.term.text == "robots" for t in terms_in_group + isinstance(t, SearchTerm) and t.term.text == "robots" + for t in terms_in_group ) def test_entity_filter_produces_expr(self) -> None: From 58a40f86d3b3bb692af8055cb0e5e13ce5e0c789 Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Fri, 24 Apr 2026 20:34:17 +0200 Subject: [PATCH 7/8] implemented changes from PR codereview with @KRRT7 Co-authored-by: Copilot --- tests/test_convthreads.py | 7 +++---- tests/test_memory_semrefindex.py | 31 ++++++++++++++++++++++++++----- tests/test_search.py | 3 +-- tests/test_searchlang_compile.py | 5 +++-- 4 files changed, 33 insertions(+), 13 deletions(-) diff --git a/tests/test_convthreads.py b/tests/test_convthreads.py index 9cbdaa1f..e4d5e2d5 100644 --- a/tests/test_convthreads.py +++ b/tests/test_convthreads.py @@ -8,6 +8,7 @@ from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.interfaces import TextLocation, TextRange, Thread +from typeagent.knowpro.interfaces_serialization import ConversationThreadData from typeagent.storage.memory.convthreads import ConversationThreads @@ -87,8 +88,6 @@ async def test_serialize_roundtrip(threads: ConversationThreads) -> None: @pytest.mark.asyncio async def test_deserialize_empty_data(threads: ConversationThreads) -> None: - from typeagent.knowpro.interfaces import ConversationThreadData - data: ConversationThreadData = {} # type: ignore[typeddict-item] threads.deserialize(data) assert len(threads.threads) == 0 @@ -113,8 +112,8 @@ async def test_lookup_thread_returns_matches(threads: ConversationThreads) -> No await threads.add_thread(make_thread("machine learning and AI")) await threads.add_thread(make_thread("cooking recipes")) results = await threads.lookup_thread("artificial intelligence") - # Should return at least one result (exact scoring depends on embedding model). - assert isinstance(results, list) + assert len(results) > 0 + assert results[0].thread_ordinal == 0 # ordinal of the matching thread @pytest.mark.asyncio diff --git a/tests/test_memory_semrefindex.py b/tests/test_memory_semrefindex.py index 5044d42e..723cdd41 100644 --- a/tests/test_memory_semrefindex.py +++ b/tests/test_memory_semrefindex.py @@ -100,7 +100,14 @@ async def test_add_entity_registers_name_and_types() -> None: index = make_index() entity = kplib.ConcreteEntity(name="Alice", type=["person", "employee"]) terms_added: set[str] = set() - await add_entity(entity, semrefs, index, message_ordinal=0, terms_added=terms_added) + await add_entity( + entity, + semrefs, + index, + message_ordinal=0, + chunk_ordinal=0, + terms_added=terms_added, + ) assert "Alice" in terms_added assert "person" in terms_added assert "employee" in terms_added @@ -116,7 +123,7 @@ async def test_add_entity_with_facets() -> None: type=["item"], facets=[kplib.Facet(name="genre", value="fiction")], ) - await add_entity(entity, semrefs, index, message_ordinal=1) + await add_entity(entity, semrefs, index, message_ordinal=1, chunk_ordinal=0) terms = await index.get_terms() assert "genre" in terms assert "fiction" in terms @@ -133,7 +140,14 @@ async def test_add_topic_registers_text() -> None: index = make_index() topic = Topic(text="machine learning") terms_added: set[str] = set() - await add_topic(topic, semrefs, index, message_ordinal=2, terms_added=terms_added) + await add_topic( + topic, + semrefs, + index, + message_ordinal=2, + chunk_ordinal=0, + terms_added=terms_added, + ) assert "machine learning" in terms_added assert await semrefs.size() == 1 @@ -155,7 +169,14 @@ async def test_add_action_registers_verbs() -> None: indirect_object_entity_name="none", ) terms_added: set[str] = set() - await add_action(action, semrefs, index, message_ordinal=0, terms_added=terms_added) + await add_action( + action, + semrefs, + index, + message_ordinal=0, + chunk_ordinal=0, + terms_added=terms_added, + ) terms = set(await index.get_terms()) assert "run execute" in terms assert "Alice" in terms @@ -174,7 +195,7 @@ async def test_add_action_none_entities_skipped() -> None: object_entity_name="none", indirect_object_entity_name="none", ) - await add_action(action, semrefs, index, message_ordinal=0) + await add_action(action, semrefs, index, message_ordinal=0, chunk_ordinal=0) terms = await index.get_terms() assert "none" not in terms assert "go" in terms diff --git a/tests/test_search.py b/tests/test_search.py index abb94bc2..7028403d 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -10,6 +10,7 @@ SearchTermGroup, Term, ) +from typeagent.knowpro.interfaces_core import ScoredMessageOrdinal from typeagent.knowpro.query import is_conversation_searchable from typeagent.knowpro.search import ( ConversationSearchResult, @@ -53,8 +54,6 @@ def test_search_options_repr_with_fields() -> None: def test_conversation_search_result_basic() -> None: - from typeagent.knowpro.interfaces import ScoredMessageOrdinal - result = ConversationSearchResult( message_matches=[ScoredMessageOrdinal(0, 0.9)], knowledge_matches={}, diff --git a/tests/test_searchlang_compile.py b/tests/test_searchlang_compile.py index 797588dc..9b208fbb 100644 --- a/tests/test_searchlang_compile.py +++ b/tests/test_searchlang_compile.py @@ -375,8 +375,9 @@ def test_entity_with_both_wildcards_no_facet_term(self) -> None: entity = make_entity("Bob", facets=[facet]) f = make_filter(entities=[entity]) group = compiler.compile_term_group(f) - # Both wildcards => no facet term added; entity name term still present - assert len(group.terms) >= 0 # Just no crash + # Both wildcards => no facet term added, but entity name term (or_max) + # and topic term for "Bob" are still generated — 2 terms total. + assert len(group.terms) == 2 def test_pronoun_entity_skipped(self) -> None: compiler = make_compiler() From a97d4d04b8098bbcd77f8972dfede5b61941b4be Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Fri, 24 Apr 2026 21:01:29 +0200 Subject: [PATCH 8/8] Remove duplicate imports in test_utils.py Removed duplicate import statements for AsyncAzureOpenAI and AsyncOpenAI. --- tests/test_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index d6edf57e..d105dc5b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -8,7 +8,6 @@ from dotenv import load_dotenv import pytest -from openai import AsyncAzureOpenAI, AsyncOpenAI from openai import AsyncAzureOpenAI, AsyncOpenAI import pydantic.dataclasses import typechat