-
Notifications
You must be signed in to change notification settings - Fork 113
fix: handle direct model answers in ReACT loop #763
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -7,6 +7,10 @@ | |||||||||||||||||||||||||
| history tracking. Raises ``RuntimeError`` if the loop ends without a final answer. | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| from collections.abc import Awaitable, Callable | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| import pydantic | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # from PIL import Image as PILImage | ||||||||||||||||||||||||||
| from mellea.backends.model_options import ModelOption | ||||||||||||||||||||||||||
| from mellea.core.backend import Backend, BaseModelSubclass | ||||||||||||||||||||||||||
|
|
@@ -24,6 +28,14 @@ | |||||||||||||||||||||||||
| from mellea.stdlib.context import ChatContext | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| class TrueOrFalse(pydantic.BaseModel): | ||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have TrueOrFalse in three places in this PR (since one is a test, and one is an example I get there's justification)-- this particular one isn't used? Do you expect to use it in the framework?) |
||||||||||||||||||||||||||
| """Response indicating whether the ReACT agent has completed its task.""" | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| answer: bool = pydantic.Field( | ||||||||||||||||||||||||||
| description="True if you have enough information to answer the user's question, False if you need more tool calls" | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| async def react( | ||||||||||||||||||||||||||
| goal: str, | ||||||||||||||||||||||||||
| context: ChatContext, | ||||||||||||||||||||||||||
|
|
@@ -36,6 +48,19 @@ async def react( | |||||||||||||||||||||||||
| model_options: dict | None = None, | ||||||||||||||||||||||||||
| tools: list[AbstractMelleaTool] | None, | ||||||||||||||||||||||||||
| loop_budget: int = 10, | ||||||||||||||||||||||||||
| answer_check: Callable[ | ||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's good we have the flexibility though it's a lot of positional parameters - may be rationale for a new type (or Protocol) - but we do neither elsewhere given that callables only have 2 parms often. Technically it's not an issue - the code works. It's more about usability give it's an external API. Worth thinking about, but I wouldn't a say a blocker Essential will be to get the docstring right at least |
||||||||||||||||||||||||||
| [ | ||||||||||||||||||||||||||
| str, | ||||||||||||||||||||||||||
| ComputedModelOutputThunk[str], | ||||||||||||||||||||||||||
| ChatContext, | ||||||||||||||||||||||||||
| Backend, | ||||||||||||||||||||||||||
| dict | None, | ||||||||||||||||||||||||||
| int, | ||||||||||||||||||||||||||
| int, | ||||||||||||||||||||||||||
| ], | ||||||||||||||||||||||||||
| Awaitable[bool], | ||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||
| | None = None, | ||||||||||||||||||||||||||
| ) -> tuple[ComputedModelOutputThunk[str], ChatContext]: | ||||||||||||||||||||||||||
| """Asynchronous ReACT pattern (Think -> Act -> Observe -> Repeat Until Done); attempts to accomplish the provided goal given the provided tools. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
@@ -47,6 +72,11 @@ async def react( | |||||||||||||||||||||||||
| model_options: additional model options, which will upsert into the model/backend's defaults. | ||||||||||||||||||||||||||
| tools: the list of tools to use | ||||||||||||||||||||||||||
| loop_budget: the number of steps allowed; use -1 for unlimited | ||||||||||||||||||||||||||
| answer_check: optional callable to determine if the agent has completed its task. | ||||||||||||||||||||||||||
| Called every iteration when no tool calls are made and step.value exists (if provided). | ||||||||||||||||||||||||||
| Receives (goal, step, context, backend, model_options, turn_num, loop_budget). | ||||||||||||||||||||||||||
| Returns bool indicating if the task is complete. | ||||||||||||||||||||||||||
| If None, no answer check is performed (loop continues until finalizer or budget exhausted). | ||||||||||||||||||||||||||
|
Comment on lines
+75
to
+79
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Possible suggestion — the
Suggested change
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||
| A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`. | ||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Possible suggestion —
Suggested change
|
||||||||||||||||||||||||||
|
|
@@ -106,9 +136,31 @@ async def react( | |||||||||||||||||||||||||
| if tool_res.name == MELLEA_FINALIZER_TOOL: | ||||||||||||||||||||||||||
| is_final = True | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Check if the agent has completed its task (runs every iteration if answer_check is provided and there's a value) | ||||||||||||||||||||||||||
| # The answer_check function can decide when to actually check based on turn_num and loop_budget | ||||||||||||||||||||||||||
| elif not is_final and answer_check and step.value: | ||||||||||||||||||||||||||
| have_answer = await answer_check( | ||||||||||||||||||||||||||
| goal, step, context, backend, model_options, turn_num, loop_budget | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if have_answer: | ||||||||||||||||||||||||||
| # Create a synthetic finalizer tool response to be consistent with normal loop | ||||||||||||||||||||||||||
| finalizer_response = ToolMessage( | ||||||||||||||||||||||||||
| role="tool", | ||||||||||||||||||||||||||
| content=step.value or "", | ||||||||||||||||||||||||||
| tool_output=step.value or "", | ||||||||||||||||||||||||||
| name=MELLEA_FINALIZER_TOOL, | ||||||||||||||||||||||||||
| args={}, | ||||||||||||||||||||||||||
| tool=None, # type: ignore | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| tool_responses = [finalizer_response] | ||||||||||||||||||||||||||
| context = context.add(finalizer_response) | ||||||||||||||||||||||||||
| is_final = True | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if is_final: | ||||||||||||||||||||||||||
| assert len(tool_responses) == 1, "multiple tools were called with 'final'" | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Apply format if requested | ||||||||||||||||||||||||||
| if format is not None: | ||||||||||||||||||||||||||
| step, next_context = await mfuncs.aact( | ||||||||||||||||||||||||||
| action=ReactThought(), | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,108 @@ | ||||||
| """Test ReACT framework handling of direct answers without tool calls.""" | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This file probably should be in test/stdlib/frameworks to match the source structure?o |
||||||
|
|
||||||
| import pydantic | ||||||
| import pytest | ||||||
|
|
||||||
| from mellea.backends.tools import tool | ||||||
| from mellea.stdlib import functional as mfuncs | ||||||
| from mellea.stdlib.context import ChatContext | ||||||
| from mellea.stdlib.frameworks.react import react | ||||||
| from mellea.stdlib.session import start_session | ||||||
|
|
||||||
|
|
||||||
| class TrueOrFalse(pydantic.BaseModel): | ||||||
| """Response indicating whether the ReACT agent has completed its task.""" | ||||||
|
|
||||||
| answer: bool = pydantic.Field( | ||||||
| description="True if you have enough information to answer the user's question, False if you need more tool calls" | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| async def last_loop_completion_check( | ||||||
| goal, step, context, backend, model_options, turn_num, loop_budget | ||||||
| ): | ||||||
| """Completion check that asks the model if it has the answer on the last iteration. | ||||||
|
|
||||||
| Note: step.value is guaranteed to exist when this is called. | ||||||
| """ | ||||||
| # Only check on last iteration (and not for unlimited budget) | ||||||
| if loop_budget == -1 or turn_num < loop_budget: | ||||||
| return False | ||||||
|
|
||||||
| content = mfuncs.chat( | ||||||
| content=f"Do you know the answer to the user's original query ({goal})? If so, respond with True. If you need to take more actions, then respond False.", | ||||||
| context=context, | ||||||
| backend=backend, | ||||||
| format=TrueOrFalse, | ||||||
| )[0].content | ||||||
| have_answer = TrueOrFalse.model_validate_json(content).answer | ||||||
| return have_answer | ||||||
|
|
||||||
|
|
||||||
| @pytest.mark.ollama | ||||||
| @pytest.mark.llm | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
We updated our markers (llm->e2e).
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this change is needed for tests
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider |
||||||
| async def test_react_direct_answer_without_tools(): | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The control-flow change in We're been adding more unit/integration tests to reduce the need to run e2e with slower times/overhead locally -- so having some mocking like this adds in coverage for relevant code paths Suggest adding to @pytest.mark.asyncio
async def test_react_answer_check_terminates_on_direct_response():
"""answer_check returning True on a no-tool-call turn exits the loop."""
backend = ScriptedBackend([_ScriptedTurn(value="42")])
async def always_done(goal, step, ctx, backend, opts, turn, budget):
return True
result, _ = await react(
goal="answer",
context=ChatContext(),
backend=backend,
tools=None,
loop_budget=5,
answer_check=always_done,
)
assert result.value == "42" |
||||||
| """Test that ReACT handles direct answers when model doesn't call tools. | ||||||
|
|
||||||
| This tests the case where the model provides a direct answer in step.value | ||||||
| without making any tool calls. The fix ensures the loop terminates properly | ||||||
| instead of continuing until loop_budget is exhausted. | ||||||
| """ | ||||||
| m = start_session() | ||||||
|
|
||||||
| # Ask a simple question that doesn't require tools | ||||||
| # The model should provide a direct answer without calling any tools | ||||||
| out, _ = await react( | ||||||
| goal="What is 2 + 2?", | ||||||
| context=ChatContext(), | ||||||
| backend=m.backend, | ||||||
| tools=[], # No tools provided | ||||||
| loop_budget=3, # Should complete in 1 iteration, not exhaust budget | ||||||
| answer_check=last_loop_completion_check, | ||||||
| ) | ||||||
|
|
||||||
| # Verify we got an answer | ||||||
| assert out.value is not None | ||||||
| assert len(out.value) > 0 | ||||||
|
|
||||||
| # The answer should contain "4" or "four" | ||||||
| answer_lower = out.value.lower() | ||||||
| assert "4" in answer_lower or "four" in answer_lower | ||||||
|
|
||||||
|
|
||||||
| @pytest.mark.ollama | ||||||
| @pytest.mark.llm | ||||||
| async def test_react_direct_answer_with_unused_tools(): | ||||||
| """Test that ReACT handles direct answers even when tools are available. | ||||||
|
|
||||||
| This tests the case where tools are provided but the model chooses to | ||||||
| answer directly without using them. | ||||||
| """ | ||||||
| m = start_session() | ||||||
|
|
||||||
| # Create a dummy tool that won't be needed | ||||||
| @tool | ||||||
| def search_web(query: str) -> str: | ||||||
| """Search the web for information.""" | ||||||
| return "Search results" | ||||||
|
|
||||||
| # Ask a question that doesn't need the tool | ||||||
| out, _ = await react( | ||||||
| goal="What is the capital of France?", | ||||||
| context=ChatContext(), | ||||||
| backend=m.backend, | ||||||
| tools=[search_web], | ||||||
| loop_budget=3, | ||||||
| answer_check=last_loop_completion_check, | ||||||
| ) | ||||||
|
|
||||||
| # Verify we got an answer | ||||||
| assert out.value is not None | ||||||
| assert len(out.value) > 0 | ||||||
|
|
||||||
| # The answer should mention Paris | ||||||
| answer_lower = out.value.lower() | ||||||
| assert "paris" in answer_lower | ||||||
|
|
||||||
|
|
||||||
| # Made with Bob | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're calling a sync method from async code. Technically it works in this sample as there's nothing else happing.
But if the user copied it and used it when running multiple react loops, or within something like FastAPI where there's a shared event loop the this callback would freeze out other handlers whilst the llm call is running
can we await on the async variant instead ie
await mfuncs.achat(....)though note also the method sig is slightly different - a tuple vs a list -- I think this gives a better example - and as these are examples where we show appropriate usage I do think it's worth changing