diff --git a/.secrets.baseline b/.secrets.baseline index 95e3cb1a7..7037b19f0 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -158,7 +158,7 @@ "filename": "django_app/redbox_app/settings.py", "hashed_secret": "3a07fa5a98e5eab18d39dd37f33148d45f8091dc", "is_verified": false, - "line_number": 203, + "line_number": 205, "is_secret": false }, { @@ -166,7 +166,7 @@ "filename": "django_app/redbox_app/settings.py", "hashed_secret": "71ecb7a9026acd0b5c75ce1cd543c411eaa78a75", "is_verified": false, - "line_number": 204, + "line_number": 206, "is_secret": false }, { @@ -174,7 +174,7 @@ "filename": "django_app/redbox_app/settings.py", "hashed_secret": "d2aa34aa9083c20fc7eeb1f2e08dfd7f8f6c0022", "is_verified": false, - "line_number": 205, + "line_number": 207, "is_secret": false }, { @@ -182,7 +182,7 @@ "filename": "django_app/redbox_app/settings.py", "hashed_secret": "28c0d52cedc09b9e794d0e1033ca0b2a06def89a", "is_verified": false, - "line_number": 206, + "line_number": 208, "is_secret": false }, { @@ -190,7 +190,7 @@ "filename": "django_app/redbox_app/settings.py", "hashed_secret": "8782f26f9343343d50facf336a0befad925537d1", "is_verified": false, - "line_number": 207, + "line_number": 209, "is_secret": false }, { @@ -198,7 +198,7 @@ "filename": "django_app/redbox_app/settings.py", "hashed_secret": "1bc89245f6e26d516ddc579b00a0d59e4096a765", "is_verified": false, - "line_number": 216, + "line_number": 218, "is_secret": false } ], @@ -281,5 +281,5 @@ } ] }, - "generated_at": "2026-04-09T11:20:11Z" + "generated_at": "2026-04-10T17:06:15Z" } diff --git a/django_app/redbox_app/settings.py b/django_app/redbox_app/settings.py index d5267e92a..515369d81 100644 --- a/django_app/redbox_app/settings.py +++ b/django_app/redbox_app/settings.py @@ -10,6 +10,8 @@ import sentry_sdk from dbt_copilot_python.database import database_from_env from dbt_copilot_python.error_tracking import DatadogErrorTrackingFilter +from ddtrace import patch +from ddtrace.llmobs import LLMObs from django.urls import reverse_lazy from django_log_formatter_asim import ASIMFormatter from dotenv import find_dotenv, load_dotenv @@ -505,3 +507,8 @@ def filter_transactions(event, _hint): ) PRODUCT_NAME = env.str("PRODUCT_NAME", "Redbox at DBT") + +# datadog +# enable llm manual instrument +LLMObs.enable(integrations_enabled=False, api_key=env.str("DATADOG_API_KEY", "Fake")) +patch(langchain=True, langgraph=True, mcp=True, botocore=False) diff --git a/django_app/start.sh b/django_app/start.sh index f37ec2c89..afc704140 100644 --- a/django_app/start.sh +++ b/django_app/start.sh @@ -7,5 +7,5 @@ venv/bin/django-admin collectstatic --noinput venv/bin/django-admin create_admin_user echo "Starting daphne on port $PORT" -#venv/bin/daphne --websocket_timeout 86400 -b 0.0.0.0 -p $PORT redbox_app.asgi:application -venv/bin/ddtrace-run venv/bin/daphne --websocket_timeout 86400 -b 0.0.0.0 -p $PORT redbox_app.asgi:application +venv/bin/daphne --websocket_timeout 86400 -b 0.0.0.0 -p $PORT redbox_app.asgi:application +# venv/bin/ddtrace-run venv/bin/daphne --websocket_timeout 86400 -b 0.0.0.0 -p $PORT redbox_app.asgi:application diff --git a/redbox/redbox/chains/runnables.py b/redbox/redbox/chains/runnables.py index 957b393e3..db4d56bb3 100644 --- a/redbox/redbox/chains/runnables.py +++ b/redbox/redbox/chains/runnables.py @@ -2,6 +2,8 @@ import re from typing import Any, Callable, Iterable, Iterator +from ddtrace.llmobs import LLMObs +from ddtrace.trace import tracer from langchain_core.callbacks.manager import CallbackManagerForLLMRun, dispatch_custom_event from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage @@ -344,22 +346,34 @@ def _basic_chat_chain(state: RedboxState): "question": state.request.question, "chat_history": truncated_history if using_chat_history else "", } | _additional_variables - if parser: - if isinstance(parser, StrOutputParser): - prompt = ChatPromptTemplate([(system_prompt)]) - else: - format_instructions = parser.get_format_instructions() - prompt = ChatPromptTemplate( - [(system_prompt)], partial_variables={"format_instructions": format_instructions} - ) - if using_only_structure: - chain = prompt | llm - else: - chain = prompt | llm | parser + if parser and not isinstance(parser, StrOutputParser): + format_instructions = parser.get_format_instructions() + prompt = ChatPromptTemplate( + [(system_prompt)], partial_variables={"format_instructions": format_instructions} + ) else: prompt = ChatPromptTemplate([(system_prompt)]) - chain = prompt | llm - return chain.invoke(context) + chain = prompt | llm + + output = chain.invoke(context) + bedrock_span = tracer.current_span() + LLMObs.annotate( + span=bedrock_span, + metadata={ + "max_tokens": (llm._default_config or {}).get("max_tokens", None), + "stop_reason": (output.response_metadata or {}).get("stop_reason", None), + }, + metrics={ + "input_tokens": (output.usage_metadata or {}).get("input_tokens", None), + "output_tokens": (output.usage_metadata or {}).get("output_tokens", None), + "total_tokens": (output.usage_metadata or {}).get("total_tokens", None), + }, + tags={"func": "basic_chat_chain"}, + ) + + if parser and not using_only_structure: + output = parser.invoke(output) + return output return _basic_chat_chain diff --git a/redbox/redbox/graph/agents/workers.py b/redbox/redbox/graph/agents/workers.py index b4e448953..e24749680 100644 --- a/redbox/redbox/graph/agents/workers.py +++ b/redbox/redbox/graph/agents/workers.py @@ -149,5 +149,5 @@ def execute(self): return ( self.reading_task_info() | RunnableParallel(state=self.log_agent_activity(), result=self.core_task() | self.post_processing()) - | (lambda x: x["result"]) # Return only the result - ) + | (lambda x: x["result"]) + ) # Return only the result diff --git a/redbox/redbox/graph/nodes/tools.py b/redbox/redbox/graph/nodes/tools.py index 43a5a8277..4ee404a3c 100644 --- a/redbox/redbox/graph/nodes/tools.py +++ b/redbox/redbox/graph/nodes/tools.py @@ -14,13 +14,18 @@ import numpy as np import pandas as pd import requests +from ddtrace.llmobs import LLMObs +from ddtrace.trace import tracer from elasticsearch import Elasticsearch from langchain_community.utilities import WikipediaAPIWrapper from langchain_core.documents import Document from langchain_core.embeddings.embeddings import Embeddings from langchain_core.messages import ToolCall from langchain_core.tools import Tool, tool +from langchain_mcp_adapters.tools import load_mcp_tools from langgraph.prebuilt import InjectedState +from mcp import ClientSession +from mcp.client.streamable_http import streamablehttp_client from mohawk import Sender from opensearchpy import OpenSearch from sklearn.metrics.pairwise import cosine_similarity @@ -41,9 +46,6 @@ ) from redbox.retriever.retrievers import SchematisedTabularChunkRetriever, query_to_documents from redbox.transform import bedrock_tokeniser, merge_documents, sort_documents -from mcp import ClientSession -from mcp.client.streamable_http import streamablehttp_client -from langchain_mcp_adapters.tools import load_mcp_tools log = logging.getLogger(__name__) @@ -203,9 +205,10 @@ def build_search_documents_tool( ) -> Tool: """Constructs a tool that searches the index and sets state.documents.""" - def search_repo(query, selected_files, permitted_files, ai_settings, start_time=time.time()): + def search_repo(query, selected_files, permitted_files, ai_settings): query_vector = embedding_model.embed_query(query) # Initial pass + start_time = time.time() initial_query = build_document_query( query=query, query_vector=query_vector, @@ -220,12 +223,21 @@ def search_repo(query, selected_files, permitted_files, ai_settings, start_time= "[_search_documents] Initial query using %s seconds", time.time() - start_time, ) + metrics = { + "initial_query_time": None, + "boosted_query_time": None, + "merged_sort_document_time": None, + "no_returned_documents": 0, + } + + metrics["initial_query_time"] = time.time() - start_time # Handle nothing found (as when no files are permitted) if not initial_documents: - return "", [] + return "", [], metrics # Adjacent documents + start_time = time.time() with_adjacent_query = add_document_filter_scores_to_query( elasticsearch_query=initial_query, ai_settings=ai_settings, @@ -236,8 +248,10 @@ def search_repo(query, selected_files, permitted_files, ai_settings, start_time= "[_search_documents] Adjacent boosted query using %s seconds", time.time() - start_time, ) + metrics["boosted_query_time"] = time.time() - start_time # Merge and sort + start_time = time.time() merged_documents = merge_documents(initial=initial_documents, adjacent=adjacent_boosted) sorted_documents = sort_documents(documents=merged_documents) log.warning( @@ -245,9 +259,11 @@ def search_repo(query, selected_files, permitted_files, ai_settings, start_time= time.time() - start_time, ) log.warning("[_search_documents] Returning %s documents", len(sorted_documents)) + metrics["merged_sort_docuemnt_time"] = time.time() - start_time + metrics["no_returned_documents"] = len(sorted_documents) # Return as state update - return format_documents(sorted_documents), sorted_documents + return format_documents(sorted_documents), sorted_documents, metrics @tool(response_format="content_and_artifact") def _search_documents(query: str, state: Annotated[RedboxState, InjectedState]) -> tuple[str, list[Document]]: @@ -264,13 +280,24 @@ def _search_documents(query: str, state: Annotated[RedboxState, InjectedState]) Returns: dict[str, Any]: Collection of matching document snippets with metadata: """ - return search_repo( + + document, artifact, metrics = search_repo( query=query, selected_files=state.request.s3_keys, permitted_files=state.request.permitted_s3_keys, ai_settings=state.request.ai_settings, ) + LLMObs.annotate( + span=tracer.current_span(), + input_data=query, + output_data=document, + metrics=metrics, + tags={"func": "hello-world"}, + ) + + return document, artifact + @tool(response_format="content_and_artifact") def _search_knowledge_base(query: str, state: Annotated[RedboxState, InjectedState]) -> tuple[str, list[Document]]: """ @@ -286,13 +313,23 @@ def _search_knowledge_base(query: str, state: Annotated[RedboxState, InjectedSta Returns: dict[str, Any]: Collection of matching document snippets with metadata: """ - return search_repo( + document, artifact, metrics = search_repo( query=query, selected_files=state.request.knowledge_base_s3_keys, permitted_files=state.request.knowledge_base_s3_keys, ai_settings=state.request.ai_settings, ) + LLMObs.annotate( + span=tracer.current_span(), + input_data=query, + output_data=document, + metrics=metrics, + tags={"func": "hello-world"}, + ) + + return document, artifact + return _search_documents if repository == "user_uploaded" else _search_knowledge_base diff --git a/redbox/redbox/graph/root.py b/redbox/redbox/graph/root.py index 3793101e8..84ba065fa 100644 --- a/redbox/redbox/graph/root.py +++ b/redbox/redbox/graph/root.py @@ -22,8 +22,8 @@ from redbox.graph.nodes.processes import ( build_activity_log_node, build_agent_with_loop, - build_datahub_agent_with_loop, build_chat_pattern, + build_datahub_agent_with_loop, build_error_pattern, build_merge_pattern, build_passthrough_pattern, diff --git a/redbox/redbox/test/data.py b/redbox/redbox/test/data.py index 79998619a..64d829db6 100644 --- a/redbox/redbox/test/data.py +++ b/redbox/redbox/test/data.py @@ -12,11 +12,10 @@ from langchain_core.tools import BaseTool from pydantic.v1 import BaseModel, Field, validator -from redbox.models.chain import RedboxQuery +from redbox.models.chain import MultiAgentPlanBase, RedboxQuery from redbox.models.chat import ChatRoute, ErrorRoute from redbox.models.file import ChunkResolution, TabularSchema, UploadedFileMetadata from redbox.models.graph import RedboxActivityEvent -from redbox.models.chain import MultiAgentPlanBase log = logging.getLogger() @@ -243,6 +242,10 @@ class GenericFakeChatModelWithTools(GenericFakeChatModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + @property + def _default_config(self): + return {} + def bind_tools( self, tools: Sequence[dict[str, Any] | type | Callable | BaseTool], diff --git a/redbox/tests/graph/test_app.py b/redbox/tests/graph/test_app.py index 4f8ba7f37..c78a7d089 100644 --- a/redbox/tests/graph/test_app.py +++ b/redbox/tests/graph/test_app.py @@ -17,7 +17,7 @@ from pytest_mock import MockerFixture from redbox import Redbox -from redbox.graph.nodes.processes import create_or_update_db_from_tabulars, check_if_task_requires_user_feedback +from redbox.graph.nodes.processes import check_if_task_requires_user_feedback, create_or_update_db_from_tabulars from redbox.models.chain import ( AISettings, DocumentState, @@ -25,9 +25,9 @@ RedboxState, RequestMetadata, StructuredResponseWithCitations, + TaskStatus, configure_agent_task_plan, metadata_reducer, - TaskStatus, ) from redbox.models.chat import ChatRoute, ErrorRoute from redbox.models.graph import RedboxActivityEvent @@ -224,7 +224,6 @@ def _search_govuk(query: str) -> dict[str, Any]: # Mock the LLM and relevant tools llm = GenericFakeChatModelWithTools(messages=iter(test_case.test_data.llm_responses)) - llm._default_config = {"model": "bedrock"} mocker.patch("redbox.graph.nodes.processes.get_chat_llm", return_value=llm) # Instantiate app diff --git a/redbox/tests/graph/test_multiagents_app.py b/redbox/tests/graph/test_multiagents_app.py index 057a8049d..ccc417df5 100644 --- a/redbox/tests/graph/test_multiagents_app.py +++ b/redbox/tests/graph/test_multiagents_app.py @@ -317,14 +317,11 @@ async def test_newroute_zero_or_one_task( configured_multi_agent_plan = multi_agent_plan().model_copy(update={"tasks": tasks}) planner = configured_multi_agent_plan.model_dump_json() planner_response = GenericFakeChatModelWithTools(messages=iter([planner])) - planner_response._default_config = {"model": "bedrock"} evaluator_response = GenericFakeChatModelWithTools(messages=iter([evaluator])) - evaluator_response._default_config = {"model": "bedrock"} # mock response from worker agent worker_response = GenericFakeChatModelWithTools(messages=iter([WORKER_RESPONSE])) - worker_response._default_config = {"model": "bedrock"} mock_chat_chain = mocker.patch("redbox.chains.runnables.get_chat_llm") mock_chat_chain.side_effect = [planner_response, worker_response] @@ -377,7 +374,6 @@ async def test_newroute_more_than_one_tasks( ): # mocking user feedback classification feedback_class_response = GenericFakeChatModelWithTools(messages=iter([AIMessage(content=user_feedback)])) - feedback_class_response._default_config = {"model": "bedrock"} side_effect = [feedback_class_response] # mocking planner agent with tasks @@ -400,14 +396,12 @@ async def test_newroute_more_than_one_tasks( planner_response = GenericFakeChatModelWithTools( messages=iter([configured_multi_agent_plan.model_dump_json()]) ) - planner_response._default_config = {"model": "bedrock"} side_effect += [planner_response] # mock response from worker agents tool_call_side_effect = [] for i in range(len(agents)): worker_response = GenericFakeChatModelWithTools(messages=iter([WORKER_RESPONSE])) - worker_response._default_config = {"model": "bedrock"} side_effect += [worker_response] tool_call_side_effect += [[WORKER_TOOL_RESPONSE]] @@ -420,7 +414,6 @@ async def test_newroute_more_than_one_tasks( # mock evaluator evaluator_response = GenericFakeChatModelWithTools(messages=iter([evaluator])) - evaluator_response._default_config = {"model": "bedrock"} mocker.patch("redbox.graph.nodes.processes.get_chat_llm", return_value=evaluator_response) test_case = self.create_new_route_test( diff --git a/redbox/tests/graph/test_patterns.py b/redbox/tests/graph/test_patterns.py index dac6df798..327fb3b63 100644 --- a/redbox/tests/graph/test_patterns.py +++ b/redbox/tests/graph/test_patterns.py @@ -1,8 +1,7 @@ +import copy from uuid import uuid4 -import copy import pytest -from langchain_core.language_models.fake_chat_models import GenericFakeChatModel from langchain_core.messages import AIMessage, HumanMessage, ToolCall from langchain_core.retrievers import BaseRetriever from langchain_core.runnables import RunnableLambda @@ -14,8 +13,8 @@ from redbox.chains.runnables import CannedChatLLM, build_chat_prompt_from_messages_runnable, build_llm_chain from redbox.graph.nodes.processes import ( build_agent_with_loop, - build_datahub_agent_with_loop, build_chat_pattern, + build_datahub_agent_with_loop, build_merge_pattern, build_passthrough_pattern, build_retrieve_pattern, @@ -39,6 +38,7 @@ ) from redbox.models.chat import ChatRoute from redbox.test.data import ( + GenericFakeChatModelWithTools, RedboxChatTestCase, RedboxTestData, generate_docs, @@ -122,8 +122,7 @@ def test_build_chat_prompt_from_messages_runnable(test_case: RedboxChatTestCase, @pytest.mark.parametrize(("test_case"), BUILD_LLM_TEST_CASES, ids=[t.test_id for t in BUILD_LLM_TEST_CASES]) def test_build_llm_chain(test_case: RedboxChatTestCase): """Tests a given state can update the data and metadata correctly.""" - llm = GenericFakeChatModel(messages=iter(test_case.test_data.llm_responses)) - llm._default_config = {"model": "bedrock"} + llm = GenericFakeChatModelWithTools(messages=iter(test_case.test_data.llm_responses)) llm_chain = build_llm_chain(PromptSet.Chat, llm) state = RedboxState(request=test_case.query) @@ -163,8 +162,7 @@ def test_build_llm_chain(test_case: RedboxChatTestCase): @pytest.mark.parametrize(("test_case"), CHAT_TEST_CASES, ids=[t.test_id for t in CHAT_TEST_CASES]) def test_build_chat_pattern(test_case: RedboxChatTestCase, mocker: MockerFixture): """Tests a given state["request"] correctly changes state["text"].""" - llm = GenericFakeChatModel(messages=iter(test_case.test_data.llm_responses)) - llm._default_config = {"model": "bedrock"} + llm = GenericFakeChatModelWithTools(messages=iter(test_case.test_data.llm_responses)) state = RedboxState(request=test_case.query) chat = build_chat_pattern(prompt_set=PromptSet.Chat, final_response_chain=True) @@ -299,8 +297,7 @@ def test_build_retrieve_pattern(test_case: RedboxChatTestCase, mock_retriever: B @pytest.mark.parametrize(("test_case"), MERGE_TEST_CASES, ids=[t.test_id for t in MERGE_TEST_CASES]) def test_build_merge_pattern(test_case: RedboxChatTestCase, mocker: MockerFixture): """Tests a given state["request"] and state["documents"] correctly changes state["documents"].""" - llm = GenericFakeChatModel(messages=iter(test_case.test_data.llm_responses)) - llm._default_config = {"model": "bedrock"} + llm = GenericFakeChatModelWithTools(messages=iter(test_case.test_data.llm_responses)) state = RedboxState(request=test_case.query, documents=structure_documents_by_file_name(test_case.docs)) merge = build_merge_pattern(prompt_set=PromptSet.ChatwithDocsMapReduce, final_response_chain=True) @@ -350,8 +347,7 @@ def test_build_merge_pattern(test_case: RedboxChatTestCase, mocker: MockerFixtur @pytest.mark.parametrize(("test_case"), STUFF_TEST_CASES, ids=[t.test_id for t in STUFF_TEST_CASES]) def test_build_stuff_pattern(test_case: RedboxChatTestCase, mocker: MockerFixture): """Tests a given state["request"] and state["documents"] correctly changes state["text"].""" - llm = GenericFakeChatModel(messages=iter(test_case.test_data.llm_responses)) - llm._default_config = {"model": "bedrock"} + llm = GenericFakeChatModelWithTools(messages=iter(test_case.test_data.llm_responses)) state = RedboxState(request=test_case.query, documents=structure_documents_by_file_name(test_case.docs)) stuff = build_stuff_pattern(prompt_set=PromptSet.ChatwithDocs, final_response_chain=True) @@ -570,8 +566,7 @@ async def test_canned_llm_async(): ("test_case"), STRUCTURED_OUTPUT_TEST_CASE, ids=[t.test_id for t in STRUCTURED_OUTPUT_TEST_CASE] ) def test_citation_structured_output(test_case: RedboxChatTestCase, mocker: MockerFixture): - llm = GenericFakeChatModel(messages=iter(test_case.test_data.llm_responses)) - llm._default_config = {"model": "bedrock"} + llm = GenericFakeChatModelWithTools(messages=iter(test_case.test_data.llm_responses)) mocker.patch("redbox.graph.nodes.processes.get_chat_llm", return_value=llm) state = RedboxState(request=test_case.query, documents=structure_documents_by_file_name(test_case.docs)) @@ -655,7 +650,7 @@ def test_preprocess_loop( content="test", additional_kwargs={"tool_calls": [{"name": "test_tool", "args": {"is_intermediate_step": True}}]}, ) - llm = GenericFakeChatModel(messages=iter([res])) + llm = GenericFakeChatModelWithTools(messages=iter([res])) mock_llm = mocker.patch("redbox.chains.runnables.get_chat_llm", return_value=llm) mock_tool_calls = mocker.patch("redbox.graph.nodes.processes.run_tools_parallel") @@ -735,7 +730,7 @@ def test_llm_response_truncation(self, test_name, max_tokens, actual_tokens, fak content=llm_content.strip(), additional_kwargs={"tool_calls": [{"name": "test_tool", "args": {"is_intermediate_step": True}}]}, ) - llm = GenericFakeChatModel(messages=iter([llm_message])) + llm = GenericFakeChatModelWithTools(messages=iter([llm_message])) mocker.patch("redbox.chains.runnables.get_chat_llm", return_value=llm) mock_tool_calls = mocker.patch("redbox.graph.nodes.processes.run_tools_parallel") @@ -864,7 +859,7 @@ async def test_preprocess_loop( {"name": "test_tool", "args": {"is_intermediate_step": True}, "id": "fake-id", "type": "tool_call"} ], ) - llm = GenericFakeChatModel(messages=iter([res])) + llm = GenericFakeChatModelWithTools(messages=iter([res])) mock_llm = mocker.patch("redbox.chains.runnables.get_chat_llm", return_value=llm) mock_tool_calls = mocker.patch("redbox.graph.nodes.processes.run_tools_parallel") @@ -962,7 +957,7 @@ async def test_llm_response_truncation( content=llm_content.strip(), additional_kwargs={"tool_calls": [{"name": "test_tool", "args": {"is_intermediate_step": True}}]}, ) - llm = GenericFakeChatModel(messages=iter([llm_message])) + llm = GenericFakeChatModelWithTools(messages=iter([llm_message])) mocker.patch("redbox.chains.runnables.get_chat_llm", return_value=llm) mock_tool_calls = mocker.patch("redbox.graph.nodes.processes.run_tools_parallel") @@ -1098,7 +1093,7 @@ async def test_feedback_required_vs_not( {"name": "test_tool", "args": {"is_intermediate_step": False}, "id": "fake-id", "type": "tool_call"} ], ) - llm = GenericFakeChatModel(messages=iter([res] * 10)) + llm = GenericFakeChatModelWithTools(messages=iter([res] * 10)) mocker.patch("redbox.chains.runnables.get_chat_llm", return_value=llm) mock_tool_calls = mocker.patch("redbox.graph.nodes.processes.run_tools_parallel") diff --git a/redbox/tests/test_ingest.py b/redbox/tests/test_ingest.py index 973b37e7d..a368edea7 100644 --- a/redbox/tests/test_ingest.py +++ b/redbox/tests/test_ingest.py @@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Any from unittest.mock import MagicMock, patch import pytest -from langchain_core.language_models.fake_chat_models import GenericFakeChatModel from redbox.loader.loaders import MetadataLoader, parse_tabular_schema from redbox.models.chain import GeneratedMetadata @@ -23,6 +22,8 @@ from io import BytesIO import pandas as pd +from redbox.test.data import GenericFakeChatModelWithTools + if TYPE_CHECKING: from mypy_boto3_s3.client import S3Client @@ -78,7 +79,7 @@ def test_extract_metadata_missing_key( ): mock_llm_response = mock_llm.return_value mock_llm_response.status_code = 200 - mock_llm_response.return_value = GenericFakeChatModel(messages=iter(['{"missing_key":""}'])) + mock_llm_response.return_value = GenericFakeChatModelWithTools(messages=iter(['{"missing_key":""}'])) """ LLM replies but without one of the keys @@ -105,7 +106,7 @@ def test_extract_metadata_extra_key( ): mock_llm_response = mock_llm.return_value mock_llm_response.status_code = 200 - mock_llm_response.return_value = GenericFakeChatModel( + mock_llm_response.return_value = GenericFakeChatModelWithTools( messages=iter(['{"extra_key": "", "name": "foo", "description": "test", "keywords": ["abc"]}']) )