Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion llms/openai/internal/openaiclient/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ type ChatRequest struct {
// Metadata allows you to specify additional information that will be passed to the model.
Metadata map[string]any `json:"metadata,omitempty"`

// ExtraBody allows passing custom parameters directly to the OpenAI API.
// This is useful for beta features or new parameters not yet supported by the library.
// Fields in ExtraBody will be merged into the JSON request body.
ExtraBody map[string]any `json:"-"`

// WebSearchOptions configures web search behavior for search-enabled models
// like gpt-4o-search-preview and gpt-4o-mini-search-preview.
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
Expand All @@ -92,6 +97,7 @@ type ChatRequest struct {
// MarshalJSON ensures that only one of MaxTokens or MaxCompletionTokens is sent.
// OpenAI's API returns an error if both fields are present.
// Also omits temperature for reasoning models (GPT-5, o1, o3) that only accept default temperature.
// Additionally, merges ExtraBody fields into the JSON request.
func (r ChatRequest) MarshalJSON() ([]byte, error) {
type Alias ChatRequest
aux := struct {
Expand Down Expand Up @@ -127,7 +133,29 @@ func (r ChatRequest) MarshalJSON() ([]byte, error) {
aux.MaxCompletionTokens = nil
}

return json.Marshal(&aux)
// Marshal the base request
baseJSON, err := json.Marshal(&aux)
if err != nil {
return nil, err
}

// If no ExtraBody, return the base JSON
if len(r.ExtraBody) == 0 {
return baseJSON, nil
}

// Merge ExtraBody fields into the base JSON
var baseMap map[string]any
if err := json.Unmarshal(baseJSON, &baseMap); err != nil {
return nil, err
}

// Merge ExtraBody fields (ExtraBody takes precedence if there are conflicts)
for k, v := range r.ExtraBody {
baseMap[k] = v
}

return json.Marshal(baseMap)
}

// isReasoningModel returns true if the model is a reasoning model that has temperature constraints.
Expand Down
126 changes: 126 additions & 0 deletions llms/openai/internal/openaiclient/marshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,3 +321,129 @@ func TestIsReasoningModel(t *testing.T) {
})
}
}

func TestChatRequest_ExtraBodyMarshalJSON(t *testing.T) {
tests := []struct {
name string
request ChatRequest
wantExtra map[string]interface{}
}{
{
name: "with extra_body",
request: ChatRequest{
Model: "gpt-4",
ExtraBody: map[string]any{
"parallel_tool_calls": false,
"custom_param": "test_value",
},
},
wantExtra: map[string]interface{}{
"parallel_tool_calls": false,
"custom_param": "test_value",
},
},
{
name: "without extra_body",
request: ChatRequest{
Model: "gpt-4",
},
wantExtra: nil,
},
{
name: "empty extra_body",
request: ChatRequest{
Model: "gpt-4",
ExtraBody: map[string]any{},
},
wantExtra: nil,
},
{
name: "extra_body with nested objects",
request: ChatRequest{
Model: "gpt-4",
ExtraBody: map[string]any{
"nested": map[string]interface{}{
"key1": "value1",
"key2": 123,
},
},
},
wantExtra: map[string]interface{}{
"nested": map[string]interface{}{
"key1": "value1",
"key2": float64(123),
},
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.request)
if err != nil {
t.Fatalf("failed to marshal: %v", err)
}

var result map[string]interface{}
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}

if tt.wantExtra == nil {
// Check that extra_body fields are not present
for key := range tt.request.ExtraBody {
if _, exists := result[key]; exists {
t.Errorf("unexpected extra_body field %q in result", key)
}
}
} else {
// Check that all extra_body fields are present in result
for key, wantValue := range tt.wantExtra {
gotValue, exists := result[key]
if !exists {
t.Errorf("missing extra_body field %q in result", key)
continue
}

// For nested maps, need deep comparison
wantJSON, _ := json.Marshal(wantValue)
gotJSON, _ := json.Marshal(gotValue)
if string(wantJSON) != string(gotJSON) {
t.Errorf("extra_body field %q: got %v, want %v", key, gotValue, wantValue)
}
}
}

// Verify model field is still present
if result["model"] != tt.request.Model {
t.Errorf("model field: got %v, want %v", result["model"], tt.request.Model)
}
})
}
}

func TestChatRequest_ExtraBodyOverridesFields(t *testing.T) {
// Test that ExtraBody can override standard fields
request := ChatRequest{
Model: "gpt-4",
Temperature: 0.7,
ExtraBody: map[string]any{
"temperature": 0.9,
},
}

data, err := json.Marshal(request)
if err != nil {
t.Fatalf("failed to marshal: %v", err)
}

var result map[string]interface{}
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}

// ExtraBody should take precedence
if temp := result["temperature"].(float64); temp != 0.9 {
t.Errorf("temperature: got %v, want 0.9 (from ExtraBody)", temp)
}
}
9 changes: 9 additions & 0 deletions llms/openai/openaillm.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,14 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten
}
}

// Extract extra_body if provided
var extraBody map[string]any
if opts.Metadata != nil {
if v, ok := opts.Metadata["openai:extra_body"].(map[string]interface{}); ok {
extraBody = v
}
}

// Extract reasoning effort for thinking models
// Note: OpenAI o1/o3 models have built-in reasoning and don't support reasoning_effort parameter
// This is kept for future models that might support it (like GPT-5)
Expand Down Expand Up @@ -293,6 +301,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten
FunctionCallBehavior: openaiclient.FunctionCallBehavior(opts.FunctionCallBehavior),
Seed: opts.Seed,
Metadata: apiMetadata,
ExtraBody: extraBody,
WebSearchOptions: webSearchOptionsFromCallOptions(opts.WebSearchOptions),
}
if opts.JSONMode {
Expand Down
23 changes: 23 additions & 0 deletions llms/openai/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,26 @@ func WithLegacyMaxTokensField() llms.CallOption {
opts.Metadata["openai:use_legacy_max_tokens"] = true
}
}

// WithExtraBody allows passing custom parameters directly to the OpenAI API.
// This is useful for beta features or new parameters not yet supported by the library.
// Fields in extraBody will be merged into the JSON request body.
//
// Usage:
//
// llm.GenerateContent(ctx, messages,
// openai.WithExtraBody(map[string]interface{}{
// "parallel_tool_calls": false,
// }),
// )
func WithExtraBody(extraBody map[string]interface{}) llms.CallOption {
return func(opts *llms.CallOptions) {
// Only set if extraBody is not empty
if len(extraBody) > 0 {
if opts.Metadata == nil {
opts.Metadata = make(map[string]interface{})
}
opts.Metadata["openai:extra_body"] = extraBody
}
}
}
45 changes: 45 additions & 0 deletions llms/openai/options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,51 @@ func TestWithLegacyMaxTokensField(t *testing.T) {
}
}

func TestWithExtraBody(t *testing.T) {
opts := &llms.CallOptions{}

// Test that WithExtraBody sets the metadata
extraBody := map[string]interface{}{
"parallel_tool_calls": false,
"custom_param": "test_value",
}
WithExtraBody(extraBody)(opts)
if opts.Metadata == nil {
t.Fatal("expected Metadata to be initialized")
}
if v, ok := opts.Metadata["openai:extra_body"].(map[string]interface{}); !ok {
t.Error("expected openai:extra_body to be set")
} else {
if v["parallel_tool_calls"] != false {
t.Error("expected parallel_tool_calls to be false")
}
if v["custom_param"] != "test_value" {
t.Error("expected custom_param to be test_value")
}
}

// Test with empty extra body - should not set metadata
opts2 := &llms.CallOptions{}
emptyExtra := map[string]interface{}{}
WithExtraBody(emptyExtra)(opts2)
// Empty extra body should not create metadata
if opts2.Metadata != nil {
if _, exists := opts2.Metadata["openai:extra_body"]; exists {
t.Error("expected openai:extra_body to not be set for empty map")
}
}

// Test with nil extra body - should not set metadata
opts3 := &llms.CallOptions{}
WithExtraBody(nil)(opts3)
// Nil extra body should not create metadata
if opts3.Metadata != nil {
if _, exists := opts3.Metadata["openai:extra_body"]; exists {
t.Error("expected openai:extra_body to not be set for nil")
}
}
}

func TestWithWebSearch(t *testing.T) {
// Test with nil options (default behavior)
opts := &llms.CallOptions{}
Expand Down
Loading