Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ as-any = "0.3.2"
assert_fs = "1.1.3"
async-stream = "0.3.6"
aws-config = { version = "1.8.5", default-features = false }
aws-sdk-bedrockruntime = { version = "1.102.0", default-features = false }
aws-sdk-bedrockruntime = { version = "1.124.0", default-features = false }
aws-smithy-types = "1.3.2"
base64 = "0.22.1"
bytes = "1.10.1"
Expand Down
3 changes: 2 additions & 1 deletion rig-integrations/rig-bedrock/src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ impl completion::CompletionModel for CompletionModel {
.set_inference_config(request.inference_config())
.set_tool_config(tool_config)
.set_system(request.system_prompt()?)
.set_messages(Some(messages));
.set_messages(Some(messages))
.set_output_config(request.output_config());

async move {
let response = converse_builder.send().await.map_err(|sdk_error| {
Expand Down
3 changes: 2 additions & 1 deletion rig-integrations/rig-bedrock/src/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ impl CompletionModel {
.set_inference_config(request.inference_config())
.set_tool_config(tool_config)
.set_system(request.system_prompt()?)
.set_messages(Some(prompt_with_history));
.set_messages(Some(prompt_with_history))
.set_output_config(request.output_config());

let response = converse_builder.send().await.map_err(|sdk_error| {
Into::<CompletionError>::into(AwsSdkConverseStreamError(sdk_error))
Expand Down
104 changes: 104 additions & 0 deletions rig-integrations/rig-bedrock/src/types/completion_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,37 @@ impl AwsCompletionRequest {
}
}

/// Maps rig's `output_schema` to Bedrock's `OutputConfig` for structured output.
pub fn output_config(&self) -> Option<aws_bedrock::OutputConfig> {
self.inner.output_schema.as_ref().map(|schema| {
let schema_name = self
.inner
.output_schema_name()
.unwrap_or_else(|| "response_schema".to_string());

let schema_json = serde_json::to_string(&schema.clone().to_value())
.unwrap_or_else(|_| "{}".to_string());

let json_schema_def = aws_bedrock::JsonSchemaDefinition::builder()
.schema(schema_json)
.name(schema_name)
.build()
.expect("schema field is set");

aws_bedrock::OutputConfig::builder()
.text_format(
aws_bedrock::OutputFormat::builder()
.r#type(aws_bedrock::OutputFormatType::JsonSchema)
.structure(aws_bedrock::OutputFormatStructure::JsonSchema(
json_schema_def,
))
.build()
.expect("type field is set"),
)
.build()
})
}

pub fn system_prompt(&self) -> Result<Option<Vec<SystemContentBlock>>, CompletionError> {
let mut system_blocks = Vec::new();

Expand Down Expand Up @@ -544,4 +575,77 @@ mod tests {
Some(aws_bedrock::ContentBlock::CachePoint(_))
));
}

#[test]
fn test_output_config_none_when_no_schema() {
let request = minimal_request();
let aws_request = aws_request(request, false);
assert!(aws_request.output_config().is_none());
}

#[test]
fn test_output_config_with_schema() {
let schema: schemars::Schema = serde_json::from_value(serde_json::json!({
"type": "object",
"title": "WeatherResponse",
"properties": {
"temperature": { "type": "number" },
"unit": { "type": "string", "enum": ["celsius", "fahrenheit"] }
},
"required": ["temperature", "unit"]
}))
.expect("valid schema");

let request = CompletionRequest {
output_schema: Some(schema),
..minimal_request()
};

let aws_request = aws_request(request, false);
let output_config = aws_request.output_config();

assert!(output_config.is_some());
let config = output_config.unwrap();
let text_format = config.text_format().expect("text_format should be set");
assert_eq!(
*text_format.r#type(),
aws_bedrock::OutputFormatType::JsonSchema
);

let structure = text_format.structure().expect("structure should be set");
let json_schema = structure
.as_json_schema()
.expect("should be JsonSchema variant");
assert_eq!(json_schema.name(), Some("WeatherResponse"));

let parsed: serde_json::Value =
serde_json::from_str(json_schema.schema()).expect("schema should be valid JSON");
assert_eq!(parsed["type"], "object");
assert!(parsed["properties"]["temperature"].is_object());
}

#[test]
fn test_output_config_uses_default_name() {
let schema: schemars::Schema = serde_json::from_value(serde_json::json!({
"type": "object",
"properties": {
"result": { "type": "string" }
}
}))
.expect("valid schema");

let request = CompletionRequest {
output_schema: Some(schema),
..minimal_request()
};

let aws_request = aws_request(request, false);
let config = aws_request.output_config().expect("should have config");
let text_format = config.text_format().expect("text_format should be set");
let structure = text_format.structure().expect("structure should be set");
let json_schema = structure
.as_json_schema()
.expect("should be JsonSchema variant");
assert_eq!(json_schema.name(), Some("response_schema"));
}
}