Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions tests/test_convthreads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Tests for storage/memory/convthreads.py."""

import pytest

from typeagent.aitools.model_adapters import create_test_embedding_model
from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings
from typeagent.knowpro.interfaces import TextLocation, TextRange, Thread
from typeagent.knowpro.interfaces_serialization import ConversationThreadData
from typeagent.storage.memory.convthreads import ConversationThreads


@pytest.fixture
def settings() -> TextEmbeddingIndexSettings:
return TextEmbeddingIndexSettings(create_test_embedding_model())


@pytest.fixture
def threads(settings: TextEmbeddingIndexSettings) -> ConversationThreads:
return ConversationThreads(settings)


def make_thread(description: str, start: int = 0, end: int = 1) -> Thread:
return Thread(
description=description,
ranges=[
TextRange(start=TextLocation(start), end=TextLocation(end)),
],
)


@pytest.mark.asyncio
async def test_add_thread_appends(threads: ConversationThreads) -> None:
await threads.add_thread(make_thread("topic one"))
assert len(threads.threads) == 1
assert threads.threads[0].description == "topic one"


@pytest.mark.asyncio
async def test_add_multiple_threads(threads: ConversationThreads) -> None:
await threads.add_thread(make_thread("alpha"))
await threads.add_thread(make_thread("beta"))
await threads.add_thread(make_thread("gamma"))
assert len(threads.threads) == 3

Comment thread
bmerkle marked this conversation as resolved.

@pytest.mark.asyncio
async def test_clear_resets_state(threads: ConversationThreads) -> None:
await threads.add_thread(make_thread("something"))
threads.clear()
assert len(threads.threads) == 0
assert len(threads.vector_base) == 0


@pytest.mark.asyncio
async def test_build_index_rebuilds_from_threads(threads: ConversationThreads) -> None:
# Manually add threads without building the vector index.
t1 = make_thread("python programming")
t2 = make_thread("data science")
threads.threads.append(t1)
threads.threads.append(t2)
# build_index should embed all existing threads.
await threads.build_index()
assert len(threads.vector_base) == 2


@pytest.mark.asyncio
async def test_serialize_roundtrip(threads: ConversationThreads) -> None:
await threads.add_thread(make_thread("episode one", 0, 5))
await threads.add_thread(make_thread("episode two", 5, 10))

data = threads.serialize()
assert "threads" in data
thread_list = data["threads"]
assert thread_list is not None
assert len(thread_list) == 2

# Deserialize into a fresh instance.
settings = TextEmbeddingIndexSettings(create_test_embedding_model())
fresh = ConversationThreads(settings)
fresh.deserialize(data)
assert len(fresh.threads) == 2
assert fresh.threads[0].description == "episode one"
assert fresh.threads[1].description == "episode two"


Comment thread
bmerkle marked this conversation as resolved.
@pytest.mark.asyncio
async def test_deserialize_empty_data(threads: ConversationThreads) -> None:
data: ConversationThreadData = {} # type: ignore[typeddict-item]
threads.deserialize(data)
assert len(threads.threads) == 0


@pytest.mark.asyncio
async def test_serialize_without_embeddings(threads: ConversationThreads) -> None:
# Add a thread without going through add_thread (so no embedding yet).
threads.threads.append(make_thread("bare thread"))
data = threads.serialize()
thread_list = data["threads"]
assert thread_list is not None
assert len(thread_list) == 1
# Embedding may be None because vector_base has no entries for this slot.
assert thread_list[0]["embedding"] is None or isinstance(
thread_list[0]["embedding"], list
)


@pytest.mark.asyncio
async def test_lookup_thread_returns_matches(threads: ConversationThreads) -> None:
await threads.add_thread(make_thread("machine learning and AI"))
await threads.add_thread(make_thread("cooking recipes"))
results = await threads.lookup_thread("artificial intelligence")
assert len(results) > 0
assert results[0].thread_ordinal == 0 # ordinal of the matching thread


@pytest.mark.asyncio
async def test_lookup_thread_empty_index(threads: ConversationThreads) -> None:
results = await threads.lookup_thread("anything")
assert results == []
60 changes: 60 additions & 0 deletions tests/test_convutils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import pytest

from typeagent.knowpro.convutils import (
get_time_range_for_conversation,
get_time_range_prompt_section_for_conversation,
)

from conftest import FakeConversation, FakeMessage


class TestGetTimeRangeForConversation:
@pytest.mark.asyncio
async def test_empty_conversation_returns_none(self) -> None:
conv = FakeConversation(messages=[])
result = await get_time_range_for_conversation(conv)
assert result is None

@pytest.mark.asyncio
async def test_message_without_timestamp_returns_none(self) -> None:
msg = FakeMessage("hello") # no message_ordinal → timestamp=None
conv = FakeConversation(messages=[msg])
result = await get_time_range_for_conversation(conv)
assert result is None

@pytest.mark.asyncio
async def test_single_message_with_timestamp(self) -> None:
msg = FakeMessage("hello", message_ordinal=0)
conv = FakeConversation(messages=[msg])
result = await get_time_range_for_conversation(conv)
assert result is not None
assert result.start.isoformat().startswith("2020-01-01T00")

@pytest.mark.asyncio
async def test_multiple_messages_range_start_end(self) -> None:
msgs = [FakeMessage(f"msg{i}", message_ordinal=i) for i in range(3)]
conv = FakeConversation(messages=msgs)
result = await get_time_range_for_conversation(conv)
assert result is not None
assert result.start < result.end # type: ignore[operator]


class TestGetTimeRangePromptSection:
@pytest.mark.asyncio
async def test_no_timestamps_returns_none(self) -> None:
conv = FakeConversation(messages=[FakeMessage("hello")])
result = await get_time_range_prompt_section_for_conversation(conv)
assert result is None

@pytest.mark.asyncio
async def test_with_timestamps_returns_prompt_section(self) -> None:
msgs = [FakeMessage(f"msg{i}", message_ordinal=i) for i in range(2)]
conv = FakeConversation(messages=msgs)
result = await get_time_range_prompt_section_for_conversation(conv)
assert result is not None
assert result["role"] == "system"
assert "CONVERSATION TIME RANGE" in result["content"]
assert "2020-01-01" in result["content"]
Loading
Loading