Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions chains/llm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/httputil"
"github.com/tmc/langchaingo/internal/httprr"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/googleai"
"github.com/tmc/langchaingo/llms/openai"
"github.com/tmc/langchaingo/prompts"
Expand Down Expand Up @@ -65,6 +66,31 @@ func TestLLMChain(t *testing.T) {
require.True(t, strings.Contains(result, "Paris"))
}

type errorLanguageModel struct {
err error
}

func (m *errorLanguageModel) Call(_ context.Context, _ string, _ ...llms.CallOption) (string, error) {
return "", m.err
}

func (m *errorLanguageModel) GenerateContent(_ context.Context, _ []llms.MessageContent, _ ...llms.CallOption) (*llms.ContentResponse, error) {
return nil, m.err
}

func TestLLMChainPropagatesContentFilterError(t *testing.T) {
t.Parallel()

chain := NewLLMChain(
&errorLanguageModel{err: llms.NewError(llms.ErrCodeContentFilter, "bedrock", "blocked")},
prompts.NewPromptTemplate("{{.text}}", []string{"text"}),
)

_, err := chain.Call(context.Background(), map[string]any{"text": "unsafe prompt"})
require.Error(t, err)
require.True(t, llms.IsContentFilterError(err))
}

func TestLLMChainWithChatPromptTemplate(t *testing.T) {
ctx := context.Background()
t.Parallel()
Expand Down
17 changes: 17 additions & 0 deletions chains/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ type chainCallOption struct {
RepetitionPenalty float64
repetitionPenaltySet bool

// Safety configures provider-defined safety controls in an LLM call.
SafetyConfig map[string]any
safetyConfigSet bool

// CallbackHandler is the callback handler for Chain
CallbackHandler callbacks.Handler
}
Expand Down Expand Up @@ -146,6 +150,16 @@ func WithRepetitionPenalty(repetitionPenalty float64) ChainCallOption {
}
}

// WithSafetyConfig configures provider-defined safety controls for the LLM call.
func WithSafetyConfig(config map[string]any) ChainCallOption {
return func(o *chainCallOption) {
if config != nil {
o.SafetyConfig = config
o.safetyConfigSet = true
}
}
}

// WithStopWords is an option for setting the stop words for LLM.Call.
func WithStopWords(stopWords []string) ChainCallOption {
return func(o *chainCallOption) {
Expand Down Expand Up @@ -208,6 +222,9 @@ func GetLLMCallOptions(options ...ChainCallOption) []llms.CallOption { //nolint:
if opts.repetitionPenaltySet {
chainCallOption = append(chainCallOption, llms.WithRepetitionPenalty(opts.RepetitionPenalty))
}
if opts.safetyConfigSet {
chainCallOption = append(chainCallOption, llms.WithSafetyConfig(opts.SafetyConfig))
}
chainCallOption = append(chainCallOption, llms.WithStreamingFunc(opts.StreamingFunc))

return chainCallOption
Expand Down
147 changes: 147 additions & 0 deletions llms/bedrock/internal/bedrockclient/guardrail.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package bedrockclient

import (
"encoding/json"
"fmt"
"strings"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/tmc/langchaingo/llms"
)

const bedrockGuardrailIntervened = "INTERVENED"

type bedrockGuardrailParams struct {
identifier *string
version *string
tagSuffix *string
streamProcessingMode *string
}

func getRequiredGuardrailParam(safety map[string]any, key string) (*string, error) {
value, ok := safety[key]
if !ok {
return nil, fmt.Errorf("bedrock guardrail %s is required when safety config is provided", key)
}

str, ok := value.(string)
if !ok || strings.TrimSpace(str) == "" {
return nil, fmt.Errorf("bedrock guardrail %s must be a non-empty string", key)
}

return aws.String(str), nil
}

func getOptionalGuardrailParam(safety map[string]any, key string) (*string, error) {
value, ok := safety[key]
if !ok {
return nil, nil
}

str, ok := value.(string)
if !ok || strings.TrimSpace(str) == "" {
return nil, fmt.Errorf("bedrock guardrail %s must be a non-empty string if provided", key)
}

return aws.String(str), nil
}

func getGuardrailParams(options llms.CallOptions) (bedrockGuardrailParams, error) {
if options.SafetyConfig == nil {
return bedrockGuardrailParams{}, nil
}

identifier, err := getRequiredGuardrailParam(options.SafetyConfig, "identifier")
if err != nil {
return bedrockGuardrailParams{}, err
}
version, err := getRequiredGuardrailParam(options.SafetyConfig, "version")
if err != nil {
return bedrockGuardrailParams{}, err
}

tagSuffix, err := getOptionalGuardrailParam(options.SafetyConfig, "tagSuffix")
if err != nil {
return bedrockGuardrailParams{}, err
}
streamProcessingMode, err := getOptionalGuardrailParam(options.SafetyConfig, "streamProcessingMode")
if err != nil {
return bedrockGuardrailParams{}, err
}

return bedrockGuardrailParams{
identifier: identifier,
version: version,
tagSuffix: tagSuffix,
streamProcessingMode: streamProcessingMode,
}, nil
}

func handleGuardrailParams(input guardrailConfigurableAdapter, options llms.CallOptions) error {
guardrailParams, err := getGuardrailParams(options)
if err != nil {
return fmt.Errorf("failed to get guardrail parameters: %w", err)
}

if guardrailParams.identifier != nil && guardrailParams.version != nil {
input.setGuardrailIdentifier(guardrailParams.identifier)
input.setGuardrailVersion(guardrailParams.version)
}

if guardrailParams.tagSuffix != nil || guardrailParams.streamProcessingMode != nil {
var m map[string]any
if err := json.Unmarshal(input.body(), &m); err != nil {
return fmt.Errorf("failed to unmarshal input body when adding guardrail parameters: %w", err)
}

var config = make(map[string]string)
if guardrailParams.tagSuffix != nil {
config["tagSuffix"] = *guardrailParams.tagSuffix
}
if guardrailParams.streamProcessingMode != nil {
config["streamProcessingMode"] = *guardrailParams.streamProcessingMode
}
m["amazon-bedrock-guardrailConfig"] = config

newBody, err := json.Marshal(m)
if err != nil {
return fmt.Errorf("failed to marshal input body when adding guardrail parameters: %w", err)
}
input.setBody(newBody)
}
return nil
}

func checkGuardrailError(response any) error {
var action string
var trace json.RawMessage

switch r := response.(type) {
case []byte:
var resp guardrailResponse
if err := json.Unmarshal(r, &resp); err != nil {
// If we can't parse the response, we assume it's not a guardrail error and return nil
return nil
}
action = resp.AmazonBedrockGuardrailAction
trace = resp.AmazonBedrockTrace
case streamingCompletionResponseChunk:
action = r.AmazonBedrockGuardrailAction
trace = r.AmazonBedrockTrace
case *streamingCompletionResponseChunk:
action = r.AmazonBedrockGuardrailAction
trace = r.AmazonBedrockTrace
default:
return nil
}

if action == bedrockGuardrailIntervened {
err := llms.NewError(llms.ErrCodeContentFilter, "bedrock", "content blocked by Bedrock guardrail")
err.WithDetail("guardrail_action", action)
if len(trace) > 0 {
err.WithDetail("trace", string(trace))
}
return err
}
return nil
}
154 changes: 154 additions & 0 deletions llms/bedrock/internal/bedrockclient/invoke.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package bedrockclient

import (
"context"
"encoding/json"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/tmc/langchaingo/llms"
)

const (
bedrockAcceptAll = "*/*"
bedrockJSONContentType = "application/json"
)

type guardrailResponse struct {
AmazonBedrockGuardrailAction string `json:"amazon-bedrock-guardrailAction"`
AmazonBedrockTrace json.RawMessage `json:"amazon-bedrock-trace"`
}

type streamingCompletionResponseChunk struct {
Type string `json:"type"`
Index int `json:"index"`
Delta struct {
Type string `json:"type"`
Text string `json:"text"`
StopReason string `json:"stop_reason"`
StopSequence any `json:"stop_sequence"`
} `json:"delta"`
AmazonBedrockInvocationMetrics struct {
InputTokenCount int `json:"inputTokenCount"`
OutputTokenCount int `json:"outputTokenCount"`
InvocationLatency int `json:"invocationLatency"`
FirstByteLatency int `json:"firstByteLatency"`
} `json:"amazon-bedrock-invocationMetrics"`
Usage struct {
OutputTokens int `json:"output_tokens"`
} `json:"usage"`
Message struct {
ID string `json:"id"`
Type string `json:"type"`
Role string `json:"role"`
Content []any `json:"content"`
Model string `json:"model"`
StopReason any `json:"stop_reason"`
StopSequence any `json:"stop_sequence"`
Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
} `json:"usage"`
} `json:"message"`
guardrailResponse
}

type invokeModelAPI interface {
InvokeModel(ctx context.Context, params *bedrockruntime.InvokeModelInput, optFns ...func(*bedrockruntime.Options)) (*bedrockruntime.InvokeModelOutput, error)
}

type invokeModelWithResponseStreamAPI interface {
InvokeModelWithResponseStream(ctx context.Context, params *bedrockruntime.InvokeModelWithResponseStreamInput, optFns ...func(*bedrockruntime.Options)) (*bedrockruntime.InvokeModelWithResponseStreamOutput, error)
}

type guardrailConfigurableAdapter interface {
body() []byte
setGuardrailIdentifier(*string)
setGuardrailVersion(*string)
setBody([]byte)
}

type guardrailConfigurableInput struct {
*bedrockruntime.InvokeModelInput
}

var _ guardrailConfigurableAdapter = (*guardrailConfigurableInput)(nil)

func (i *guardrailConfigurableInput) setGuardrailIdentifier(identifier *string) {
i.GuardrailIdentifier = identifier
}
func (i *guardrailConfigurableInput) setGuardrailVersion(version *string) {
i.GuardrailVersion = version
}
func (i *guardrailConfigurableInput) body() []byte { return i.Body }
func (i *guardrailConfigurableInput) setBody(body []byte) { i.Body = body }

type guardrailConfigurableStreamInput struct {
*bedrockruntime.InvokeModelWithResponseStreamInput
}

var _ guardrailConfigurableAdapter = (*guardrailConfigurableStreamInput)(nil)

func (i *guardrailConfigurableStreamInput) setGuardrailIdentifier(identifier *string) {
i.GuardrailIdentifier = identifier
}
func (i *guardrailConfigurableStreamInput) setGuardrailVersion(version *string) {
i.GuardrailVersion = version
}
func (i *guardrailConfigurableStreamInput) body() []byte { return i.Body }
func (i *guardrailConfigurableStreamInput) setBody(body []byte) { i.Body = body }

func newInvokeModelInput(modelID string, body []byte, options llms.CallOptions) (*bedrockruntime.InvokeModelInput, error) {
input := &guardrailConfigurableInput{
&bedrockruntime.InvokeModelInput{
ModelId: aws.String(modelID),
Accept: aws.String(bedrockAcceptAll),
ContentType: aws.String(bedrockJSONContentType),
Body: body,
},
}
if err := handleGuardrailParams(input, options); err != nil {
return nil, err
}
return input.InvokeModelInput, nil
}

func newInvokeModelWithResponseStreamInput(modelID string, body []byte, options llms.CallOptions) (*bedrockruntime.InvokeModelWithResponseStreamInput, error) {
input := &guardrailConfigurableStreamInput{
&bedrockruntime.InvokeModelWithResponseStreamInput{
ModelId: aws.String(modelID),
Accept: aws.String(bedrockAcceptAll),
ContentType: aws.String(bedrockJSONContentType),
Body: body,
},
}
if err := handleGuardrailParams(input, options); err != nil {
return nil, err
}
return input.InvokeModelWithResponseStreamInput, nil
}

func invokeModel(ctx context.Context, client invokeModelAPI, modelID string, body []byte, options llms.CallOptions) (*bedrockruntime.InvokeModelOutput, error) {
input, err := newInvokeModelInput(modelID, body, options)
if err != nil {
return nil, err
}

output, err := client.InvokeModel(ctx, input)
if err != nil {
return nil, err
}
if err := checkGuardrailError(output.Body); err != nil {
return nil, err
}

return output, nil
}

func invokeModelWithResponseStream(ctx context.Context, client invokeModelWithResponseStreamAPI, modelID string, body []byte, options llms.CallOptions) (*bedrockruntime.InvokeModelWithResponseStreamOutput, error) {
input, err := newInvokeModelWithResponseStreamInput(modelID, body, options)
if err != nil {
return nil, err
}
return client.InvokeModelWithResponseStream(ctx, input)
}
Loading
Loading