diff --git a/server/api/views/assistant/assistant_prompts.py b/server/api/views/assistant/assistant_prompts.py new file mode 100644 index 00000000..44bf9b9b --- /dev/null +++ b/server/api/views/assistant/assistant_prompts.py @@ -0,0 +1,38 @@ +INSTRUCTIONS = """ +You are an AI assistant that helps users find and understand information about bipolar disorder +from your internal library of bipolar disorder research sources using semantic search. + +IMPORTANT CONTEXT: +- You have access to a library of sources that the user CANNOT see +- The user did not upload these sources and doesn't know about them +- You must explain what information exists in your sources and provide clear references + +TOPIC RESTRICTIONS: +When a prompt is received that is unrelated to bipolar disorder, mental health treatment, +or psychiatric medications, respond by saying you are limited to bipolar-specific conversations. + +SEMANTIC SEARCH STRATEGY: +- Always perform semantic search using the search_documents function when users ask questions +- Use conceptually related terms and synonyms, not just exact keyword matches +- Search for the meaning and context of the user's question, not just literal words +- Consider medical terminology, lay terms, and related conditions when searching + +FUNCTION USAGE: +- When a user asks about information that might be in your source library, ALWAYS use the search_documents function first +- Perform semantic searches using concepts, symptoms, treatments, and related terms from the user's question +- Only provide answers based on information found through your source searches + +RESPONSE FORMAT: +After gathering information through semantic searches, provide responses that: +1. Answer the user's question directly using only the found information +2. Structure responses with clear sections and paragraphs +3. Explain what information you found in your sources and provide context +4. Include citations using this exact format: [Name {name}, Page {page_number}] +5. Only cite information that directly supports your statements + +If no relevant information is found in your source library, clearly state that the information +is not available in your current sources. + +REMEMBER: You are working with an internal library of bipolar disorder sources that the user +cannot see. Always search these sources first, explain what you found, and provide proper citations. +""" \ No newline at end of file diff --git a/server/api/views/assistant/assistant_services.py b/server/api/views/assistant/assistant_services.py new file mode 100644 index 00000000..ac339b9f --- /dev/null +++ b/server/api/views/assistant/assistant_services.py @@ -0,0 +1,72 @@ +import os +import logging + +from openai import OpenAI + +from .assistant_prompts import INSTRUCTIONS +from .tool_services import ( + SEARCH_TOOLS_SCHEMA, + make_search_tool_mapping, + handle_tool_calls_with_reasoning, +) + +logger = logging.getLogger(__name__) + + +def run_assistant( + message: str, + user, + previous_response_id: str | None = None, +) -> tuple[str, str]: + """Wire together the OpenAI client, retrieval, and the agentic reasoning loop. + + Parameters + ---------- + message : str + The user's input message. + user : User + The Django user object used for document access control in search_documents. + previous_response_id : str | None + ID of a prior response for multi-turn conversation continuity. + + Returns + ------- + tuple[str, str] + (final_response_output_text, final_response_id) + """ + # TODO: Track total duration, cost metrics, and tool_calls_made count + # and return them from run_assistant for use in eval_assistant.py CSV output + + client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) + + MODEL_DEFAULTS = { + "instructions": INSTRUCTIONS, + "model": "gpt-5-nano", # 400,000 token context window + # A summary of the reasoning performed by the model. This can be useful for debugging and understanding the model's reasoning process. + "reasoning": {"effort": "low", "summary": None}, + "tools": SEARCH_TOOLS_SCHEMA, + } + + # TOOLS_SCHEMA tells the model what tools exist and what arguments to generate. + # tool_mapping wires those tool names to the Python functions that execute them. + # They are separate because the model generates arguments (schema concern) but + # cannot supply request-time values like user (mapping concern). + tool_mapping = make_search_tool_mapping(user) + + if not previous_response_id: + response = client.responses.create( + input=[ + {"type": "message", "role": "user", "content": str(message)} + ], + **MODEL_DEFAULTS, + ) + else: + response = client.responses.create( + input=[ + {"type": "message", "role": "user", "content": str(message)} + ], + previous_response_id=str(previous_response_id), + **MODEL_DEFAULTS, + ) + + return handle_tool_calls_with_reasoning(response, client, MODEL_DEFAULTS, tool_mapping) diff --git a/server/api/views/assistant/eval_assistant.py b/server/api/views/assistant/eval_assistant.py new file mode 100644 index 00000000..e0d3c969 --- /dev/null +++ b/server/api/views/assistant/eval_assistant.py @@ -0,0 +1,141 @@ +#!/usr/bin/env -S uv run --script +# /// script +# requires-python = "==3.11.11" +# dependencies = [ +# "pandas==2.2.3", +# "openai", +# "django", +# ] +# /// + +# uv script (or plain Python) to generate results to CSV, run from the terminal +# Run from inside the container: docker compose exec backend python eval_assistant.py +# + + +import os +import sys +import logging +import datetime +from concurrent.futures import ThreadPoolExecutor, as_completed + +# Django setup must come before any imports that touch the ORM +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../"))) +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "balancer_backend.settings") + +import django +django.setup() + +import pandas as pd +from django.contrib.auth import get_user_model + +from api.views.assistant.assistant_services import run_assistant +from api.views.assistant.assistant_prompts import INSTRUCTIONS + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + +# Read model and INSTRUCTIONS from the source file or add a lightweight config endpoint to the backend + +# Read model and INSTRUCTIONS from the source file +# INSTRUCTIONS is imported from assistant_prompts.py +# MODEL is read from assistant_services.py MODEL_DEFAULTS +MODEL = "gpt-5-nano" + +# Set of representative questions to evaluate the assistant +QUESTIONS = [ + "What medications are recommended for bipolar depression?", + "What are the risks of lithium for patients with kidney disease?", + "Which mood stabilizers are safe during pregnancy?", + "What is the evidence for quetiapine in bipolar disorder?", + "How does valproate compare to lithium for mania?", +] + + +def run_one(question: str, user, branch: str) -> dict: + """Run the assistant for a single question and return a result row. + + Uses ThreadPoolExecutor (not asyncio.gather + await run_assistant) for concurrency. + + Concurrency approach comparison: + - ThreadPoolExecutor (this implementation): + - run_assistant stays sync — views.py and the WSGI web app are unaffected + - Each question runs in a thread pool worker, blocking on OpenAI + DB I/O + - Django DB safe when run via `docker compose exec backend python eval_assistant.py`: + this is a synchronous Django process context. Each ThreadPoolExecutor worker + is a real OS thread with its own threading.local() storage, so each thread + gets its own DB connection created lazily on first use. There is no shared + event loop thread, so connections cannot clash or bleed between questions. + The connection isolation concern only arises in ASGI contexts where multiple + coroutines share one thread and therefore one threading.local() connection — + which is not the case here. + - Runtime: bottlenecked by OpenAI rate limits, not thread overhead + - asyncio.gather + await run_assistant (alternative): + - run_assistant becomes async — requires async def post in views.py, + AsyncOpenAI client, and async handle_tool_calls_with_reasoning + - Django DB unsafe if get_closest_embeddings is called directly in an async + context without wrapping: get_closest_embeddings is a sync function that + hits the ORM, so calling it on the event loop thread blocks all other + coroutines until the DB responds. The fix is sync_to_async(get_closest_embeddings), + which runs it in a dedicated worker thread with its own threading.local() + connection. Bare await does not work at all — Django ORM querysets are not + awaitables and raise TypeError immediately. + - Under WSGI (manage.py runserver), async views run in a new event loop + per request — adds overhead to every web request for no benefit + - Cleaner call site in eval_assistant.py but wrong trade-off given WSGI + """ + try: + response_text, response_id = run_assistant(message=question, user=user) + return { + "branch": branch, + "model": MODEL, + "question": question, + "response_output_text": response_text, + "error": None, + } + except Exception as e: + logger.error(f"Error evaluating question '{question}': {e}") + return { + "branch": branch, + "model": MODEL, + "question": question, + "response_output_text": None, + "error": str(e), + } + + +def main(): + branch = os.environ.get("EVAL_BRANCH", "develop") + + User = get_user_model() + user = User.objects.filter(is_superuser=True).first() + if not user: + raise RuntimeError("No superuser found. Create one with manage.py createsuperuser.") + + logger.info(f"Starting evaluation: branch={branch}, model={MODEL}, questions={len(QUESTIONS)}") + + # ThreadPoolExecutor runs questions concurrently — see run_one docstring + # for trade-off discussion vs asyncio.gather + await run_assistant. + # max_workers=5 stays safely under OpenAI rate limits for gpt-5-nano. + results = [] + with ThreadPoolExecutor(max_workers=5) as pool: + futures = { + pool.submit(run_one, question, user, branch): question + for question in QUESTIONS + } + for future in as_completed(futures): + results.append(future.result()) + + df = pd.DataFrame(results) + + results_dir = os.path.join(os.path.dirname(__file__), "results") + os.makedirs(results_dir, exist_ok=True) + timestamp = datetime.datetime.utcnow().strftime("%Y%m%dT%H%M%S") + output_path = os.path.join(results_dir, f"{branch}-{timestamp}.csv") + df.to_csv(output_path, index=False) + + logger.info(f"Results saved to {output_path}") + + +if __name__ == "__main__": + main() diff --git a/server/api/views/assistant/review.ipynb b/server/api/views/assistant/review.ipynb new file mode 100644 index 00000000..e6da83d0 --- /dev/null +++ b/server/api/views/assistant/review.ipynb @@ -0,0 +1 @@ +# notebook to review and compare the two CSVs \ No newline at end of file diff --git a/server/api/views/assistant/test_assistant_services.py b/server/api/views/assistant/test_assistant_services.py new file mode 100644 index 00000000..3cb70534 --- /dev/null +++ b/server/api/views/assistant/test_assistant_services.py @@ -0,0 +1,129 @@ +from unittest.mock import MagicMock, patch + + +def _make_terminal_response(output_text="Final answer.", response_id="resp-1"): + response = MagicMock() + response.output = [] + response.output_text = output_text + response.id = response_id + return response + + +# --------------------------------------------------------------------------- +# run_assistant tests +# +# run_assistant is responsible for wiring together the OpenAI client, +# make_search_tool_mapping (which binds user to search_documents), and +# handle_tool_calls_with_reasoning. +# +# We patch the OpenAI client and handle_tool_calls_with_reasoning to test +# that run_assistant correctly assembles and forwards its arguments. +# --------------------------------------------------------------------------- + +@patch("api.views.assistant.assistant_services.handle_tool_calls_with_reasoning") +@patch("api.views.assistant.assistant_services.OpenAI") +def test_run_assistant_returns_text_and_id(mock_openai_cls, mock_handle): + mock_client = MagicMock() + mock_openai_cls.return_value = mock_client + mock_client.responses.create.return_value = _make_terminal_response() + mock_handle.return_value = ("Final answer.", "resp-1") + + from api.views.assistant.assistant_services import run_assistant + + user = MagicMock() + text, resp_id = run_assistant(message="What is lithium?", user=user) + + assert text == "Final answer." + assert resp_id == "resp-1" + + +@patch("api.views.assistant.assistant_services.handle_tool_calls_with_reasoning") +@patch("api.views.assistant.assistant_services.OpenAI") +def test_run_assistant_sends_message_as_user_input(mock_openai_cls, mock_handle): + mock_client = MagicMock() + mock_openai_cls.return_value = mock_client + mock_client.responses.create.return_value = _make_terminal_response() + mock_handle.return_value = ("answer", "resp-1") + + from api.views.assistant.assistant_services import run_assistant + + run_assistant(message="Tell me about valproate.", user=MagicMock()) + + call_kwargs = mock_client.responses.create.call_args + input_messages = call_kwargs.kwargs.get("input") or call_kwargs.args[0] + assert any( + item.get("role") == "user" and "valproate" in item.get("content", "") + for item in input_messages + ) + + +@patch("api.views.assistant.assistant_services.handle_tool_calls_with_reasoning") +@patch("api.views.assistant.assistant_services.OpenAI") +def test_run_assistant_passes_previous_response_id(mock_openai_cls, mock_handle): + mock_client = MagicMock() + mock_openai_cls.return_value = mock_client + mock_client.responses.create.return_value = _make_terminal_response() + mock_handle.return_value = ("answer", "resp-2") + + from api.views.assistant.assistant_services import run_assistant + + run_assistant(message="More info.", user=MagicMock(), previous_response_id="resp-1") + + call_kwargs = mock_client.responses.create.call_args.kwargs + assert call_kwargs.get("previous_response_id") == "resp-1" + + +@patch("api.views.assistant.assistant_services.handle_tool_calls_with_reasoning") +@patch("api.views.assistant.assistant_services.OpenAI") +def test_run_assistant_omits_previous_response_id_when_none(mock_openai_cls, mock_handle): + mock_client = MagicMock() + mock_openai_cls.return_value = mock_client + mock_client.responses.create.return_value = _make_terminal_response() + mock_handle.return_value = ("answer", "resp-1") + + from api.views.assistant.assistant_services import run_assistant + + run_assistant(message="First message.", user=MagicMock(), previous_response_id=None) + + call_kwargs = mock_client.responses.create.call_args.kwargs + assert "previous_response_id" not in call_kwargs + + +@patch("api.views.assistant.assistant_services.handle_tool_calls_with_reasoning") +@patch("api.views.assistant.assistant_services.OpenAI") +def test_run_assistant_passes_search_tools_schema_to_model(mock_openai_cls, mock_handle): + from api.views.assistant.assistant_services import run_assistant + from api.views.assistant.tool_services import SEARCH_TOOLS_SCHEMA + + mock_client = MagicMock() + mock_openai_cls.return_value = mock_client + mock_client.responses.create.return_value = _make_terminal_response() + mock_handle.return_value = ("answer", "resp-1") + + run_assistant(message="query", user=MagicMock()) + + call_kwargs = mock_client.responses.create.call_args.kwargs + assert call_kwargs.get("tools") == SEARCH_TOOLS_SCHEMA + + +@patch("api.views.assistant.tool_services.search_documents") +@patch("api.views.assistant.assistant_services.handle_tool_calls_with_reasoning") +@patch("api.views.assistant.assistant_services.OpenAI") +def test_run_assistant_binds_user_to_search_documents(mock_openai_cls, mock_handle, mock_search): + mock_client = MagicMock() + mock_openai_cls.return_value = mock_client + mock_client.responses.create.return_value = _make_terminal_response() + mock_handle.return_value = ("answer", "resp-1") + + from api.views.assistant.assistant_services import run_assistant + + user = MagicMock() + run_assistant(message="query", user=user) + + # Extract the tool_mapping passed to handle_tool_calls_with_reasoning + tool_mapping = mock_handle.call_args.kwargs.get("tool_mapping") or mock_handle.call_args.args[3] + bound_search = tool_mapping["search_documents"] + + # Calling the bound function should forward user to search_documents + bound_search(query="test query") + mock_search.assert_called_once_with("test query", user) diff --git a/server/api/views/assistant/test_eval_assistant.py b/server/api/views/assistant/test_eval_assistant.py new file mode 100644 index 00000000..e69de29b diff --git a/server/api/views/assistant/test_tool_services.py b/server/api/views/assistant/test_tool_services.py new file mode 100644 index 00000000..ef78a383 --- /dev/null +++ b/server/api/views/assistant/test_tool_services.py @@ -0,0 +1,210 @@ +import json +from unittest.mock import MagicMock, patch + +from api.views.assistant.tool_services import ( + invoke_functions_from_response, + handle_tool_calls_with_reasoning, + make_search_tool_mapping, +) + + +# --------------------------------------------------------------------------- +# make_search_tool_mapping tests +# +# make_search_tool_mapping is responsible for binding user to search_documents +# so the tool dispatcher can call it with only the query argument the model +# generates. We test the shape of the returned mapping and that user is +# forwarded correctly to search_documents. +# --------------------------------------------------------------------------- + +@patch("api.views.assistant.tool_services.search_documents") +def test_make_search_tool_mapping_returns_search_documents_key(mock_search): + user = MagicMock() + mapping = make_search_tool_mapping(user) + assert "search_documents" in mapping + + +@patch("api.views.assistant.tool_services.search_documents") +def test_make_search_tool_mapping_bound_fn_forwards_user(mock_search): + mock_search.return_value = "results" + user = MagicMock() + mapping = make_search_tool_mapping(user) + + mapping["search_documents"](query="lithium") + + mock_search.assert_called_once_with("lithium", user) + + +@patch("api.views.assistant.tool_services.search_documents") +def test_make_search_tool_mapping_different_users_are_independent(mock_search): + # Each call to make_search_tool_mapping should capture its own user, + # so two mappings created with different users do not share state. + user_a = MagicMock() + user_b = MagicMock() + mapping_a = make_search_tool_mapping(user_a) + mapping_b = make_search_tool_mapping(user_b) + + mapping_a["search_documents"](query="q") + mapping_b["search_documents"](query="q") + + calls = mock_search.call_args_list + assert calls[0] == ((("q", user_a),), {}) + assert calls[1] == ((("q", user_b),), {}) + + +# --------------------------------------------------------------------------- +# invoke_functions_from_response tests +# --------------------------------------------------------------------------- + +def _make_function_call_item(name, arguments, call_id): + item = MagicMock() + item.type = "function_call" + item.name = name + item.arguments = json.dumps(arguments) + item.call_id = call_id + return item + + +def _make_reasoning_item(summary="reasoning summary"): + item = MagicMock() + item.type = "reasoning" + item.summary = summary + return item + + +def _make_response(output_items): + response = MagicMock() + response.output = output_items + return response + + +def test_invoke_returns_empty_list_when_no_function_calls(): + response = _make_response([_make_reasoning_item()]) + result = invoke_functions_from_response(response, tool_mapping={}) + assert result == [] + + +def test_invoke_calls_tool_and_returns_output(): + mock_tool = MagicMock(return_value="search result") + item = _make_function_call_item("search_documents", {"query": "lithium"}, "call-1") + response = _make_response([item]) + + result = invoke_functions_from_response( + response, tool_mapping={"search_documents": mock_tool} + ) + + mock_tool.assert_called_once_with(query="lithium") + assert result == [ + {"type": "function_call_output", "call_id": "call-1", "output": "search result"} + ] + + +def test_invoke_returns_error_message_when_tool_not_registered(): + item = _make_function_call_item("unknown_tool", {"query": "x"}, "call-2") + response = _make_response([item]) + + result = invoke_functions_from_response(response, tool_mapping={}) + + assert result[0]["call_id"] == "call-2" + assert "ERROR" in result[0]["output"] + + +def test_invoke_returns_error_message_when_tool_raises(): + mock_tool = MagicMock(side_effect=Exception("tool exploded")) + item = _make_function_call_item("search_documents", {"query": "x"}, "call-3") + response = _make_response([item]) + + result = invoke_functions_from_response( + response, tool_mapping={"search_documents": mock_tool} + ) + + assert "Error executing function call" in result[0]["output"] + + +def test_invoke_handles_multiple_function_calls(): + mock_tool = MagicMock(return_value="result") + items = [ + _make_function_call_item("search_documents", {"query": "q1"}, "call-4"), + _make_function_call_item("search_documents", {"query": "q2"}, "call-5"), + ] + response = _make_response(items) + + result = invoke_functions_from_response( + response, tool_mapping={"search_documents": mock_tool} + ) + + assert len(result) == 2 + assert mock_tool.call_count == 2 + + +# --------------------------------------------------------------------------- +# handle_tool_calls_with_reasoning tests +# --------------------------------------------------------------------------- + +def _make_terminal_response(output_text, response_id): + """A response with no function calls — terminates the loop.""" + response = MagicMock() + response.output = [] + response.output_text = output_text + response.id = response_id + return response + + +def _make_tool_call_response(response_id, query="lithium"): + """A response with one function call — continues the loop.""" + response = MagicMock() + response.output = [_make_function_call_item("search_documents", {"query": query}, "call-loop")] + response.id = response_id + return response + + +def test_handle_terminates_immediately_when_no_tool_calls(): + response = _make_terminal_response("Final answer.", "resp-1") + client = MagicMock() + + text, resp_id = handle_tool_calls_with_reasoning( + response, client, model_defaults={}, tool_mapping={} + ) + + assert text == "Final answer." + assert resp_id == "resp-1" + client.responses.create.assert_not_called() + + +def test_handle_calls_tool_then_terminates(): + mock_search = MagicMock(return_value="doc content") + first_response = _make_tool_call_response("resp-1") + second_response = _make_terminal_response("Final answer.", "resp-2") + + client = MagicMock() + client.responses.create.return_value = second_response + + text, resp_id = handle_tool_calls_with_reasoning( + first_response, + client, + model_defaults={}, + tool_mapping={"search_documents": mock_search}, + ) + + mock_search.assert_called_once_with(query="lithium") + assert text == "Final answer." + assert resp_id == "resp-2" + + +def test_handle_passes_previous_response_id_on_followup(): + mock_search = MagicMock(return_value="doc content") + first_response = _make_tool_call_response("resp-1") + second_response = _make_terminal_response("Done.", "resp-2") + + client = MagicMock() + client.responses.create.return_value = second_response + + handle_tool_calls_with_reasoning( + first_response, + client, + model_defaults={}, + tool_mapping={"search_documents": mock_search}, + ) + + call_kwargs = client.responses.create.call_args.kwargs + assert call_kwargs["previous_response_id"] == "resp-1" diff --git a/server/api/views/assistant/test_views.py b/server/api/views/assistant/test_views.py new file mode 100644 index 00000000..a9511b42 --- /dev/null +++ b/server/api/views/assistant/test_views.py @@ -0,0 +1,66 @@ +from unittest.mock import patch + +from django.test import TestCase +from django.contrib.auth import get_user_model +from rest_framework.test import APIClient + +User = get_user_model() + + +class AssistantViewTest(TestCase): + def setUp(self): + self.client = APIClient() + self.url = "/v1/api/assistant/" + + @patch("api.views.assistant.views.run_assistant") + def test_returns_200_with_response_fields(self, mock_run_assistant): + mock_run_assistant.return_value = ("Lithium is recommended.", "resp-abc-123") + + response = self.client.post( + self.url, + {"message": "What medications help with bipolar depression?"}, + format="json", + ) + + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data["response_output_text"], "Lithium is recommended.") + self.assertEqual(response.data["final_response_id"], "resp-abc-123") + + @patch("api.views.assistant.views.run_assistant") + def test_passes_message_and_user_to_run_assistant(self, mock_run_assistant): + mock_run_assistant.return_value = ("Some response.", "resp-xyz") + + self.client.post( + self.url, + {"message": "Tell me about lithium."}, + format="json", + ) + + call_kwargs = mock_run_assistant.call_args.kwargs + self.assertEqual(call_kwargs["message"], "Tell me about lithium.") + self.assertIn("user", call_kwargs) + self.assertIsNone(call_kwargs["previous_response_id"]) + + @patch("api.views.assistant.views.run_assistant") + def test_passes_previous_response_id_when_provided(self, mock_run_assistant): + mock_run_assistant.return_value = ("Follow-up response.", "resp-456") + + self.client.post( + self.url, + {"message": "Tell me more.", "previous_response_id": "resp-123"}, + format="json", + ) + + call_kwargs = mock_run_assistant.call_args.kwargs + self.assertEqual(call_kwargs["previous_response_id"], "resp-123") + + @patch("api.views.assistant.views.run_assistant", side_effect=Exception("OpenAI error")) + def test_returns_500_on_exception(self, mock_run_assistant): + response = self.client.post( + self.url, + {"message": "What is lithium?"}, + format="json", + ) + + self.assertEqual(response.status_code, 500) + self.assertIn("error", response.data) diff --git a/server/api/views/assistant/tool_services.py b/server/api/views/assistant/tool_services.py new file mode 100644 index 00000000..0fb96cef --- /dev/null +++ b/server/api/views/assistant/tool_services.py @@ -0,0 +1,214 @@ +import json +import logging +from typing import Callable + +from ...services.embedding_services import get_closest_embeddings +from ...services.conversions_services import convert_uuids + +logger = logging.getLogger(__name__) + +TOOL_DESCRIPTION = """ +Search the user's uploaded documents for information relevant to answering their question. +Call this function when you need to find specific information from the user's documents +to provide an accurate, citation-backed response. Always search before answering questions +about document content. +""" + +TOOL_PROPERTY_DESCRIPTION = """ +A specific search query to find relevant information in the user's documents. +Use keywords, phrases, or questions related to what the user is asking about. +Be specific rather than generic - use terms that would appear in the relevant documents. +""" + +# SEARCH_TOOLS_SCHEMA defines the search_documents tool for the OpenAI API. +# The model reads this schema to know what tools are available and what +# arguments to generate — it can only generate arguments declared here. +SEARCH_TOOLS_SCHEMA = [ + { + "type": "function", + "name": "search_documents", + "description": TOOL_DESCRIPTION, + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": TOOL_PROPERTY_DESCRIPTION, + } + }, + "required": ["query"], + }, + } +] + + +# TODO: Add get_tools_schema() and make_tool_mapping(user) aggregation functions +# that combine all tool schemas and mappings so assistant_services.py never needs +# to change when a new tool is added — only tool_services.py does. + +def make_search_tool_mapping(user) -> dict[str, Callable]: + # make_search_tool_mapping binds user to search_documents at call time. + # user is a request-time value the model cannot generate, so it must be + # captured here and kept out of the schema. + """Return a tool mapping with search_documents bound to the given user. + + Parameters + ---------- + user : User + The Django user object used for document access control. + + Returns + ------- + dict[str, Callable] + Tool mapping ready to pass to invoke_functions_from_response. + """ + def bound_search(query: str) -> str: + return search_documents(query, user) + + return {"search_documents": bound_search} + + +def search_documents(query: str, user) -> str: + """ + Search through user's uploaded documents using semantic similarity. + + This function performs vector similarity search against the user's document corpus + and returns formatted results with context information for the LLM to use. + + Parameters + ---------- + query : str + The search query string + user : User + The authenticated user whose documents to search + + Returns + ------- + str + Formatted search results containing document excerpts with metadata + + Raises + ------ + Exception + If embedding search fails + """ + + try: + embeddings_results = get_closest_embeddings( + user=user, message_data=query.strip() + ) + embeddings_results = convert_uuids(embeddings_results) + + if not embeddings_results: + return "No relevant documents found for your query. Please try different search terms or upload documents first." + + # Format results with clear structure and metadata + prompt_texts = [ + f"[Document {i + 1} - File: {obj['file_id']}, Name: {obj['name']}, Page: {obj['page_number']}, Chunk: {obj['chunk_number']}, Similarity: {1 - obj['distance']:.3f}]\n{obj['text']}\n[End Document {i + 1}]" + for i, obj in enumerate(embeddings_results) + ] + + return "\n\n".join(prompt_texts) + + except Exception as e: + return f"Error searching documents: {str(e)}. Please try again if the issue persists." + + +def invoke_functions_from_response( + response, tool_mapping: dict[str, Callable] +) -> list[dict]: + """Extract all function calls from the response, look up the corresponding tool function(s) and execute them. + (This would be a good place to handle asynchroneous tool calls, or ones that take a while to execute.) + This returns a list of messages to be added to the conversation history. + + Parameters + ---------- + response : OpenAI Response + The response object from OpenAI containing output items that may include function calls + tool_mapping : dict[str, Callable] + A dictionary mapping function names (as strings) to their corresponding Python functions. + Keys should match the function names defined in the tools schema. + + Returns + ------- + list[dict] + List of function call output messages formatted for the OpenAI conversation. + Each message contains: + - type: "function_call_output" + - call_id: The unique identifier for the function call + - output: The result returned by the executed function (string or error message) + """ + + # Open AI Cookbook: Handling Function Calls with Reasoning Models + # https://cookbook.openai.com/examples/reasoning_function_calls + + intermediate_messages = [] + for response_item in response.output: + if response_item.type == "function_call": + target_tool = tool_mapping.get(response_item.name) + if target_tool: + try: + arguments = json.loads(response_item.arguments) + logger.info( + f"Invoking tool: {response_item.name} with arguments: {arguments}" + ) + tool_output = target_tool(**arguments) + logger.info(f"Tool {response_item.name} completed successfully") + except Exception as e: + msg = f"Error executing function call: {response_item.name}: {e}" + tool_output = msg + logger.error(msg, exc_info=True) + else: + msg = f"ERROR - No tool registered for function call: {response_item.name}" + tool_output = msg + logger.error(msg) + intermediate_messages.append( + { + "type": "function_call_output", + "call_id": response_item.call_id, + "output": tool_output, + } + ) + elif response_item.type == "reasoning": + logger.info(f"Reasoning step: {response_item.summary}") + return intermediate_messages + +def handle_tool_calls_with_reasoning( + response, client, model_defaults: dict, tool_mapping: dict[str, Callable] +) -> tuple[str, str]: + """Run the agentic loop until the model stops emitting function calls. + + Parameters + ---------- + response : OpenAI Response + The initial response from the model. + client : OpenAI + The OpenAI client instance. + model_defaults : dict + Keyword arguments forwarded to every client.responses.create call. + tool_mapping : dict[str, Callable] + Maps function names to their implementations. + + Returns + ------- + tuple[str, str] + (final_response_output_text, final_response_id) + """ + # Open AI Cookbook: Handling Function Calls with Reasoning Models + # https://cookbook.openai.com/examples/reasoning_function_calls + while True: + # Mapping of the tool names we tell the model about and the functions that implement them + function_responses = invoke_functions_from_response(response, tool_mapping) + if len(function_responses) == 0: # We're done reasoning + logger.info("Reasoning completed") + final_response_output_text = response.output_text + final_response_id = response.id + logger.info(f"Final response: {final_response_output_text}") + return final_response_output_text, final_response_id + else: + logger.info("More reasoning required, continuing...") + response = client.responses.create( + input=function_responses, + previous_response_id=response.id, + **model_defaults, + ) diff --git a/server/api/views/assistant/views.py b/server/api/views/assistant/views.py index e3e8d6f7..d7b10a4b 100644 --- a/server/api/views/assistant/views.py +++ b/server/api/views/assistant/views.py @@ -1,8 +1,4 @@ -import os -import json import logging -import time -from typing import Callable from rest_framework.views import APIView from rest_framework.response import Response @@ -13,103 +9,10 @@ from drf_spectacular.utils import extend_schema, inline_serializer from rest_framework import serializers as drf_serializers -from openai import OpenAI +from .assistant_services import run_assistant -from ...services.embedding_services import get_closest_embeddings -from ...services.conversions_services import convert_uuids - -# Configure logging logger = logging.getLogger(__name__) -GPT_5_NANO_PRICING_DOLLARS_PER_MILLION_TOKENS = {"input": 0.05, "output": 0.40} - - -def calculate_cost_metrics(token_usage: dict, pricing: dict) -> dict: - """ - Calculate cost metrics based on token usage and pricing - - Args: - token_usage: Dictionary containing input_tokens and output_tokens - pricing: Dictionary containing input and output pricing per million tokens - - Returns: - Dictionary containing input_cost, output_cost, and total_cost in USD - """ - TOKENS_PER_MILLION = 1_000_000 - - # Pricing is in dollars per million tokens - input_cost_dollars = (pricing["input"] / TOKENS_PER_MILLION) * token_usage.get( - "input_tokens", 0 - ) - output_cost_dollars = (pricing["output"] / TOKENS_PER_MILLION) * token_usage.get( - "output_tokens", 0 - ) - total_cost_dollars = input_cost_dollars + output_cost_dollars - - return { - "input_cost": input_cost_dollars, - "output_cost": output_cost_dollars, - "total_cost": total_cost_dollars, - } - - -# Open AI Cookbook: Handling Function Calls with Reasoning Models -# https://cookbook.openai.com/examples/reasoning_function_calls -def invoke_functions_from_response( - response, tool_mapping: dict[str, Callable] -) -> list[dict]: - """Extract all function calls from the response, look up the corresponding tool function(s) and execute them. - (This would be a good place to handle asynchroneous tool calls, or ones that take a while to execute.) - This returns a list of messages to be added to the conversation history. - - Parameters - ---------- - response : OpenAI Response - The response object from OpenAI containing output items that may include function calls - tool_mapping : dict[str, Callable] - A dictionary mapping function names (as strings) to their corresponding Python functions. - Keys should match the function names defined in the tools schema. - - Returns - ------- - list[dict] - List of function call output messages formatted for the OpenAI conversation. - Each message contains: - - type: "function_call_output" - - call_id: The unique identifier for the function call - - output: The result returned by the executed function (string or error message) - """ - intermediate_messages = [] - for response_item in response.output: - if response_item.type == "function_call": - target_tool = tool_mapping.get(response_item.name) - if target_tool: - try: - arguments = json.loads(response_item.arguments) - logger.info( - f"Invoking tool: {response_item.name} with arguments: {arguments}" - ) - tool_output = target_tool(**arguments) - logger.info(f"Tool {response_item.name} completed successfully") - except Exception as e: - msg = f"Error executing function call: {response_item.name}: {e}" - tool_output = msg - logger.error(msg, exc_info=True) - else: - msg = f"ERROR - No tool registered for function call: {response_item.name}" - tool_output = msg - logger.error(msg) - intermediate_messages.append( - { - "type": "function_call_output", - "call_id": response_item.call_id, - "output": tool_output, - } - ) - elif response_item.type == "reasoning": - logger.info(f"Reasoning step: {response_item.summary}") - return intermediate_messages - @method_decorator(csrf_exempt, name="dispatch") class Assistant(APIView): @@ -133,208 +36,14 @@ class Assistant(APIView): def post(self, request): try: user = request.user - - client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) - - TOOL_DESCRIPTION = """ - Search the user's uploaded documents for information relevant to answering their question. - Call this function when you need to find specific information from the user's documents - to provide an accurate, citation-backed response. Always search before answering questions - about document content. - """ - - TOOL_PROPERTY_DESCRIPTION = """ - A specific search query to find relevant information in the user's documents. - Use keywords, phrases, or questions related to what the user is asking about. - Be specific rather than generic - use terms that would appear in the relevant documents. - """ - - tools = [ - { - "type": "function", - "name": "search_documents", - "description": TOOL_DESCRIPTION, - "parameters": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": TOOL_PROPERTY_DESCRIPTION, - } - }, - "required": ["query"], - }, - } - ] - - def search_documents(query: str, user=user) -> str: - """ - Search through user's uploaded documents using semantic similarity. - - This function performs vector similarity search against the user's document corpus - and returns formatted results with context information for the LLM to use. - - Parameters - ---------- - query : str - The search query string - user : User - The authenticated user whose documents to search - - Returns - ------- - str - Formatted search results containing document excerpts with metadata - - Raises - ------ - Exception - If embedding search fails - """ - - try: - embeddings_results = get_closest_embeddings( - user=user, message_data=query.strip() - ) - embeddings_results = convert_uuids(embeddings_results) - - if not embeddings_results: - return "No relevant documents found for your query. Please try different search terms or upload documents first." - - # Format results with clear structure and metadata - prompt_texts = [ - f"[Document {i + 1} - File: {obj['file_id']}, Name: {obj['name']}, Page: {obj['page_number']}, Chunk: {obj['chunk_number']}, Similarity: {1 - obj['distance']:.3f}]\n{obj['text']}\n[End Document {i + 1}]" - for i, obj in enumerate(embeddings_results) - ] - - return "\n\n".join(prompt_texts) - - except Exception as e: - return f"Error searching documents: {str(e)}. Please try again if the issue persists." - - INSTRUCTIONS = """ - You are an AI assistant that helps users find and understand information about bipolar disorder - from your internal library of bipolar disorder research sources using semantic search. - - IMPORTANT CONTEXT: - - You have access to a library of sources that the user CANNOT see - - The user did not upload these sources and doesn't know about them - - You must explain what information exists in your sources and provide clear references - - TOPIC RESTRICTIONS: - When a prompt is received that is unrelated to bipolar disorder, mental health treatment, - or psychiatric medications, respond by saying you are limited to bipolar-specific conversations. - - SEMANTIC SEARCH STRATEGY: - - Always perform semantic search using the search_documents function when users ask questions - - Use conceptually related terms and synonyms, not just exact keyword matches - - Search for the meaning and context of the user's question, not just literal words - - Consider medical terminology, lay terms, and related conditions when searching - - FUNCTION USAGE: - - When a user asks about information that might be in your source library, ALWAYS use the search_documents function first - - Perform semantic searches using concepts, symptoms, treatments, and related terms from the user's question - - Only provide answers based on information found through your source searches - - RESPONSE FORMAT: - After gathering information through semantic searches, provide responses that: - 1. Answer the user's question directly using only the found information - 2. Structure responses with clear sections and paragraphs - 3. Explain what information you found in your sources and provide context - 4. Include citations using this exact format: [Name {name}, Page {page_number}] - 5. Only cite information that directly supports your statements - - If no relevant information is found in your source library, clearly state that the information - is not available in your current sources. - - REMEMBER: You are working with an internal library of bipolar disorder sources that the user - cannot see. Always search these sources first, explain what you found, and provide proper citations. - """ - - MODEL_DEFAULTS = { - "instructions": INSTRUCTIONS, - "model": "gpt-5-nano", # 400,000 token context window - # A summary of the reasoning performed by the model. This can be useful for debugging and understanding the model's reasoning process. - "reasoning": {"effort": "low", "summary": None}, - "tools": tools, - } - - # We fetch a response and then kick off a loop to handle the response - + message = request.data.get("message", None) previous_response_id = request.data.get("previous_response_id", None) - - # Track total duration and cost metrics - start_time = time.time() - total_token_usage = {"input_tokens": 0, "output_tokens": 0} - - if not previous_response_id: - response = client.responses.create( - input=[ - {"type": "message", "role": "user", "content": str(message)} - ], - **MODEL_DEFAULTS, - ) - else: - response = client.responses.create( - input=[ - {"type": "message", "role": "user", "content": str(message)} - ], - previous_response_id=str(previous_response_id), - **MODEL_DEFAULTS, - ) - - # Accumulate token usage from initial response - if hasattr(response, "usage"): - total_token_usage["input_tokens"] += getattr( - response.usage, "input_tokens", 0 - ) - total_token_usage["output_tokens"] += getattr( - response.usage, "output_tokens", 0 - ) - - # Open AI Cookbook: Handling Function Calls with Reasoning Models - # https://cookbook.openai.com/examples/reasoning_function_calls - while True: - # Mapping of the tool names we tell the model about and the functions that implement them - function_responses = invoke_functions_from_response( - response, tool_mapping={"search_documents": search_documents} - ) - if len(function_responses) == 0: # We're done reasoning - logger.info("Reasoning completed") - final_response_output_text = response.output_text - final_response_id = response.id - logger.info(f"Final response: {final_response_output_text}") - break - else: - logger.info("More reasoning required, continuing...") - response = client.responses.create( - input=function_responses, - previous_response_id=response.id, - **MODEL_DEFAULTS, - ) - # Accumulate token usage from reasoning iterations - if hasattr(response, "usage"): - total_token_usage["input_tokens"] += getattr( - response.usage, "input_tokens", 0 - ) - total_token_usage["output_tokens"] += getattr( - response.usage, "output_tokens", 0 - ) - - # Calculate total duration and cost metrics - total_duration = time.time() - start_time - cost_metrics = calculate_cost_metrics( - total_token_usage, GPT_5_NANO_PRICING_DOLLARS_PER_MILLION_TOKENS - ) - - # Log cost and duration metrics - logger.info( - f"Request completed: " - f"Duration: {total_duration:.2f}s, " - f"Input tokens: {total_token_usage['input_tokens']}, " - f"Output tokens: {total_token_usage['output_tokens']}, " - f"Total cost: ${cost_metrics['total_cost']:.6f}" + + final_response_output_text, final_response_id = run_assistant( + message=message, + user=user, + previous_response_id=previous_response_id, ) return Response(