Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions packages/optimization/src/ldai_optimization/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
AIJudgeCallConfig,
GroundTruthOptimizationOptions,
GroundTruthSample,
HandleJudgeCall,
JudgeResult,
OptimizationContext,
OptimizationFromConfigOptions,
Expand Down Expand Up @@ -228,6 +229,11 @@ def _create_optimization_context(
iteration=iteration,
)

@property
def _judge_call(self) -> HandleJudgeCall:
"""Return the judge callable, falling back to handle_agent_call when not set."""
return self._options.handle_judge_call or self._options.handle_agent_call

def _safe_status_update(
self,
status: Literal[
Expand Down Expand Up @@ -569,10 +575,9 @@ async def _evaluate_config_judge(
LDMessage(role="user", content=judge_user_input),
]

# Collect model parameters from the judge config, separating out any existing tools
model_name = (
judge_config.model.name if judge_config.model else self._options.judge_model
)
# Always use the global judge_model; model parameters (temperature, etc.) from
# the judge flag are still forwarded, but the model name is never overridden.
model_name = self._options.judge_model
model_params: Dict[str, Any] = {}
tools: List[ToolDefinition] = []
if judge_config.model and judge_config.model._parameters:
Expand Down Expand Up @@ -615,8 +620,8 @@ async def _evaluate_config_judge(
)

_judge_start = time.monotonic()
result = self._options.handle_judge_call(
judge_key, judge_call_config, judge_ctx
result = self._judge_call(
judge_key, judge_call_config, judge_ctx, True
)
judge_response: OptimizationResponse = await await_if_needed(result)
judge_duration_ms = (time.monotonic() - _judge_start) * 1000
Expand Down Expand Up @@ -776,8 +781,8 @@ async def _evaluate_acceptance_judge(
)

_judge_start = time.monotonic()
result = self._options.handle_judge_call(
judge_key, judge_call_config, judge_ctx
result = self._judge_call(
judge_key, judge_call_config, judge_ctx, True
)
judge_response: OptimizationResponse = await await_if_needed(result)
judge_duration_ms = (time.monotonic() - _judge_start) * 1000
Expand Down Expand Up @@ -1318,6 +1323,7 @@ async def _generate_new_variation(
self._agent_key,
agent_config,
variation_ctx,
False,
)
variation_response: OptimizationResponse = await await_if_needed(result)
response_str = variation_response.output
Expand Down Expand Up @@ -1717,6 +1723,7 @@ async def _execute_agent_turn(
self._agent_key,
self._build_agent_config_for_context(optimize_context),
optimize_context,
False,
)
agent_response: OptimizationResponse = await await_if_needed(result)
agent_duration_ms = (time.monotonic() - _agent_start) * 1000
Expand Down
17 changes: 10 additions & 7 deletions packages/optimization/src/ldai_optimization/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,12 @@ class OptimizationJudgeContext:
# the concrete types (AIAgentConfig / AIJudgeCallConfig) continue to work
# because those types structurally satisfy the Protocols.
HandleAgentCall = Union[
Callable[[str, LLMCallConfig, LLMCallContext], OptimizationResponse],
Callable[[str, LLMCallConfig, LLMCallContext], Awaitable[OptimizationResponse]],
Callable[[str, LLMCallConfig, LLMCallContext, bool], OptimizationResponse],
Callable[[str, LLMCallConfig, LLMCallContext, bool], Awaitable[OptimizationResponse]],
]
HandleJudgeCall = Union[
Callable[[str, LLMCallConfig, LLMCallContext], OptimizationResponse],
Callable[[str, LLMCallConfig, LLMCallContext], Awaitable[OptimizationResponse]],
Callable[[str, LLMCallConfig, LLMCallContext, bool], OptimizationResponse],
Callable[[str, LLMCallConfig, LLMCallContext, bool], Awaitable[OptimizationResponse]],
]

_StatusLiteral = Literal[
Expand Down Expand Up @@ -315,7 +315,8 @@ class OptimizationOptions:
] # choices of interpolated variables to be chosen at random per turn, 1 min required
# Actual agent/completion (judge) calls - Required
handle_agent_call: HandleAgentCall
handle_judge_call: HandleJudgeCall
# Optional; falls back to handle_agent_call when omitted (both share the same signature)
handle_judge_call: Optional[HandleJudgeCall] = None
# Criteria for pass/fail - Optional
user_input_options: Optional[List[str]] = (
None # optional list of user input messages to randomly select from
Expand Down Expand Up @@ -401,7 +402,8 @@ class GroundTruthOptimizationOptions:
model_choices: List[str]
judge_model: str
handle_agent_call: HandleAgentCall
handle_judge_call: HandleJudgeCall
# Optional; falls back to handle_agent_call when omitted (both share the same signature)
handle_judge_call: Optional[HandleJudgeCall] = None
judges: Optional[Dict[str, OptimizationJudge]] = None
on_turn: Optional[Callable[[OptimizationContext], bool]] = None
on_sample_result: Optional[Callable[[OptimizationContext], None]] = None
Expand Down Expand Up @@ -461,7 +463,8 @@ class OptimizationFromConfigOptions:

project_key: str
handle_agent_call: HandleAgentCall
handle_judge_call: HandleJudgeCall
# Optional; falls back to handle_agent_call when omitted (both share the same signature)
handle_judge_call: Optional[HandleJudgeCall] = None
on_turn: Optional[Callable[["OptimizationContext"], bool]] = None
on_sample_result: Optional[Callable[["OptimizationContext"], None]] = None
on_passing_result: Optional[Callable[["OptimizationContext"], None]] = None
Expand Down
Loading
Loading