diff --git a/llms/generatecontent.go b/llms/generatecontent.go index d81160463..381349223 100644 --- a/llms/generatecontent.go +++ b/llms/generatecontent.go @@ -14,6 +14,13 @@ import ( type MessageContent struct { Role ChatMessageType Parts []ContentPart + + // ReasoningContent is used with reasoning models (e.g. deepseek-reasoner) + // to preserve the reasoning content in assistant messages for round-trip + // conversations. When the API returns reasoning_content in an assistant + // message (e.g. alongside tool_calls), this field must be included when + // sending the message back as part of conversation history. + ReasoningContent string } // TextPart creates TextContent from a given string. diff --git a/llms/marshaling.go b/llms/marshaling.go index a9297a977..2699ac7ec 100644 --- a/llms/marshaling.go +++ b/llms/marshaling.go @@ -14,25 +14,29 @@ func (mc MessageContent) MarshalJSON() ([]byte, error) { if hasSingleTextPart { tp, _ := mc.Parts[0].(TextContent) return json.Marshal(struct { - Role ChatMessageType `json:"role"` - Text string `json:"text"` - }{Role: mc.Role, Text: tp.Text}) + Role ChatMessageType `json:"role"` + Text string `json:"text"` + ReasoningContent string `json:"reasoning_content,omitempty"` + }{Role: mc.Role, Text: tp.Text, ReasoningContent: mc.ReasoningContent}) } return json.Marshal(struct { - Role ChatMessageType `json:"role"` - Parts []ContentPart `json:"parts"` + Role ChatMessageType `json:"role"` + Parts []ContentPart `json:"parts"` + ReasoningContent string `json:"reasoning_content,omitempty"` }{ - Role: mc.Role, - Parts: mc.Parts, + Role: mc.Role, + Parts: mc.Parts, + ReasoningContent: mc.ReasoningContent, }) } func (mc *MessageContent) UnmarshalJSON(data []byte) error { var m struct { - Role ChatMessageType `json:"role"` - Text string `json:"text"` - Parts []struct { + Role ChatMessageType `json:"role"` + Text string `json:"text"` + ReasoningContent string `json:"reasoning_content"` + Parts []struct { Type string `json:"type"` Text string `json:"text,omitempty"` ImageURL struct { @@ -60,6 +64,7 @@ func (mc *MessageContent) UnmarshalJSON(data []byte) error { return err } mc.Role = m.Role + mc.ReasoningContent = m.ReasoningContent for _, part := range m.Parts { switch part.Type { diff --git a/llms/marshaling_test.go b/llms/marshaling_test.go index 9e3c686b4..d6f8adb9d 100644 --- a/llms/marshaling_test.go +++ b/llms/marshaling_test.go @@ -188,7 +188,7 @@ role: user } } -func TestUnmarshalJSONMessageContent(t *testing.T) { +func TestUnmarshalJSONMessageContent(t *testing.T) { //nolint:funlen // We make an exception given the number of test cases. t.Parallel() tests := []struct { name string @@ -259,6 +259,18 @@ func TestUnmarshalJSONMessageContent(t *testing.T) { }, wantErr: false, }, + { + name: "assistant message with reasoning_content", + input: `{"role":"assistant","text":"final answer","reasoning_content":"step-by-step reasoning"}`, + want: MessageContent{ + Role: "assistant", + Parts: []ContentPart{ + TextContent{Text: "final answer"}, + }, + ReasoningContent: "step-by-step reasoning", + }, + wantErr: false, + }, } for _, tt := range tests { @@ -323,6 +335,29 @@ func TestMarshalJSONMessageContent(t *testing.T) { want: `{"role":"user","parts":[{}]}`, wantErr: false, }, + { + name: "assistant message with reasoning_content", + input: MessageContent{ + Role: "assistant", + Parts: []ContentPart{ + TextContent{Text: "final answer"}, + }, + ReasoningContent: "step-by-step reasoning", + }, + want: `{"role":"assistant","text":"final answer","reasoning_content":"step-by-step reasoning"}`, + wantErr: false, + }, + { + name: "message without reasoning_content omits field", + input: MessageContent{ + Role: "user", + Parts: []ContentPart{ + TextContent{Text: "Hello"}, + }, + }, + want: `{"role":"user","text":"Hello"}`, + wantErr: false, + }, } for _, tt := range tests { @@ -485,6 +520,28 @@ role: assistant }, }, }, + { + name: "assistant message with reasoning_content and tool calls", + in: MessageContent{ + Role: "assistant", + Parts: []ContentPart{ + ToolCall{Type: "function", ID: "tc01", FunctionCall: &FunctionCall{Name: "calculator", Arguments: `{"a":15,"b":28}`}}, + }, + ReasoningContent: "I need to use the calculator to add 15 and 28", + }, + assertedJSON: `{"role":"assistant","parts":[{"type":"tool_call","tool_call":{"function":{"name":"calculator","arguments":"{\"a\":15,\"b\":28}"},"id":"tc01","type":"function"}}],"reasoning_content":"I need to use the calculator to add 15 and 28"}`, + }, + { + name: "assistant message with reasoning_content single text", + in: MessageContent{ + Role: "assistant", + Parts: []ContentPart{ + TextContent{Text: "The answer is 43"}, + }, + ReasoningContent: "I calculated 15 + 28 = 43", + }, + assertedJSON: `{"role":"assistant","text":"The answer is 43","reasoning_content":"I calculated 15 + 28 = 43"}`, + }, } // Round-trip both JSON and YAML: diff --git a/llms/openai/internal/openaiclient/chat_test.go b/llms/openai/internal/openaiclient/chat_test.go index 0a26725b2..545a747fa 100644 --- a/llms/openai/internal/openaiclient/chat_test.go +++ b/llms/openai/internal/openaiclient/chat_test.go @@ -122,3 +122,37 @@ func TestChatMessage_MarshalUnmarshal_WithReasoning(t *testing.T) { require.NoError(t, err) require.Equal(t, msg, msg2) } + +func TestChatMessage_MarshalUnmarshal_WithReasoningAndToolCalls(t *testing.T) { + t.Parallel() + msg := ChatMessage{ + Role: "assistant", + Content: "", + ReasoningContent: "I need to use the calculator to add 15 and 28", + ToolCalls: []ToolCall{ + { + ID: "call_123", + Type: ToolTypeFunction, + Function: ToolFunction{ + Name: "calculator", + Arguments: `{"a":15,"b":28}`, + }, + }, + }, + } + text, err := json.Marshal(msg) + require.NoError(t, err) + + // Verify reasoning_content is present in serialized JSON + assert.Contains(t, string(text), `"reasoning_content"`) + assert.Contains(t, string(text), `"tool_calls"`) + + // Round-trip: unmarshal back + var msg2 ChatMessage + err = json.Unmarshal(text, &msg2) + require.NoError(t, err) + require.Equal(t, msg.ReasoningContent, msg2.ReasoningContent) + require.Equal(t, len(msg.ToolCalls), len(msg2.ToolCalls)) + require.Equal(t, msg.ToolCalls[0].ID, msg2.ToolCalls[0].ID) + require.Equal(t, msg.ToolCalls[0].Function.Name, msg2.ToolCalls[0].Function.Name) +} diff --git a/llms/openai/openaillm.go b/llms/openai/openaillm.go index 84690072a..fdad23d7c 100644 --- a/llms/openai/openaillm.go +++ b/llms/openai/openaillm.go @@ -151,6 +151,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten msg.Role = RoleSystem case llms.ChatMessageTypeAI: msg.Role = RoleAssistant + msg.ReasoningContent = mc.ReasoningContent case llms.ChatMessageTypeHuman: msg.Role = RoleUser // For models without system support, prepend system content to first user message