Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion components/model/ark/examples/intent_tool/intent_tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func textToolCall(ctx context.Context, chatModel model.ToolCallingChatModel) {
},
{
Role: schema.User,
Content: "My name is zhangsan, and my email is zhangsan@bytedance.com. Please recommend some suitable houses for me.",
Content: "My name is zhangsan, and my email is zhangsan@bytedance.com. Please search my salary.",
},
})

Expand Down
101 changes: 101 additions & 0 deletions components/model/ark/message_extra.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@
package ark

import (
"strings"

"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
)

const (
keyOfRequestID = "ark-request-id"
keyOfReasoningContent = "ark-reasoning-content"
keyOfReasoningID = "ark-reasoning-id"
keyOfModelName = "ark-model-name"
videoURLFPS = "ark-model-video-url-fps"
keyOfContextID = "ark-context-id"
Expand All @@ -32,6 +35,7 @@ const (
keyOfServiceTier = "ark-service-tier"
keyOfPartial = "ark-partial"
ImageSizeKey = "seedream-image-size"
keyOfOutputItemsOrder = "ark-output-items-order"
)

type arkRequestID string
Expand All @@ -40,8 +44,46 @@ type arkServiceTier string
type arkResponseID string
type arkContextID string
type arkResponseCacheExpireAt int64
type arkOutputItemsOrder string

// outputItemType represents the type of an output item in the responses API.
type outputItemType string

const (
// outputItemTypeMessage represents a message output item.
outputItemTypeMessage outputItemType = "message"
// outputItemTypeReasoning represents a reasoning output item.
outputItemTypeReasoning outputItemType = "reasoning"
// outputItemTypeFunctionCall represents a function call output item.
outputItemTypeFunctionCall outputItemType = "function_call"
)

func init() {
compose.RegisterStreamChunkConcatFunc(func(ts []arkOutputItemsOrder) (arkOutputItemsOrder, error) {
if len(ts) == 0 {
return "", nil
}
if len(ts) == 1 {
return ts[0], nil
}
var ret []outputItemType
var lastType outputItemType
for _, t := range ts {
if len(t) == 0 {
continue
}
itemTypes := parseOutputItemsOrder(t)
for _, it := range itemTypes {
if it != lastType {
ret = append(ret, it)
lastType = it
}
}
}
return encodeOutputItemsOrder(ret), nil
})
schema.RegisterName[arkOutputItemsOrder]("_eino_ext_ark_output_items_order")

compose.RegisterStreamChunkConcatFunc(func(chunks []arkRequestID) (final arkRequestID, err error) {
if len(chunks) == 0 {
return "", nil
Expand Down Expand Up @@ -120,6 +162,14 @@ func setReasoningContent(msg *schema.Message, reasoningContent string) {
setMsgExtra(msg, keyOfReasoningContent, reasoningContent)
}

func getReasoningID(msg *schema.Message) (string, bool) {
return getMsgExtraValue[string](msg, keyOfReasoningID)
}

func setReasoningID(msg *schema.Message, id string) {
setMsgExtra(msg, keyOfReasoningID, id)
}

func GetModelName(msg *schema.Message) (string, bool) {
modelName, ok := getMsgExtraValue[arkModelName](msg, keyOfModelName)
if !ok {
Expand Down Expand Up @@ -392,3 +442,54 @@ func getPartial(msg *schema.Message) bool {
}
return v
}

// getOutputItemsOrder returns the output items order from the message.
// This records the original order of items (message, reasoning, function_call) in the responses API output,
// so that when converting back to InputItem list, the original order can be preserved.
// Only available for ResponsesAPI.
//
// The order is stored as a comma-separated string in Extra, e.g. "message,reasoning,function_call,function_call".
func getOutputItemsOrder(msg *schema.Message) ([]outputItemType, bool) {
orderStr, ok := getMsgExtraValue[arkOutputItemsOrder](msg, keyOfOutputItemsOrder)
if !ok {
// Fallback for deserialized string type.
s, ok := getMsgExtraValue[string](msg, keyOfOutputItemsOrder)
if !ok || s == "" {
return nil, false
}
orderStr = arkOutputItemsOrder(s)
}
result := parseOutputItemsOrder(orderStr)
return result, len(result) > 0
}

func setOutputItemsOrder(msg *schema.Message, order []outputItemType) {
setMsgExtra(msg, keyOfOutputItemsOrder, arkOutputItemsOrder(encodeOutputItemsOrder(order)))
}

// encodeOutputItemsOrder encodes the order entries into a comma-separated string.
// e.g. "message,reasoning,function_call,function_call"
func encodeOutputItemsOrder(order []outputItemType) arkOutputItemsOrder {
parts := make([]string, 0, len(order))
for _, entry := range order {
parts = append(parts, string(entry))
}
return arkOutputItemsOrder(strings.Join(parts, ","))
}

// parseOutputItemsOrder parses a comma-separated order string back into entries.
func parseOutputItemsOrder(s arkOutputItemsOrder) []outputItemType {
if s == "" {
return nil
}
parts := strings.Split(string(s), ",")
entries := make([]outputItemType, 0, len(parts))
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
entries = append(entries, outputItemType(part))
}
return entries
}
92 changes: 80 additions & 12 deletions components/model/ark/responses_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -624,19 +624,62 @@ func (cm *ResponsesAPIChatModel) populateInput(in []*schema.Message, responseReq
if err != nil {
return err
}
if len(inputMessage.GetContent()) > 0 {
itemList = append(itemList, &responses.InputItem{Union: &responses.InputItem_InputMessage{InputMessage: inputMessage}})
}

for _, toolCall := range msg.ToolCalls {
itemList = append(itemList, &responses.InputItem{Union: &responses.InputItem_FunctionToolCall{
FunctionToolCall: &responses.ItemFunctionToolCall{
Type: responses.ItemType_function_call,
CallId: toolCall.ID,
Arguments: toolCall.Function.Arguments,
Name: toolCall.Function.Name,
},
}})
itemsOrder, hasOrder := getOutputItemsOrder(msg)
if hasOrder {
// Reconstruct items in the original order recorded during output conversion.
toolCallIdx := 0
for _, entry := range itemsOrder {
switch entry {
case outputItemTypeMessage:
if len(inputMessage.GetContent()) > 0 {
itemList = append(itemList, &responses.InputItem{Union: &responses.InputItem_InputMessage{InputMessage: inputMessage}})
}
case outputItemTypeFunctionCall:
if toolCallIdx < len(msg.ToolCalls) {
toolCall := msg.ToolCalls[toolCallIdx]
itemList = append(itemList, &responses.InputItem{Union: &responses.InputItem_FunctionToolCall{
FunctionToolCall: &responses.ItemFunctionToolCall{
Type: responses.ItemType_function_call,
CallId: toolCall.ID,
Arguments: toolCall.Function.Arguments,
Name: toolCall.Function.Name,
},
}})
toolCallIdx++
}
case outputItemTypeReasoning:
reasoning := &responses.ItemReasoning{
Type: responses.ItemType_reasoning,
Summary: []*responses.ReasoningSummaryPart{{
Type: responses.ContentItemType_input_text,
Text: msg.ReasoningContent,
}},
}
if id, ok := getReasoningID(msg); ok && id != "" {
reasoning.Id = &id
}
itemList = append(itemList, &responses.InputItem{Union: &responses.InputItem_Reasoning{
Reasoning: reasoning,
}})
}
}
} else {
// Fallback: original behavior when no order metadata is available.
if len(inputMessage.GetContent()) > 0 {
itemList = append(itemList, &responses.InputItem{Union: &responses.InputItem_InputMessage{InputMessage: inputMessage}})
}

for _, toolCall := range msg.ToolCalls {
itemList = append(itemList, &responses.InputItem{Union: &responses.InputItem_FunctionToolCall{
FunctionToolCall: &responses.ItemFunctionToolCall{
Type: responses.ItemType_function_call,
CallId: toolCall.ID,
Arguments: toolCall.Function.Arguments,
Name: toolCall.Function.Name,
},
}})
}
}
case schema.System:
inputMessage, err := cm.toArkSystemRoleItemInputMessage(msg)
Expand Down Expand Up @@ -980,12 +1023,15 @@ func (cm *ResponsesAPIChatModel) toOutputMessage(resp *responses.ResponseObject,
return nil, fmt.Errorf("received empty output from ARK")
}

var itemsOrder []outputItemType

for _, item := range resp.Output {
switch asItem := item.GetUnion().(type) {
case *responses.OutputItem_OutputMessage:
if asItem.OutputMessage == nil {
continue
}
itemsOrder = append(itemsOrder, outputItemTypeMessage)
isMultiContent := len(asItem.OutputMessage.Content) > 1
for _, content := range asItem.OutputMessage.Content {
if content.GetText() == nil {
Expand All @@ -1005,6 +1051,10 @@ func (cm *ResponsesAPIChatModel) toOutputMessage(resp *responses.ResponseObject,
if asItem.Reasoning == nil {
continue
}
itemsOrder = append(itemsOrder, outputItemTypeReasoning)
if asItem.Reasoning.Id != nil {
setReasoningID(msg, *asItem.Reasoning.Id)
}
for _, s := range asItem.Reasoning.GetSummary() {
if s.Text == "" {
continue
Expand All @@ -1020,6 +1070,7 @@ func (cm *ResponsesAPIChatModel) toOutputMessage(resp *responses.ResponseObject,
if asItem.FunctionToolCall == nil {
continue
}
itemsOrder = append(itemsOrder, outputItemTypeFunctionCall)
msg.ToolCalls = append(msg.ToolCalls, schema.ToolCall{
ID: asItem.FunctionToolCall.CallId,
Type: asItem.FunctionToolCall.Type.String(),
Expand All @@ -1031,6 +1082,10 @@ func (cm *ResponsesAPIChatModel) toOutputMessage(resp *responses.ResponseObject,
}
}

if len(itemsOrder) > 0 {
setOutputItemsOrder(msg, itemsOrder)
}

return msg, nil
}

Expand Down Expand Up @@ -1086,6 +1141,7 @@ func (cm *ResponsesAPIChatModel) toCallbackConfig(req *responses.ResponsesReques
func (cm *ResponsesAPIChatModel) receivedStreamResponse(streamReader *utils.ResponsesStreamReader,
config *model.Config, cacheConfig *cacheConfig, sw *schema.StreamWriter[*model.CallbackOutput]) {
var itemFunctionToolCall *responses.ItemFunctionToolCall
var reasoningID *string

for {
event, err := streamReader.Recv()
Expand Down Expand Up @@ -1157,6 +1213,11 @@ func (cm *ResponsesAPIChatModel) receivedStreamResponse(streamReader *utils.Resp
if outputItemFuncCall, ok := ev.Item.GetItem().GetUnion().(*responses.OutputItem_FunctionToolCall); ok {
itemFunctionToolCall = outputItemFuncCall.FunctionToolCall
}
if outputItemReasoning, ok := ev.Item.GetItem().GetUnion().(*responses.OutputItem_Reasoning); ok {
if outputItemReasoning.Reasoning != nil {
reasoningID = outputItemReasoning.Reasoning.Id
}
}

case *responses.Event_FunctionCallArguments:
if ev.FunctionCallArguments == nil {
Expand All @@ -1181,6 +1242,7 @@ func (cm *ResponsesAPIChatModel) receivedStreamResponse(streamReader *utils.Resp
},
},
}
setOutputItemsOrder(msg, []outputItemType{outputItemTypeFunctionCall})
cm.sendCallbackOutput(sw, config, "", msg)
}

Expand All @@ -1194,6 +1256,11 @@ func (cm *ResponsesAPIChatModel) receivedStreamResponse(streamReader *utils.Resp
ReasoningContent: delta,
}
setReasoningContent(msg, delta)
setOutputItemsOrder(msg, []outputItemType{outputItemTypeReasoning})
if reasoningID != nil {
setReasoningID(msg, *reasoningID)
reasoningID = nil // only set once
}
cm.sendCallbackOutput(sw, config, "", msg)

case *responses.Event_Text:
Expand All @@ -1204,6 +1271,7 @@ func (cm *ResponsesAPIChatModel) receivedStreamResponse(streamReader *utils.Resp
Role: schema.Assistant,
Content: *ev.Text.Delta,
}
setOutputItemsOrder(msg, []outputItemType{outputItemTypeMessage})
cm.sendCallbackOutput(sw, config, "", msg)

}
Expand Down
Loading
Loading