Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 80 additions & 67 deletions agents/openai_functions_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"strings"

"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/chains"
Expand Down Expand Up @@ -270,92 +271,104 @@ func (o *OpenAIFunctionsAgent) ParseOutput(contentResp *llms.ContentResponse) (

// Check for new-style tool calls first
if len(choice.ToolCalls) > 0 {
// Handle multiple tool calls properly
actions := make([]schema.AgentAction, 0, len(choice.ToolCalls))

for _, toolCall := range choice.ToolCalls {
functionName := toolCall.FunctionCall.Name
toolInputStr := toolCall.FunctionCall.Arguments
toolInputMap := make(map[string]any, 0)
err := json.Unmarshal([]byte(toolInputStr), &toolInputMap)

toolInput := toolInputStr
if err == nil {
// Successfully parsed JSON, check for __arg1 pattern
if arg1, ok := toolInputMap["__arg1"]; ok {
toolInputCheck, ok := arg1.(string)
if ok {
toolInput = toolInputCheck
}
}
}
// If JSON parsing failed, use the raw string as tool input
// This handles cases like calculator expressions
return o.parseToolCalls(choice)
}

contentMsg := "\n"
if choice.Content != "" {
contentMsg = fmt.Sprintf("responded: %s\n", choice.Content)
}
// Check for legacy function call
if choice.FuncCall != nil {
return o.parseLegacyFunctionCall(choice)
}

actions = append(actions, schema.AgentAction{
Tool: functionName,
ToolInput: toolInput,
Log: fmt.Sprintf("Invoking: %s with %s %s", functionName, toolInputStr, contentMsg),
ToolID: toolCall.ID,
})
}
// No function/tool call - this is a finish
return nil, &schema.AgentFinish{
ReturnValues: map[string]any{
"output": choice.Content,
},
Log: choice.Content,
}, nil
}

return actions, nil, nil
func (o *OpenAIFunctionsAgent) parseToolCalls(choice *llms.ContentChoice) ([]schema.AgentAction, *schema.AgentFinish, error) {
actions := make([]schema.AgentAction, 0, len(choice.ToolCalls))
contentMsg := "\n"
if choice.Content != "" {
contentMsg = fmt.Sprintf("responded: %s\n", choice.Content)
}

// Check for legacy function call
if choice.FuncCall != nil {
functionCall := choice.FuncCall
functionName := functionCall.Name
toolInputStr := functionCall.Arguments
// Generate a shared log for all tool calls in this choice to ensure they are grouped together
// in constructScratchPad.
var logBuilder strings.Builder
logBuilder.WriteString(contentMsg)
for _, tc := range choice.ToolCalls {
logBuilder.WriteString(fmt.Sprintf("Invoking: %s with %s\n", tc.FunctionCall.Name, tc.FunctionCall.Arguments))
}
sharedLog := logBuilder.String()

for _, toolCall := range choice.ToolCalls {
functionName := toolCall.FunctionCall.Name
toolInputStr := toolCall.FunctionCall.Arguments
toolInputMap := make(map[string]any, 0)
err := json.Unmarshal([]byte(toolInputStr), &toolInputMap)
if err != nil {
// If it's not valid JSON, it might be a raw expression for the calculator
// Try to use it directly as tool input
return []schema.AgentAction{
{
Tool: functionName,
ToolInput: toolInputStr,
Log: fmt.Sprintf("Invoking: %s with %s\n", functionName, toolInputStr),
ToolID: "", // Legacy function calls don't have tool IDs
},
}, nil, nil
}

toolInput := toolInputStr
if arg1, ok := toolInputMap["__arg1"]; ok {
toolInputCheck, ok := arg1.(string)
if ok {
toolInput = toolInputCheck
if err == nil {
// Successfully parsed JSON, check for __arg1 pattern
if arg1, ok := toolInputMap["__arg1"]; ok {
toolInputCheck, ok := arg1.(string)
if ok {
toolInput = toolInputCheck
}
}
}

contentMsg := "\n"
if choice.Content != "" {
contentMsg = fmt.Sprintf("responded: %s\n", choice.Content)
}
actions = append(actions, schema.AgentAction{
Tool: functionName,
ToolInput: toolInput,
Log: sharedLog,
ToolID: toolCall.ID,
})
}

return actions, nil, nil
}

func (o *OpenAIFunctionsAgent) parseLegacyFunctionCall(choice *llms.ContentChoice) ([]schema.AgentAction, *schema.AgentFinish, error) {
functionCall := choice.FuncCall
functionName := functionCall.Name
toolInputStr := functionCall.Arguments
toolInputMap := make(map[string]any, 0)
err := json.Unmarshal([]byte(toolInputStr), &toolInputMap)
if err != nil {
// If it's not valid JSON, it might be a raw expression for the calculator
return []schema.AgentAction{
{
Tool: functionName,
ToolInput: toolInput,
Log: fmt.Sprintf("Invoking: %s with %s \n %s \n", functionName, toolInputStr, contentMsg),
ToolID: "", // Legacy function calls don't have tool IDs
ToolInput: toolInputStr,
Log: fmt.Sprintf("Invoking: %s with %s\n", functionName, toolInputStr),
ToolID: "",
},
}, nil, nil
}

// No function/tool call - this is a finish
return nil, &schema.AgentFinish{
ReturnValues: map[string]any{
"output": choice.Content,
toolInput := toolInputStr
if arg1, ok := toolInputMap["__arg1"]; ok {
toolInputCheck, ok := arg1.(string)
if ok {
toolInput = toolInputCheck
}
}

contentMsg := "\n"
if choice.Content != "" {
contentMsg = fmt.Sprintf("responded: %s\n", choice.Content)
}

return []schema.AgentAction{
{
Tool: functionName,
ToolInput: toolInput,
Log: fmt.Sprintf("Invoking: %s with %s \n %s \n", functionName, toolInputStr, contentMsg),
ToolID: "",
},
Log: choice.Content,
}, nil
}, nil, nil
}
Loading