diff --git a/packages/optimization/src/ldai_optimization/client.py b/packages/optimization/src/ldai_optimization/client.py index 3ee2973..63fdbaa 100644 --- a/packages/optimization/src/ldai_optimization/client.py +++ b/packages/optimization/src/ldai_optimization/client.py @@ -19,6 +19,7 @@ AIJudgeCallConfig, GroundTruthOptimizationOptions, GroundTruthSample, + HandleJudgeCall, JudgeResult, OptimizationContext, OptimizationFromConfigOptions, @@ -228,6 +229,11 @@ def _create_optimization_context( iteration=iteration, ) + @property + def _judge_call(self) -> HandleJudgeCall: + """Return the judge callable, falling back to handle_agent_call when not set.""" + return self._options.handle_judge_call or self._options.handle_agent_call + def _safe_status_update( self, status: Literal[ @@ -569,10 +575,9 @@ async def _evaluate_config_judge( LDMessage(role="user", content=judge_user_input), ] - # Collect model parameters from the judge config, separating out any existing tools - model_name = ( - judge_config.model.name if judge_config.model else self._options.judge_model - ) + # Always use the global judge_model; model parameters (temperature, etc.) from + # the judge flag are still forwarded, but the model name is never overridden. + model_name = self._options.judge_model model_params: Dict[str, Any] = {} tools: List[ToolDefinition] = [] if judge_config.model and judge_config.model._parameters: @@ -615,8 +620,8 @@ async def _evaluate_config_judge( ) _judge_start = time.monotonic() - result = self._options.handle_judge_call( - judge_key, judge_call_config, judge_ctx + result = self._judge_call( + judge_key, judge_call_config, judge_ctx, True ) judge_response: OptimizationResponse = await await_if_needed(result) judge_duration_ms = (time.monotonic() - _judge_start) * 1000 @@ -776,8 +781,8 @@ async def _evaluate_acceptance_judge( ) _judge_start = time.monotonic() - result = self._options.handle_judge_call( - judge_key, judge_call_config, judge_ctx + result = self._judge_call( + judge_key, judge_call_config, judge_ctx, True ) judge_response: OptimizationResponse = await await_if_needed(result) judge_duration_ms = (time.monotonic() - _judge_start) * 1000 @@ -1318,6 +1323,7 @@ async def _generate_new_variation( self._agent_key, agent_config, variation_ctx, + False, ) variation_response: OptimizationResponse = await await_if_needed(result) response_str = variation_response.output @@ -1717,6 +1723,7 @@ async def _execute_agent_turn( self._agent_key, self._build_agent_config_for_context(optimize_context), optimize_context, + False, ) agent_response: OptimizationResponse = await await_if_needed(result) agent_duration_ms = (time.monotonic() - _agent_start) * 1000 diff --git a/packages/optimization/src/ldai_optimization/dataclasses.py b/packages/optimization/src/ldai_optimization/dataclasses.py index f4d2f91..ff151be 100644 --- a/packages/optimization/src/ldai_optimization/dataclasses.py +++ b/packages/optimization/src/ldai_optimization/dataclasses.py @@ -119,6 +119,7 @@ async def handle_llm_call( key: str, config: LLMCallConfig, context: LLMCallContext, + is_evaluation: bool, ) -> OptimizationResponse: model_name = config.model.name if config.model else "gpt-4o" instructions = config.instructions or "" @@ -132,9 +133,12 @@ async def handle_llm_call( ) """ - key: str - model: Optional[ModelConfig] - instructions: Optional[str] + @property + def key(self) -> str: ... + @property + def model(self) -> Optional[ModelConfig]: ... + @property + def instructions(self) -> Optional[str]: ... class LLMCallContext(Protocol): @@ -144,8 +148,10 @@ class LLMCallContext(Protocol): ``handle_agent_call`` and ``handle_judge_call``. """ - user_input: Optional[str] - current_variables: Dict[str, Any] + @property + def user_input(self) -> Optional[str]: ... + @property + def current_variables(self) -> Dict[str, Any]: ... @dataclass @@ -282,12 +288,12 @@ class OptimizationJudgeContext: # the concrete types (AIAgentConfig / AIJudgeCallConfig) continue to work # because those types structurally satisfy the Protocols. HandleAgentCall = Union[ - Callable[[str, LLMCallConfig, LLMCallContext], OptimizationResponse], - Callable[[str, LLMCallConfig, LLMCallContext], Awaitable[OptimizationResponse]], + Callable[[str, LLMCallConfig, LLMCallContext, bool], OptimizationResponse], + Callable[[str, LLMCallConfig, LLMCallContext, bool], Awaitable[OptimizationResponse]], ] HandleJudgeCall = Union[ - Callable[[str, LLMCallConfig, LLMCallContext], OptimizationResponse], - Callable[[str, LLMCallConfig, LLMCallContext], Awaitable[OptimizationResponse]], + Callable[[str, LLMCallConfig, LLMCallContext, bool], OptimizationResponse], + Callable[[str, LLMCallConfig, LLMCallContext, bool], Awaitable[OptimizationResponse]], ] _StatusLiteral = Literal[ @@ -315,7 +321,8 @@ class OptimizationOptions: ] # choices of interpolated variables to be chosen at random per turn, 1 min required # Actual agent/completion (judge) calls - Required handle_agent_call: HandleAgentCall - handle_judge_call: HandleJudgeCall + # Optional; falls back to handle_agent_call when omitted (both share the same signature) + handle_judge_call: Optional[HandleJudgeCall] = None # Criteria for pass/fail - Optional user_input_options: Optional[List[str]] = ( None # optional list of user input messages to randomly select from @@ -401,7 +408,8 @@ class GroundTruthOptimizationOptions: model_choices: List[str] judge_model: str handle_agent_call: HandleAgentCall - handle_judge_call: HandleJudgeCall + # Optional; falls back to handle_agent_call when omitted (both share the same signature) + handle_judge_call: Optional[HandleJudgeCall] = None judges: Optional[Dict[str, OptimizationJudge]] = None on_turn: Optional[Callable[[OptimizationContext], bool]] = None on_sample_result: Optional[Callable[[OptimizationContext], None]] = None @@ -461,7 +469,8 @@ class OptimizationFromConfigOptions: project_key: str handle_agent_call: HandleAgentCall - handle_judge_call: HandleJudgeCall + # Optional; falls back to handle_agent_call when omitted (both share the same signature) + handle_judge_call: Optional[HandleJudgeCall] = None on_turn: Optional[Callable[["OptimizationContext"], bool]] = None on_sample_result: Optional[Callable[["OptimizationContext"], None]] = None on_passing_result: Optional[Callable[["OptimizationContext"], None]] = None diff --git a/packages/optimization/src/ldai_optimization/util.py b/packages/optimization/src/ldai_optimization/util.py index 0f901d5..7cdf300 100644 --- a/packages/optimization/src/ldai_optimization/util.py +++ b/packages/optimization/src/ldai_optimization/util.py @@ -4,7 +4,7 @@ import json import logging import re -from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union +from typing import Any, Awaitable, Dict, List, Optional, Tuple, TypeVar, Union from ldai_optimization.dataclasses import ToolDefinition @@ -156,18 +156,19 @@ def restore_variable_placeholders( return text, warnings -async def await_if_needed( - result: Union[str, Awaitable[str]] -) -> str: +_T = TypeVar("_T") + + +async def await_if_needed(result: Union[_T, Awaitable[_T]]) -> _T: """ Handle both sync and async callable results. - :param result: Either a string or an awaitable that returns a string - :return: The string result + :param result: Either a value or an awaitable that returns a value + :return: The resolved value """ - if isinstance(result, str): - return result - return await result + if inspect.isawaitable(result): + return await result # type: ignore[return-value] + return result # type: ignore[return-value] def create_evaluation_tool() -> ToolDefinition: diff --git a/packages/optimization/tests/test_client.py b/packages/optimization/tests/test_client.py index 39d7514..a478230 100644 --- a/packages/optimization/tests/test_client.py +++ b/packages/optimization/tests/test_client.py @@ -395,7 +395,7 @@ async def test_handle_judge_call_receives_correct_key_and_config(self): user_input="What time is it?", ) call_args = self.handle_judge_call.call_args - key, config, ctx = call_args.args + key, config, ctx, _ = call_args.args assert key == "relevance" assert isinstance(config, AIJudgeCallConfig) assert isinstance(ctx, OptimizationJudgeContext) @@ -412,7 +412,7 @@ async def test_messages_has_system_and_user_turns(self): reasoning_history="", user_input="What colour is the sky?", ) - _, config, _ = self.handle_judge_call.call_args.args + _, config, _, _ = self.handle_judge_call.call_args.args roles = [m.role for m in config.messages] assert roles == ["system", "user"] @@ -428,7 +428,7 @@ async def test_messages_system_content_matches_instructions(self): reasoning_history="", user_input="Is Paris in France?", ) - _, config, _ = self.handle_judge_call.call_args.args + _, config, _, _ = self.handle_judge_call.call_args.args system_msg = next(m for m in config.messages if m.role == "system") assert system_msg.content == config.instructions @@ -444,7 +444,7 @@ async def test_messages_user_content_matches_context_user_input(self): reasoning_history="", user_input="Capital of France?", ) - _, config, ctx = self.handle_judge_call.call_args.args + _, config, ctx, _ = self.handle_judge_call.call_args.args user_msg = next(m for m in config.messages if m.role == "user") assert user_msg.content == ctx.user_input @@ -460,7 +460,7 @@ async def test_acceptance_statement_in_instructions(self): user_input="Tell me about Paris.", ) call_args = self.handle_judge_call.call_args - _, config, _ = call_args.args + _, config, _, _ = call_args.args assert statement in config.instructions async def test_no_structured_output_tool_in_judge_config(self): @@ -475,7 +475,7 @@ async def test_no_structured_output_tool_in_judge_config(self): user_input="Is Paris in France?", ) call_args = self.handle_judge_call.call_args - _, config, _ = call_args.args + _, config, _, _ = call_args.args tools = config.model.get_parameter("tools") or [] assert tools == [] @@ -494,7 +494,7 @@ async def test_agent_tools_included_in_config_tools(self): agent_tools=[agent_tool], ) call_args = self.handle_judge_call.call_args - _, config, _ = call_args.args + _, config, _, _ = call_args.args tools = config.model.get_parameter("tools") or [] tool_names = [t["name"] for t in tools] assert tool_names == ["lookup"] @@ -512,7 +512,7 @@ async def test_variables_in_context(self): variables=variables, ) call_args = self.handle_judge_call.call_args - _, _, ctx = call_args.args + _, _, ctx, _ = call_args.args assert ctx.current_variables == variables async def test_duration_context_added_to_instructions_when_latency_keyword_present(self): @@ -531,7 +531,7 @@ async def test_duration_context_added_to_instructions_when_latency_keyword_prese user_input="Tell me something.", agent_duration_ms=1500.0, ) - _, config, _ = self.handle_judge_call.call_args.args + _, config, _, _ = self.handle_judge_call.call_args.args assert "1500ms" in config.instructions assert "mention the duration" in config.instructions @@ -561,7 +561,7 @@ async def test_duration_context_includes_baseline_comparison_when_history_presen user_input="Tell me something.", agent_duration_ms=1500.0, ) - _, config, _ = self.handle_judge_call.call_args.args + _, config, _, _ = self.handle_judge_call.call_args.args assert "1500ms" in config.instructions assert "2000ms" in config.instructions assert "faster" in config.instructions @@ -592,7 +592,7 @@ async def test_duration_context_says_slower_when_candidate_is_slower(self): user_input="Tell me something.", agent_duration_ms=1800.0, ) - _, config, _ = self.handle_judge_call.call_args.args + _, config, _, _ = self.handle_judge_call.call_args.args assert "slower" in config.instructions async def test_duration_context_not_added_when_no_latency_keyword(self): @@ -610,7 +610,7 @@ async def test_duration_context_not_added_when_no_latency_keyword(self): user_input="Capital of France?", agent_duration_ms=2000.0, ) - _, config, _ = self.handle_judge_call.call_args.args + _, config, _, _ = self.handle_judge_call.call_args.args assert "2000ms" not in config.instructions assert "duration" not in config.instructions.lower() or "acceptance" in config.instructions.lower() @@ -629,7 +629,7 @@ async def test_duration_context_not_added_when_agent_duration_ms_is_none(self): user_input="Tell me something.", agent_duration_ms=None, ) - _, config, _ = self.handle_judge_call.call_args.args + _, config, _, _ = self.handle_judge_call.call_args.args assert "mention the duration" not in config.instructions async def test_returns_zero_score_on_missing_acceptance_statement(self): @@ -698,7 +698,7 @@ async def test_calls_handle_judge_call_with_correct_config_type(self): user_input="What is X?", ) call_args = self.handle_judge_call.call_args - key, config, ctx = call_args.args + key, config, ctx, _ = call_args.args assert key == "quality" assert isinstance(config, AIJudgeCallConfig) assert "You are an evaluator." in config.instructions @@ -715,7 +715,7 @@ async def test_messages_has_system_and_user_turns(self): reasoning_history="", user_input="What is X?", ) - _, config, _ = self.handle_judge_call.call_args.args + _, config, _, _ = self.handle_judge_call.call_args.args roles = [m.role for m in config.messages] assert roles == ["system", "user"] @@ -730,7 +730,7 @@ async def test_messages_system_content_matches_instructions(self): reasoning_history="", user_input="What is X?", ) - _, config, _ = self.handle_judge_call.call_args.args + _, config, _, _ = self.handle_judge_call.call_args.args system_msg = next(m for m in config.messages if m.role == "system") assert system_msg.content == config.instructions @@ -745,7 +745,7 @@ async def test_messages_user_content_matches_context_user_input(self): reasoning_history="", user_input="What is X?", ) - _, config, ctx = self.handle_judge_call.call_args.args + _, config, ctx, _ = self.handle_judge_call.call_args.args user_msg = next(m for m in config.messages if m.role == "user") assert user_msg.content == ctx.user_input @@ -760,7 +760,7 @@ async def test_messages_user_content_contains_ld_user_message(self): reasoning_history="", user_input="What is X?", ) - _, config, _ = self.handle_judge_call.call_args.args + _, config, _, _ = self.handle_judge_call.call_args.args user_msg = next(m for m in config.messages if m.role == "user") assert "Evaluate this response." in user_msg.content @@ -830,7 +830,7 @@ async def test_agent_tools_included_without_evaluation_tool(self): user_input="Q?", agent_tools=[agent_tool], ) - _, config, _ = self.handle_judge_call.call_args.args + _, config, _, _ = self.handle_judge_call.call_args.args tools = config.model.get_parameter("tools") or [] names = [t["name"] for t in tools] assert names == ["search"] @@ -872,7 +872,7 @@ async def test_calls_handle_agent_call_with_config_and_context(self): ctx = self._make_context() await self.client._execute_agent_turn(ctx, iteration=1) self.handle_agent_call.assert_called_once() - key, config, passed_ctx = self.handle_agent_call.call_args.args + key, config, passed_ctx, _ = self.handle_agent_call.call_args.args assert key == "test-agent" assert isinstance(config, AIAgentConfig) assert passed_ctx is ctx @@ -891,7 +891,7 @@ async def test_judge_scores_stored_in_returned_context(self): async def test_variables_interpolated_into_agent_config_instructions(self): ctx = self._make_context() await self.client._execute_agent_turn(ctx, iteration=1) - _, config, _ = self.handle_agent_call.call_args.args + _, config, _, _ = self.handle_agent_call.call_args.args assert "{{language}}" not in config.instructions assert "English" in config.instructions @@ -933,14 +933,14 @@ async def test_updates_current_model(self): async def test_no_structured_output_tool_in_variation_config(self): """Variation turn must not inject the structured-output tool — prompts use plain JSON.""" await self.client._generate_new_variation(iteration=1, variables={}) - _, config, _ = self.handle_agent_call.call_args.args + _, config, _, _ = self.handle_agent_call.call_args.args tools = config.model.get_parameter("tools") or [] assert tools == [] async def test_variation_call_uses_three_arg_signature(self): """handle_agent_call receives exactly (key, config, context) — no tools arg.""" await self.client._generate_new_variation(iteration=1, variables={}) - assert len(self.handle_agent_call.call_args.args) == 3 + assert len(self.handle_agent_call.call_args.args) == 4 async def test_model_not_updated_when_not_in_model_choices(self): bad_response = json.dumps({ @@ -1220,7 +1220,7 @@ async def test_validation_runs_additional_agent_calls(self): """With 8 variable choices, validation runs 2 extra agent calls after the initial pass.""" call_count = [0] - async def counting_agent(key, config, ctx): + async def counting_agent(key, config, ctx, is_evaluation=False): call_count[0] += 1 return OptimizationResponse(output="answer") @@ -1265,7 +1265,7 @@ async def test_validation_does_not_reuse_passing_turn_variable(self): """The variable set used in the initial passing turn must not appear in validation.""" seen_variables = [] - async def capture_agent(key, config, ctx): + async def capture_agent(key, config, ctx, is_evaluation=False): seen_variables.append(ctx.current_variables) return OptimizationResponse(output="answer") @@ -1285,7 +1285,7 @@ async def test_validation_uses_user_input_options_as_pool_when_provided(self): """When user_input_options is provided, validation samples from that pool.""" seen_inputs = [] - async def capture_agent(key, config, ctx): + async def capture_agent(key, config, ctx, is_evaluation=False): seen_inputs.append(ctx.user_input) return OptimizationResponse(output="answer") @@ -1308,7 +1308,7 @@ async def test_pool_exhaustion_caps_validation_at_available_distinct_items(self) """When fewer distinct items remain than validation_count, all available ones are used.""" call_count = [0] - async def counting_agent(key, config, ctx): + async def counting_agent(key, config, ctx, is_evaluation=False): call_count[0] += 1 return OptimizationResponse(output="answer") @@ -1324,7 +1324,7 @@ async def test_single_variable_choice_falls_back_to_repeated_draw(self): """With only 1 variable choice validation still runs 1 sample (repeated draw).""" call_count = [0] - async def counting_agent(key, config, ctx): + async def counting_agent(key, config, ctx, is_evaluation=False): call_count[0] += 1 return OptimizationResponse(output="answer") @@ -2780,7 +2780,7 @@ def bad_callback(ctx): async def test_variables_from_samples_used_per_evaluation(self): client = self._make_client() received_contexts = [] - async def capture_agent_call(key, config, ctx): + async def capture_agent_call(key, config, ctx, is_evaluation=False): received_contexts.append(ctx) return OptimizationResponse(output="response") @@ -2802,7 +2802,7 @@ async def test_model_falls_back_to_first_model_choice_when_agent_config_has_no_m client = _make_client(mock_ldai) observed_models = [] - async def capture(key, config, ctx): + async def capture(key, config, ctx, is_evaluation=False): observed_models.append(config.model.name if config.model else None) return OptimizationResponse(output="answer") @@ -2844,7 +2844,7 @@ def setup_method(self): async def test_expected_response_included_in_acceptance_judge_user_message(self): captured_configs = [] - async def capture_judge_call(key, config, ctx): + async def capture_judge_call(key, config, ctx, is_evaluation=False): captured_configs.append(config) return OptimizationResponse(output=JUDGE_PASS_RESPONSE) @@ -2866,7 +2866,7 @@ async def capture_judge_call(key, config, ctx): async def test_expected_response_in_acceptance_judge_user_message(self): captured_configs = [] - async def capture_judge_call(key, config, ctx): + async def capture_judge_call(key, config, ctx, is_evaluation=False): captured_configs.append(config) return OptimizationResponse(output=JUDGE_PASS_RESPONSE) @@ -2891,7 +2891,7 @@ async def capture_judge_call(key, config, ctx): async def test_no_expected_response_leaves_judge_messages_unchanged(self): captured_configs = [] - async def capture_judge_call(key, config, ctx): + async def capture_judge_call(key, config, ctx, is_evaluation=False): captured_configs.append(config) return OptimizationResponse(output=JUDGE_PASS_RESPONSE) diff --git a/packages/optimization/tests/test_ld_api_client.py b/packages/optimization/tests/test_ld_api_client.py index da79025..ec0248e 100644 --- a/packages/optimization/tests/test_ld_api_client.py +++ b/packages/optimization/tests/test_ld_api_client.py @@ -11,9 +11,9 @@ from ldai_optimization.ld_api_client import ( AgentOptimizationConfig, + AgentOptimizationResultPost as OptimizationResultPayload, LDApiClient, LDApiError, - OptimizationResultPayload, _parse_agent_optimization, )