diff --git a/chains/llm_test.go b/chains/llm_test.go index da0922b8b..fd09940e2 100644 --- a/chains/llm_test.go +++ b/chains/llm_test.go @@ -12,6 +12,7 @@ import ( "github.com/tmc/langchaingo/callbacks" "github.com/tmc/langchaingo/httputil" "github.com/tmc/langchaingo/internal/httprr" + "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/llms/googleai" "github.com/tmc/langchaingo/llms/openai" "github.com/tmc/langchaingo/prompts" @@ -65,6 +66,31 @@ func TestLLMChain(t *testing.T) { require.True(t, strings.Contains(result, "Paris")) } +type errorLanguageModel struct { + err error +} + +func (m *errorLanguageModel) Call(_ context.Context, _ string, _ ...llms.CallOption) (string, error) { + return "", m.err +} + +func (m *errorLanguageModel) GenerateContent(_ context.Context, _ []llms.MessageContent, _ ...llms.CallOption) (*llms.ContentResponse, error) { + return nil, m.err +} + +func TestLLMChainPropagatesContentFilterError(t *testing.T) { + t.Parallel() + + chain := NewLLMChain( + &errorLanguageModel{err: llms.NewError(llms.ErrCodeContentFilter, "bedrock", "blocked")}, + prompts.NewPromptTemplate("{{.text}}", []string{"text"}), + ) + + _, err := chain.Call(context.Background(), map[string]any{"text": "unsafe prompt"}) + require.Error(t, err) + require.True(t, llms.IsContentFilterError(err)) +} + func TestLLMChainWithChatPromptTemplate(t *testing.T) { ctx := context.Background() t.Parallel() diff --git a/chains/options.go b/chains/options.go index c6df23b09..5865f4dfa 100644 --- a/chains/options.go +++ b/chains/options.go @@ -63,6 +63,10 @@ type chainCallOption struct { RepetitionPenalty float64 repetitionPenaltySet bool + // Safety configures provider-defined safety controls in an LLM call. + SafetyConfig map[string]any + safetyConfigSet bool + // CallbackHandler is the callback handler for Chain CallbackHandler callbacks.Handler } @@ -146,6 +150,16 @@ func WithRepetitionPenalty(repetitionPenalty float64) ChainCallOption { } } +// WithSafetyConfig configures provider-defined safety controls for the LLM call. +func WithSafetyConfig(config map[string]any) ChainCallOption { + return func(o *chainCallOption) { + if config != nil { + o.SafetyConfig = config + o.safetyConfigSet = true + } + } +} + // WithStopWords is an option for setting the stop words for LLM.Call. func WithStopWords(stopWords []string) ChainCallOption { return func(o *chainCallOption) { @@ -208,6 +222,9 @@ func GetLLMCallOptions(options ...ChainCallOption) []llms.CallOption { //nolint: if opts.repetitionPenaltySet { chainCallOption = append(chainCallOption, llms.WithRepetitionPenalty(opts.RepetitionPenalty)) } + if opts.safetyConfigSet { + chainCallOption = append(chainCallOption, llms.WithSafetyConfig(opts.SafetyConfig)) + } chainCallOption = append(chainCallOption, llms.WithStreamingFunc(opts.StreamingFunc)) return chainCallOption diff --git a/llms/bedrock/internal/bedrockclient/guardrail.go b/llms/bedrock/internal/bedrockclient/guardrail.go new file mode 100644 index 000000000..45c3e3f57 --- /dev/null +++ b/llms/bedrock/internal/bedrockclient/guardrail.go @@ -0,0 +1,147 @@ +package bedrockclient + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/tmc/langchaingo/llms" +) + +const bedrockGuardrailIntervened = "INTERVENED" + +type bedrockGuardrailParams struct { + identifier *string + version *string + tagSuffix *string + streamProcessingMode *string +} + +func getRequiredGuardrailParam(safety map[string]any, key string) (*string, error) { + value, ok := safety[key] + if !ok { + return nil, fmt.Errorf("bedrock guardrail %s is required when safety config is provided", key) + } + + str, ok := value.(string) + if !ok || strings.TrimSpace(str) == "" { + return nil, fmt.Errorf("bedrock guardrail %s must be a non-empty string", key) + } + + return aws.String(str), nil +} + +func getOptionalGuardrailParam(safety map[string]any, key string) (*string, error) { + value, ok := safety[key] + if !ok { + return nil, nil + } + + str, ok := value.(string) + if !ok || strings.TrimSpace(str) == "" { + return nil, fmt.Errorf("bedrock guardrail %s must be a non-empty string if provided", key) + } + + return aws.String(str), nil +} + +func getGuardrailParams(options llms.CallOptions) (bedrockGuardrailParams, error) { + if options.SafetyConfig == nil { + return bedrockGuardrailParams{}, nil + } + + identifier, err := getRequiredGuardrailParam(options.SafetyConfig, "identifier") + if err != nil { + return bedrockGuardrailParams{}, err + } + version, err := getRequiredGuardrailParam(options.SafetyConfig, "version") + if err != nil { + return bedrockGuardrailParams{}, err + } + + tagSuffix, err := getOptionalGuardrailParam(options.SafetyConfig, "tagSuffix") + if err != nil { + return bedrockGuardrailParams{}, err + } + streamProcessingMode, err := getOptionalGuardrailParam(options.SafetyConfig, "streamProcessingMode") + if err != nil { + return bedrockGuardrailParams{}, err + } + + return bedrockGuardrailParams{ + identifier: identifier, + version: version, + tagSuffix: tagSuffix, + streamProcessingMode: streamProcessingMode, + }, nil +} + +func handleGuardrailParams(input guardrailConfigurableAdapter, options llms.CallOptions) error { + guardrailParams, err := getGuardrailParams(options) + if err != nil { + return fmt.Errorf("failed to get guardrail parameters: %w", err) + } + + if guardrailParams.identifier != nil && guardrailParams.version != nil { + input.setGuardrailIdentifier(guardrailParams.identifier) + input.setGuardrailVersion(guardrailParams.version) + } + + if guardrailParams.tagSuffix != nil || guardrailParams.streamProcessingMode != nil { + var m map[string]any + if err := json.Unmarshal(input.body(), &m); err != nil { + return fmt.Errorf("failed to unmarshal input body when adding guardrail parameters: %w", err) + } + + var config = make(map[string]string) + if guardrailParams.tagSuffix != nil { + config["tagSuffix"] = *guardrailParams.tagSuffix + } + if guardrailParams.streamProcessingMode != nil { + config["streamProcessingMode"] = *guardrailParams.streamProcessingMode + } + m["amazon-bedrock-guardrailConfig"] = config + + newBody, err := json.Marshal(m) + if err != nil { + return fmt.Errorf("failed to marshal input body when adding guardrail parameters: %w", err) + } + input.setBody(newBody) + } + return nil +} + +func checkGuardrailError(response any) error { + var action string + var trace json.RawMessage + + switch r := response.(type) { + case []byte: + var resp guardrailResponse + if err := json.Unmarshal(r, &resp); err != nil { + // If we can't parse the response, we assume it's not a guardrail error and return nil + return nil + } + action = resp.AmazonBedrockGuardrailAction + trace = resp.AmazonBedrockTrace + case streamingCompletionResponseChunk: + action = r.AmazonBedrockGuardrailAction + trace = r.AmazonBedrockTrace + case *streamingCompletionResponseChunk: + action = r.AmazonBedrockGuardrailAction + trace = r.AmazonBedrockTrace + default: + return nil + } + + if action == bedrockGuardrailIntervened { + err := llms.NewError(llms.ErrCodeContentFilter, "bedrock", "content blocked by Bedrock guardrail") + err.WithDetail("guardrail_action", action) + if len(trace) > 0 { + err.WithDetail("trace", string(trace)) + } + return err + } + return nil +} diff --git a/llms/bedrock/internal/bedrockclient/invoke.go b/llms/bedrock/internal/bedrockclient/invoke.go new file mode 100644 index 000000000..ba48eb92d --- /dev/null +++ b/llms/bedrock/internal/bedrockclient/invoke.go @@ -0,0 +1,154 @@ +package bedrockclient + +import ( + "context" + "encoding/json" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/tmc/langchaingo/llms" +) + +const ( + bedrockAcceptAll = "*/*" + bedrockJSONContentType = "application/json" +) + +type guardrailResponse struct { + AmazonBedrockGuardrailAction string `json:"amazon-bedrock-guardrailAction"` + AmazonBedrockTrace json.RawMessage `json:"amazon-bedrock-trace"` +} + +type streamingCompletionResponseChunk struct { + Type string `json:"type"` + Index int `json:"index"` + Delta struct { + Type string `json:"type"` + Text string `json:"text"` + StopReason string `json:"stop_reason"` + StopSequence any `json:"stop_sequence"` + } `json:"delta"` + AmazonBedrockInvocationMetrics struct { + InputTokenCount int `json:"inputTokenCount"` + OutputTokenCount int `json:"outputTokenCount"` + InvocationLatency int `json:"invocationLatency"` + FirstByteLatency int `json:"firstByteLatency"` + } `json:"amazon-bedrock-invocationMetrics"` + Usage struct { + OutputTokens int `json:"output_tokens"` + } `json:"usage"` + Message struct { + ID string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Content []any `json:"content"` + Model string `json:"model"` + StopReason any `json:"stop_reason"` + StopSequence any `json:"stop_sequence"` + Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + } `json:"usage"` + } `json:"message"` + guardrailResponse +} + +type invokeModelAPI interface { + InvokeModel(ctx context.Context, params *bedrockruntime.InvokeModelInput, optFns ...func(*bedrockruntime.Options)) (*bedrockruntime.InvokeModelOutput, error) +} + +type invokeModelWithResponseStreamAPI interface { + InvokeModelWithResponseStream(ctx context.Context, params *bedrockruntime.InvokeModelWithResponseStreamInput, optFns ...func(*bedrockruntime.Options)) (*bedrockruntime.InvokeModelWithResponseStreamOutput, error) +} + +type guardrailConfigurableAdapter interface { + body() []byte + setGuardrailIdentifier(*string) + setGuardrailVersion(*string) + setBody([]byte) +} + +type guardrailConfigurableInput struct { + *bedrockruntime.InvokeModelInput +} + +var _ guardrailConfigurableAdapter = (*guardrailConfigurableInput)(nil) + +func (i *guardrailConfigurableInput) setGuardrailIdentifier(identifier *string) { + i.GuardrailIdentifier = identifier +} +func (i *guardrailConfigurableInput) setGuardrailVersion(version *string) { + i.GuardrailVersion = version +} +func (i *guardrailConfigurableInput) body() []byte { return i.Body } +func (i *guardrailConfigurableInput) setBody(body []byte) { i.Body = body } + +type guardrailConfigurableStreamInput struct { + *bedrockruntime.InvokeModelWithResponseStreamInput +} + +var _ guardrailConfigurableAdapter = (*guardrailConfigurableStreamInput)(nil) + +func (i *guardrailConfigurableStreamInput) setGuardrailIdentifier(identifier *string) { + i.GuardrailIdentifier = identifier +} +func (i *guardrailConfigurableStreamInput) setGuardrailVersion(version *string) { + i.GuardrailVersion = version +} +func (i *guardrailConfigurableStreamInput) body() []byte { return i.Body } +func (i *guardrailConfigurableStreamInput) setBody(body []byte) { i.Body = body } + +func newInvokeModelInput(modelID string, body []byte, options llms.CallOptions) (*bedrockruntime.InvokeModelInput, error) { + input := &guardrailConfigurableInput{ + &bedrockruntime.InvokeModelInput{ + ModelId: aws.String(modelID), + Accept: aws.String(bedrockAcceptAll), + ContentType: aws.String(bedrockJSONContentType), + Body: body, + }, + } + if err := handleGuardrailParams(input, options); err != nil { + return nil, err + } + return input.InvokeModelInput, nil +} + +func newInvokeModelWithResponseStreamInput(modelID string, body []byte, options llms.CallOptions) (*bedrockruntime.InvokeModelWithResponseStreamInput, error) { + input := &guardrailConfigurableStreamInput{ + &bedrockruntime.InvokeModelWithResponseStreamInput{ + ModelId: aws.String(modelID), + Accept: aws.String(bedrockAcceptAll), + ContentType: aws.String(bedrockJSONContentType), + Body: body, + }, + } + if err := handleGuardrailParams(input, options); err != nil { + return nil, err + } + return input.InvokeModelWithResponseStreamInput, nil +} + +func invokeModel(ctx context.Context, client invokeModelAPI, modelID string, body []byte, options llms.CallOptions) (*bedrockruntime.InvokeModelOutput, error) { + input, err := newInvokeModelInput(modelID, body, options) + if err != nil { + return nil, err + } + + output, err := client.InvokeModel(ctx, input) + if err != nil { + return nil, err + } + if err := checkGuardrailError(output.Body); err != nil { + return nil, err + } + + return output, nil +} + +func invokeModelWithResponseStream(ctx context.Context, client invokeModelWithResponseStreamAPI, modelID string, body []byte, options llms.CallOptions) (*bedrockruntime.InvokeModelWithResponseStreamOutput, error) { + input, err := newInvokeModelWithResponseStreamInput(modelID, body, options) + if err != nil { + return nil, err + } + return client.InvokeModelWithResponseStream(ctx, input) +} diff --git a/llms/bedrock/internal/bedrockclient/invoke_test.go b/llms/bedrock/internal/bedrockclient/invoke_test.go new file mode 100644 index 000000000..3f6471dc4 --- /dev/null +++ b/llms/bedrock/internal/bedrockclient/invoke_test.go @@ -0,0 +1,123 @@ +package bedrockclient + +import ( + "context" + "testing" + + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tmc/langchaingo/llms" +) + +func TestInvokeModelInputs(t *testing.T) { + t.Run("without safety config", func(t *testing.T) { + input, err := newInvokeModelInput("anthropic.claude-v2", []byte(`{"prompt":"hello"}`), llms.CallOptions{}) + require.NoError(t, err) + require.NotNil(t, input) + assert.Nil(t, input.GuardrailIdentifier) + assert.Nil(t, input.GuardrailVersion) + + streamInput, err := newInvokeModelWithResponseStreamInput("anthropic.claude-v2", []byte(`{"prompt":"hello"}`), llms.CallOptions{}) + require.NoError(t, err) + require.NotNil(t, streamInput) + assert.Nil(t, streamInput.GuardrailIdentifier) + assert.Nil(t, streamInput.GuardrailVersion) + }) + + t.Run("with safety config", func(t *testing.T) { + var options llms.CallOptions + llms.WithSafetyConfig(map[string]any{ + "identifier": "gr-123", + "version": "1", + })(&options) + + input, err := newInvokeModelInput("anthropic.claude-v2", []byte(`{"prompt":"hello"}`), options) + require.NoError(t, err) + require.NotNil(t, input.GuardrailIdentifier) + require.NotNil(t, input.GuardrailVersion) + assert.Equal(t, "gr-123", *input.GuardrailIdentifier) + assert.Equal(t, "1", *input.GuardrailVersion) + + streamInput, err := newInvokeModelWithResponseStreamInput("anthropic.claude-v2", []byte(`{"prompt":"hello"}`), options) + require.NoError(t, err) + require.NotNil(t, streamInput.GuardrailIdentifier) + require.NotNil(t, streamInput.GuardrailVersion) + assert.Equal(t, "gr-123", *streamInput.GuardrailIdentifier) + assert.Equal(t, "1", *streamInput.GuardrailVersion) + }) + + t.Run("requires safety version", func(t *testing.T) { + _, err := newInvokeModelInput("anthropic.claude-v2", nil, llms.CallOptions{ + SafetyConfig: map[string]any{"identifier": "gr-123"}, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "bedrock safety version") + }) + + t.Run("requires safety identifier", func(t *testing.T) { + _, err := newInvokeModelInput("anthropic.claude-v2", nil, llms.CallOptions{ + SafetyConfig: map[string]any{"version": "1"}, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "bedrock safety identifier") + }) + + t.Run("rejects invalid safety types", func(t *testing.T) { + _, err := newInvokeModelInput("anthropic.claude-v2", nil, llms.CallOptions{ + SafetyConfig: map[string]any{ + "identifier": 123, + "version": "1", + }, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "bedrock safety identifier") + }) + + t.Run("returns content filter error when guardrail intervenes", func(t *testing.T) { + resp, err := invokeModel(context.Background(), &mockBedrockClient{ + invokeFunc: func(ctx context.Context, params *bedrockruntime.InvokeModelInput, optFns ...func(*bedrockruntime.Options)) (*bedrockruntime.InvokeModelOutput, error) { + return &bedrockruntime.InvokeModelOutput{ + Body: []byte(`{"amazon-bedrock-guardrailAction":"INTERVENED","amazon-bedrock-trace":{"rule":"blocked"}}`), + }, nil + }, + }, "anthropic.claude-v2", []byte(`{"prompt":"hello"}`), llms.CallOptions{}) + + require.Nil(t, resp) + require.Error(t, err) + assert.True(t, llms.IsContentFilterError(err)) + + var llmErr *llms.Error + require.ErrorAs(t, err, &llmErr) + assert.Equal(t, "bedrock", llmErr.Provider) + assert.Equal(t, "INTERVENED", llmErr.Details["guardrail_action"]) + assert.Equal(t, map[string]any{"rule": "blocked"}, llmErr.Details["trace"]) + }) + + t.Run("returns content filter error for provider filtered response", func(t *testing.T) { + resp, err := invokeModel(context.Background(), &mockBedrockClient{ + invokeFunc: func(ctx context.Context, params *bedrockruntime.InvokeModelInput, optFns ...func(*bedrockruntime.Options)) (*bedrockruntime.InvokeModelOutput, error) { + return &bedrockruntime.InvokeModelOutput{ + Body: []byte(`{"results":[{"completionReason":"CONTENT_FILTERED"}]}`), + }, nil + }, + }, "amazon.titan-text-express-v1", []byte(`{"prompt":"hello"}`), llms.CallOptions{}) + + require.Nil(t, resp) + require.Error(t, err) + assert.True(t, llms.IsContentFilterError(err)) + }) + + t.Run("returns response when guardrail does not intervene", func(t *testing.T) { + resp, err := invokeModel(context.Background(), &mockBedrockClient{ + invokeFunc: func(ctx context.Context, params *bedrockruntime.InvokeModelInput, optFns ...func(*bedrockruntime.Options)) (*bedrockruntime.InvokeModelOutput, error) { + return &bedrockruntime.InvokeModelOutput{ + Body: []byte(`{"amazon-bedrock-guardrailAction":"NONE","outputText":"ok"}`), + }, nil + }, + }, "anthropic.claude-v2", []byte(`{"prompt":"hello"}`), llms.CallOptions{}) + + require.NoError(t, err) + require.NotNil(t, resp) + }) +} diff --git a/llms/bedrock/internal/bedrockclient/provider_ai21.go b/llms/bedrock/internal/bedrockclient/provider_ai21.go index 5ff482b15..a9aef212d 100644 --- a/llms/bedrock/internal/bedrockclient/provider_ai21.go +++ b/llms/bedrock/internal/bedrockclient/provider_ai21.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" - "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" "github.com/tmc/langchaingo/llms" ) @@ -100,14 +99,7 @@ func createAi21Completion(ctx context.Context, client *bedrockruntime.Client, mo return nil, err } - modelInput := bedrockruntime.InvokeModelInput{ - ModelId: aws.String(modelID), - Body: body, - Accept: aws.String("*/*"), - ContentType: aws.String("application/json"), - } - - resp, err := client.InvokeModel(ctx, &modelInput) + resp, err := invokeModel(ctx, client, modelID, body, options) if err != nil { return nil, err } diff --git a/llms/bedrock/internal/bedrockclient/provider_amazon.go b/llms/bedrock/internal/bedrockclient/provider_amazon.go index bae6eb825..36200c16f 100644 --- a/llms/bedrock/internal/bedrockclient/provider_amazon.go +++ b/llms/bedrock/internal/bedrockclient/provider_amazon.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" - "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" "github.com/tmc/langchaingo/llms" ) @@ -79,13 +78,7 @@ func createAmazonCompletion(ctx context.Context, return nil, err } - modelInput := &bedrockruntime.InvokeModelInput{ - ModelId: aws.String(modelID), - Accept: aws.String("*/*"), - ContentType: aws.String("application/json"), - Body: body, - } - resp, err := client.InvokeModel(ctx, modelInput) + resp, err := invokeModel(ctx, client, modelID, body, options) if err != nil { return nil, err } diff --git a/llms/bedrock/internal/bedrockclient/provider_anthropic.go b/llms/bedrock/internal/bedrockclient/provider_anthropic.go index c6e389808..02be3b449 100644 --- a/llms/bedrock/internal/bedrockclient/provider_anthropic.go +++ b/llms/bedrock/internal/bedrockclient/provider_anthropic.go @@ -8,7 +8,6 @@ import ( "errors" "fmt" - "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" "github.com/tmc/langchaingo/llms" @@ -81,6 +80,11 @@ type anthropicTextGenerationInput struct { Tools []BedrockTool `json:"tools,omitempty"` // Tool choice configuration. Optional ToolChoice *BedrockToolChoice `json:"tool_choice,omitempty"` + // Guardrail configuration. Optional + GuardrailConfig *struct { + TagSuffix string `json:"tagSuffix,omitempty"` + StreamProcessingMode string `json:"streamProcessingMode,omitempty"` + } `json:"amazon-bedrock-guardrailConfig,omitempty"` } // anthropicTextGenerationOutput is the generated output. @@ -182,28 +186,38 @@ func createAnthropicCompletion(ctx context.Context, } } + // Add guardrail tag suffix and stream processing mode if provided in safety config + if options.SafetyConfig != nil { + input.GuardrailConfig = &struct { + TagSuffix string `json:"tagSuffix,omitempty"` + StreamProcessingMode string `json:"streamProcessingMode,omitempty"` + }{} + if tagSuffix, ok := options.SafetyConfig["tag_suffix"]; ok { + if str, ok := tagSuffix.(string); ok { + input.GuardrailConfig.TagSuffix = str + } else { + return nil, errors.New("tag_suffix in safety config must be a string") + } + } + if streamProcessingMode, ok := options.SafetyConfig["stream_processing_mode"]; ok { + if str, ok := streamProcessingMode.(string); ok { + input.GuardrailConfig.StreamProcessingMode = str + } else { + return nil, errors.New("stream_processing_mode in safety config must be a string") + } + } + } + body, err := json.Marshal(input) if err != nil { return nil, err } if options.StreamingFunc != nil { - modelInput := &bedrockruntime.InvokeModelWithResponseStreamInput{ - ModelId: aws.String(modelID), - Accept: aws.String("*/*"), - ContentType: aws.String("application/json"), - Body: body, - } - return parseStreamingCompletionResponse(ctx, client, modelInput, options) + return parseStreamingCompletionResponse(ctx, client, modelID, body, options) } - modelInput := &bedrockruntime.InvokeModelInput{ - ModelId: aws.String(modelID), - Accept: aws.String("*/*"), - ContentType: aws.String("application/json"), - Body: body, - } - resp, err := client.InvokeModel(ctx, modelInput) + resp, err := invokeModel(ctx, client, modelID, body, options) if err != nil { return nil, err } @@ -263,41 +277,8 @@ func createAnthropicCompletion(ctx context.Context, }, nil } -type streamingCompletionResponseChunk struct { - Type string `json:"type"` - Index int `json:"index"` - Delta struct { - Type string `json:"type"` - Text string `json:"text"` - StopReason string `json:"stop_reason"` - StopSequence any `json:"stop_sequence"` - } `json:"delta"` - AmazonBedrockInvocationMetrics struct { - InputTokenCount int `json:"inputTokenCount"` - OutputTokenCount int `json:"outputTokenCount"` - InvocationLatency int `json:"invocationLatency"` - FirstByteLatency int `json:"firstByteLatency"` - } `json:"amazon-bedrock-invocationMetrics"` - Usage struct { - OutputTokens int `json:"output_tokens"` - } `json:"usage"` - Message struct { - ID string `json:"id"` - Type string `json:"type"` - Role string `json:"role"` - Content []any `json:"content"` - Model string `json:"model"` - StopReason any `json:"stop_reason"` - StopSequence any `json:"stop_sequence"` - Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - } `json:"usage"` - } `json:"message"` -} - -func parseStreamingCompletionResponse(ctx context.Context, client *bedrockruntime.Client, modelInput *bedrockruntime.InvokeModelWithResponseStreamInput, options llms.CallOptions) (*llms.ContentResponse, error) { - output, err := client.InvokeModelWithResponseStream(ctx, modelInput) +func parseStreamingCompletionResponse(ctx context.Context, client invokeModelWithResponseStreamAPI, modelID string, body []byte, options llms.CallOptions) (*llms.ContentResponse, error) { + output, err := invokeModelWithResponseStream(ctx, client, modelID, body, options) if err != nil { return nil, err } @@ -315,8 +296,10 @@ func parseStreamingCompletionResponse(ctx context.Context, client *bedrockruntim if v, ok := e.(*types.ResponseStreamMemberChunk); ok { var resp streamingCompletionResponseChunk - err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(&resp) - if err != nil { + if err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(&resp); err != nil { + return nil, err + } + if err := checkGuardrailError(&resp); err != nil { return nil, err } diff --git a/llms/bedrock/internal/bedrockclient/provider_cohere.go b/llms/bedrock/internal/bedrockclient/provider_cohere.go index 1ededb2f1..511b3a996 100644 --- a/llms/bedrock/internal/bedrockclient/provider_cohere.go +++ b/llms/bedrock/internal/bedrockclient/provider_cohere.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" - "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" "github.com/tmc/langchaingo/llms" ) @@ -84,13 +83,7 @@ func createCohereCompletion(ctx context.Context, return nil, err } - modelInput := &bedrockruntime.InvokeModelInput{ - ModelId: aws.String(modelID), - Accept: aws.String("*/*"), - ContentType: aws.String("application/json"), - Body: body, - } - resp, err := client.InvokeModel(ctx, modelInput) + resp, err := invokeModel(ctx, client, modelID, body, options) if err != nil { return nil, err } diff --git a/llms/bedrock/internal/bedrockclient/provider_meta.go b/llms/bedrock/internal/bedrockclient/provider_meta.go index 737918712..fe10ab8da 100644 --- a/llms/bedrock/internal/bedrockclient/provider_meta.go +++ b/llms/bedrock/internal/bedrockclient/provider_meta.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" - "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" "github.com/tmc/langchaingo/llms" ) @@ -64,14 +63,7 @@ func createMetaCompletion(ctx context.Context, return nil, err } - modelInput := &bedrockruntime.InvokeModelInput{ - ModelId: aws.String(modelID), - Accept: aws.String("*/*"), - ContentType: aws.String("application/json"), - Body: body, - } - - resp, err := client.InvokeModel(ctx, modelInput) + resp, err := invokeModel(ctx, client, modelID, body, options) if err != nil { return nil, err } diff --git a/llms/bedrock/internal/bedrockclient/provider_nova.go b/llms/bedrock/internal/bedrockclient/provider_nova.go index d452a139b..ba2ceed4c 100644 --- a/llms/bedrock/internal/bedrockclient/provider_nova.go +++ b/llms/bedrock/internal/bedrockclient/provider_nova.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" - "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" "github.com/tmc/langchaingo/llms" ) @@ -161,13 +160,7 @@ func createNovaCompletion(ctx context.Context, return nil, errors.New("streaming not implemented for nova") } - modelInput := &bedrockruntime.InvokeModelInput{ - ModelId: aws.String(modelID), - Accept: aws.String("*/*"), - ContentType: aws.String("application/json"), - Body: body, - } - resp, err := client.InvokeModel(ctx, modelInput) + resp, err := invokeModel(ctx, client, modelID, body, options) if err != nil { return nil, err } diff --git a/llms/options.go b/llms/options.go index 87b35359c..42cc6dc0f 100644 --- a/llms/options.go +++ b/llms/options.go @@ -73,6 +73,9 @@ type CallOptions struct { // WebSearchOptions configures web search behavior for models that support it. // Currently supported by OpenAI models like gpt-4o-search-preview. WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"` + + // Safety configures provider-defined safety controls for the invocation. + SafetyConfig map[string]any `json:"safety,omitempty"` } // Tool is a tool that can be used by the model. diff --git a/llms/options_test.go b/llms/options_test.go index 86cc88fcb..9bc039873 100644 --- a/llms/options_test.go +++ b/llms/options_test.go @@ -160,6 +160,24 @@ func TestCallOptions(t *testing.T) { //nolint:funlen // comprehensive test } }, }, + { + name: "WithSafetyConfig", + option: llms.WithSafetyConfig(map[string]any{ + "identifier": "safety-profile", + "version": "v1", + }), + verify: func(t *testing.T, opts llms.CallOptions) { + if opts.SafetyConfig == nil { + t.Fatal("SafetyConfig = nil, want non-nil") + } + if opts.SafetyConfig["identifier"] != "safety-profile" { + t.Errorf("SafetyConfig[identifier] = %v, want %v", opts.SafetyConfig["identifier"], "safety-profile") + } + if opts.SafetyConfig["version"] != "v1" { + t.Errorf("SafetyConfig[version] = %v, want %v", opts.SafetyConfig["version"], "v1") + } + }, + }, } for _, tt := range tests { diff --git a/llms/safety.go b/llms/safety.go new file mode 100644 index 000000000..53be94400 --- /dev/null +++ b/llms/safety.go @@ -0,0 +1,10 @@ +package llms + +// WithSafetyConfig adds safety configuration to call options. +func WithSafetyConfig(config map[string]any) CallOption { + return func(opts *CallOptions) { + if config != nil { + opts.SafetyConfig = config + } + } +} diff --git a/llms/safety_test.go b/llms/safety_test.go new file mode 100644 index 000000000..9dd8e45b0 --- /dev/null +++ b/llms/safety_test.go @@ -0,0 +1,60 @@ +package llms_test + +import ( + "testing" + + "github.com/tmc/langchaingo/llms" +) + +func TestWithSafetyConfig(t *testing.T) { + var opts llms.CallOptions + + llms.WithSafetyConfig(map[string]any{ + "identifier": "safe-profile", + "version": "v1", + })(&opts) + + if opts.SafetyConfig == nil { + t.Fatal("expected safety config to be present") + } + if opts.SafetyConfig["identifier"] != "safe-profile" { + t.Fatalf("identifier = %q, want %q", opts.SafetyConfig["identifier"], "safe-profile") + } + if opts.SafetyConfig["version"] != "v1" { + t.Fatalf("version = %q, want %q", opts.SafetyConfig["version"], "v1") + } +} + +func TestWithSafetyConfig_Nil(t *testing.T) { + var opts llms.CallOptions + + llms.WithSafetyConfig(nil)(&opts) + + if opts.SafetyConfig != nil { + t.Fatal("expected safety config to be nil") + } +} + +func TestWithSafetyConfig_Overwrite(t *testing.T) { + var opts llms.CallOptions + + llms.WithSafetyConfig(map[string]any{ + "identifier": "safe-profile", + "version": "v1", + })(&opts) + + llms.WithSafetyConfig(map[string]any{ + "identifier": "new-profile", + "version": "v2", + })(&opts) + + if opts.SafetyConfig == nil { + t.Fatal("expected safety config to be present") + } + if opts.SafetyConfig["identifier"] != "new-profile" { + t.Fatalf("identifier = %q, want %q", opts.SafetyConfig["identifier"], "new-profile") + } + if opts.SafetyConfig["version"] != "v2" { + t.Fatalf("version = %q, want %q", opts.SafetyConfig["version"], "v2") + } +}