diff --git a/agents/openai_functions_agent.go b/agents/openai_functions_agent.go index 0d318aae0..6d44cc2b1 100644 --- a/agents/openai_functions_agent.go +++ b/agents/openai_functions_agent.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "strings" "github.com/tmc/langchaingo/callbacks" "github.com/tmc/langchaingo/chains" @@ -270,92 +271,104 @@ func (o *OpenAIFunctionsAgent) ParseOutput(contentResp *llms.ContentResponse) ( // Check for new-style tool calls first if len(choice.ToolCalls) > 0 { - // Handle multiple tool calls properly - actions := make([]schema.AgentAction, 0, len(choice.ToolCalls)) - - for _, toolCall := range choice.ToolCalls { - functionName := toolCall.FunctionCall.Name - toolInputStr := toolCall.FunctionCall.Arguments - toolInputMap := make(map[string]any, 0) - err := json.Unmarshal([]byte(toolInputStr), &toolInputMap) - - toolInput := toolInputStr - if err == nil { - // Successfully parsed JSON, check for __arg1 pattern - if arg1, ok := toolInputMap["__arg1"]; ok { - toolInputCheck, ok := arg1.(string) - if ok { - toolInput = toolInputCheck - } - } - } - // If JSON parsing failed, use the raw string as tool input - // This handles cases like calculator expressions + return o.parseToolCalls(choice) + } - contentMsg := "\n" - if choice.Content != "" { - contentMsg = fmt.Sprintf("responded: %s\n", choice.Content) - } + // Check for legacy function call + if choice.FuncCall != nil { + return o.parseLegacyFunctionCall(choice) + } - actions = append(actions, schema.AgentAction{ - Tool: functionName, - ToolInput: toolInput, - Log: fmt.Sprintf("Invoking: %s with %s %s", functionName, toolInputStr, contentMsg), - ToolID: toolCall.ID, - }) - } + // No function/tool call - this is a finish + return nil, &schema.AgentFinish{ + ReturnValues: map[string]any{ + "output": choice.Content, + }, + Log: choice.Content, + }, nil +} - return actions, nil, nil +func (o *OpenAIFunctionsAgent) parseToolCalls(choice *llms.ContentChoice) ([]schema.AgentAction, *schema.AgentFinish, error) { + actions := make([]schema.AgentAction, 0, len(choice.ToolCalls)) + contentMsg := "\n" + if choice.Content != "" { + contentMsg = fmt.Sprintf("responded: %s\n", choice.Content) } - // Check for legacy function call - if choice.FuncCall != nil { - functionCall := choice.FuncCall - functionName := functionCall.Name - toolInputStr := functionCall.Arguments + // Generate a shared log for all tool calls in this choice to ensure they are grouped together + // in constructScratchPad. + var logBuilder strings.Builder + logBuilder.WriteString(contentMsg) + for _, tc := range choice.ToolCalls { + logBuilder.WriteString(fmt.Sprintf("Invoking: %s with %s\n", tc.FunctionCall.Name, tc.FunctionCall.Arguments)) + } + sharedLog := logBuilder.String() + + for _, toolCall := range choice.ToolCalls { + functionName := toolCall.FunctionCall.Name + toolInputStr := toolCall.FunctionCall.Arguments toolInputMap := make(map[string]any, 0) err := json.Unmarshal([]byte(toolInputStr), &toolInputMap) - if err != nil { - // If it's not valid JSON, it might be a raw expression for the calculator - // Try to use it directly as tool input - return []schema.AgentAction{ - { - Tool: functionName, - ToolInput: toolInputStr, - Log: fmt.Sprintf("Invoking: %s with %s\n", functionName, toolInputStr), - ToolID: "", // Legacy function calls don't have tool IDs - }, - }, nil, nil - } toolInput := toolInputStr - if arg1, ok := toolInputMap["__arg1"]; ok { - toolInputCheck, ok := arg1.(string) - if ok { - toolInput = toolInputCheck + if err == nil { + // Successfully parsed JSON, check for __arg1 pattern + if arg1, ok := toolInputMap["__arg1"]; ok { + toolInputCheck, ok := arg1.(string) + if ok { + toolInput = toolInputCheck + } } } - contentMsg := "\n" - if choice.Content != "" { - contentMsg = fmt.Sprintf("responded: %s\n", choice.Content) - } + actions = append(actions, schema.AgentAction{ + Tool: functionName, + ToolInput: toolInput, + Log: sharedLog, + ToolID: toolCall.ID, + }) + } + + return actions, nil, nil +} +func (o *OpenAIFunctionsAgent) parseLegacyFunctionCall(choice *llms.ContentChoice) ([]schema.AgentAction, *schema.AgentFinish, error) { + functionCall := choice.FuncCall + functionName := functionCall.Name + toolInputStr := functionCall.Arguments + toolInputMap := make(map[string]any, 0) + err := json.Unmarshal([]byte(toolInputStr), &toolInputMap) + if err != nil { + // If it's not valid JSON, it might be a raw expression for the calculator return []schema.AgentAction{ { Tool: functionName, - ToolInput: toolInput, - Log: fmt.Sprintf("Invoking: %s with %s \n %s \n", functionName, toolInputStr, contentMsg), - ToolID: "", // Legacy function calls don't have tool IDs + ToolInput: toolInputStr, + Log: fmt.Sprintf("Invoking: %s with %s\n", functionName, toolInputStr), + ToolID: "", }, }, nil, nil } - // No function/tool call - this is a finish - return nil, &schema.AgentFinish{ - ReturnValues: map[string]any{ - "output": choice.Content, + toolInput := toolInputStr + if arg1, ok := toolInputMap["__arg1"]; ok { + toolInputCheck, ok := arg1.(string) + if ok { + toolInput = toolInputCheck + } + } + + contentMsg := "\n" + if choice.Content != "" { + contentMsg = fmt.Sprintf("responded: %s\n", choice.Content) + } + + return []schema.AgentAction{ + { + Tool: functionName, + ToolInput: toolInput, + Log: fmt.Sprintf("Invoking: %s with %s \n %s \n", functionName, toolInputStr, contentMsg), + ToolID: "", }, - Log: choice.Content, - }, nil + }, nil, nil }