Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,31 @@ const client = new OpenAI({
});
```

### AWS Bedrock

Requires the AWS SDK peer dependencies (`npm install @aws-sdk/credential-providers @smithy/signature-v4 @aws-crypto/sha256-js`). Credentials are resolved from the [standard AWS credential chain](https://docs.aws.amazon.com/sdkref/latest/guide/standardized-credentials.html).

```ts
import OpenAI from 'openai';
import { awsBedrockTokenProvider } from 'openai/auth';

const client = new OpenAI({
baseURL: 'https://bedrock-mantle.us-east-1.api.aws/v1', // region must match the token provider
apiKey: awsBedrockTokenProvider({
region: 'us-east-1',
profile: 'my-profile', // optional — defaults to the standard AWS credential chain
}),
});

// List models supported by the OpenAI-compatible endpoint
const models = await client.models.list();
for (const model of models.data) {
console.log(model.id);
}
```

> **Note:** The OpenAI SDK works only with Bedrock models that have the [OpenAI-compatible API](https://docs.aws.amazon.com/bedrock/latest/userguide/bedrock-mantle.html) enabled. Use `client.models.list()` to see which models are available on your endpoint.

### Custom subject token provider

```ts
Expand Down
14 changes: 13 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,26 @@
},
"peerDependencies": {
"ws": "^8.18.0",
"zod": "^3.25 || ^4.0"
"zod": "^3.25 || ^4.0",
"@aws-sdk/credential-providers": "^3.0.0",
"@smithy/signature-v4": "^3.0.0 || ^4.0.0 || ^5.0.0",
"@aws-crypto/sha256-js": "^3.0.0 || ^4.0.0 || ^5.0.0"
},
"peerDependenciesMeta": {
"ws": {
"optional": true
},
"zod": {
"optional": true
},
"@aws-sdk/credential-providers": {
"optional": true
},
"@smithy/signature-v4": {
"optional": true
},
"@aws-crypto/sha256-js": {
"optional": true
}
}
}
1 change: 1 addition & 0 deletions src/auth/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export {
k8sServiceAccountTokenProvider,
azureManagedIdentityTokenProvider,
gcpIDTokenProvider,
awsBedrockTokenProvider,
} from './subject-token-providers';

export { OAuthError, SubjectTokenProviderError } from '../core/error';
118 changes: 118 additions & 0 deletions src/auth/subject-token-providers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import type { SubjectTokenProvider } from './types';
import type { Fetch } from '../internal/builtin-types';
import * as Shims from '../internal/shims';
import { SubjectTokenProviderError } from '../core/error';
import { toBase64 } from '../internal/utils/base64';
import { readEnv } from '../internal/utils/env';

const DEFAULT_RESOURCE = 'https://management.azure.com/';
const DEFAULT_AZURE_API_VERSION = '2018-02-01';
Expand Down Expand Up @@ -183,3 +185,119 @@ export function gcpIDTokenProvider(
},
};
}

/**
* Get a token provider for AWS Bedrock using IAM credentials.
*
* Returns an async callable that generates a bearer token from a SigV4 presigned URL.
* Pass it directly to `apiKey` when creating an OpenAI client pointed at a
* Bedrock runtime endpoint. Credentials are resolved from the standard AWS credential chain:
* https://docs.aws.amazon.com/sdkref/latest/guide/standardized-credentials.html
*
* The AWS SDK modules are cached so import resolution is efficient, while the token
* itself is regenerated on each call to ensure it always reflects the latest valid
* credentials (important for short-lived STS/assumed-role sessions).
*
* @param config.region - AWS region. Defaults to `AWS_REGION` or `AWS_DEFAULT_REGION` environment variable.
* @param config.profile - AWS profile name. If not set, credentials are resolved from the standard chain.
* @param config.tokenDuration - Presigned URL expiry in seconds. Defaults to 3600 (1 hour).
*/
export function awsBedrockTokenProvider(config?: {
region?: string;
profile?: string;
tokenDuration?: number;
}): () => Promise<string> {
const tokenDuration = config?.tokenDuration ?? 3600;

let cachedModules: { credProviders: any; SignatureV4Cls: any; Sha256Cls: any } | null = null;

async function getAwsModules() {
if (cachedModules) return cachedModules;

try {
const [credModule, sigV4Module, sha256Module] = await Promise.all([
import('@aws-sdk/credential-providers' as any),
import('@smithy/signature-v4' as any),
import('@aws-crypto/sha256-js' as any),
]);
cachedModules = {
credProviders: credModule,
SignatureV4Cls: sigV4Module.SignatureV4,
Sha256Cls: sha256Module.Sha256,
};
return cachedModules;
} catch (e) {
throw new Error(
'@aws-sdk/credential-providers, @smithy/signature-v4, and @aws-crypto/sha256-js are required ' +
'for AWS Bedrock token generation. Install them with: ' +
'npm install @aws-sdk/credential-providers @smithy/signature-v4 @aws-crypto/sha256-js',
);
}
}

return async (): Promise<string> => {
const { credProviders, SignatureV4Cls, Sha256Cls } = await getAwsModules();

try {
const resolvedRegion = config?.region || readEnv('AWS_REGION') || readEnv('AWS_DEFAULT_REGION');
if (!resolvedRegion) {
throw new SubjectTokenProviderError(
"AWS region must be provided via the 'region' parameter, " +
'or the AWS_REGION / AWS_DEFAULT_REGION environment variable.',
'aws-bedrock',
);
}

const credentialProvider =
config?.profile ?
credProviders.fromIni({ profile: config.profile })
: credProviders.fromNodeProviderChain();

const credentials = await credentialProvider();

const signer = new SignatureV4Cls({
service: 'bedrock',
region: resolvedRegion,
credentials,
sha256: Sha256Cls,
});

const request = {
method: 'POST',
hostname: 'bedrock.amazonaws.com',
path: '/',
query: { Action: 'CallWithBearerToken' },
headers: {
host: 'bedrock.amazonaws.com',
},
protocol: 'https:',
};

const presigned = await signer.presign(request, {
expiresIn: tokenDuration,
});
Comment on lines +276 to +278
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Validate Bedrock token TTL against service limits

tokenDuration is passed directly to signer.presign without bounds checking, so callers can request values above Bedrock’s short-term bearer-token limit (12 hours) and receive tokens that are later rejected at request time. This creates a runtime auth failure path that is avoidable and hard to diagnose from the caller side; validating/clamping to the accepted range before signing would fail fast with a clear provider error.

Useful? React with 👍 / 👎.


// Reconstruct the signed URL from the presigned request
const queryParams = presigned.query as Record<string, string>;
const queryString = Object.entries(queryParams)
.map(([k, v]: [string, string]) => `${encodeURIComponent(k)}=${encodeURIComponent(v)}`)
.join('&');
const signedUrl = `https://bedrock.amazonaws.com/?${queryString}`;

// Strip https:// prefix, append Version=1, and base64-encode
const urlWithoutScheme = signedUrl.slice('https://'.length);
const encodedToken = toBase64(`${urlWithoutScheme}&Version=1`);

return `bedrock-api-key-${encodedToken}`;
} catch (e) {
if (e instanceof SubjectTokenProviderError) {
throw e;
}
throw new SubjectTokenProviderError(
`Failed to generate AWS Bedrock token: ${e instanceof Error ? e.message : String(e)}`,
'aws-bedrock',
e instanceof Error ? e : undefined,
);
}
};
}
190 changes: 190 additions & 0 deletions tests/auth/subject-token-providers.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import {
k8sServiceAccountTokenProvider,
azureManagedIdentityTokenProvider,
gcpIDTokenProvider,
awsBedrockTokenProvider,
} from 'openai/auth/subject-token-providers';
import { SubjectTokenProviderError } from 'openai';

Expand Down Expand Up @@ -201,3 +202,192 @@ describe('GCP Metadata Server Token Provider', () => {
await expect(provider.getToken()).rejects.toThrow('Failed to fetch token from GCP Metadata Server');
});
});

function makeMockAwsSdk(opts?: {
accessKeyId?: string;
secretAccessKey?: string;
sessionToken?: string;
noCredentials?: boolean;
}) {
const credentials =
opts?.noCredentials ? null : (
{
accessKeyId: opts?.accessKeyId ?? 'AKIAIOSFODNN7EXAMPLE',
secretAccessKey: opts?.secretAccessKey ?? 'wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY',
sessionToken: opts?.sessionToken,
}
);

const mockPresign = jest.fn(async (request: any, options: any) => {
return {
...request,
query: {
...request.query,
'X-Amz-Algorithm': 'AWS4-HMAC-SHA256',
'X-Amz-Credential': `${credentials?.accessKeyId}/20260428/us-east-1/bedrock/aws4_request`,
'X-Amz-Date': '20260428T000000Z',
'X-Amz-Expires': String(options?.expiresIn ?? 43200),
'X-Amz-SignedHeaders': 'host',
'X-Amz-Signature': 'fakesignature1234567890',
},
};
});

const mockCredentialProvider = jest.fn(async () => {
if (!credentials) {
throw new Error('No AWS credentials found');
}
return credentials;
});

const mockFromNodeProviderChain = jest.fn(() => mockCredentialProvider);
const mockFromIni = jest.fn((_opts: any) => mockCredentialProvider);

const mockSignatureV4 = jest.fn().mockImplementation(() => ({
presign: mockPresign,
}));

const mockSha256 = jest.fn();

return {
credProviders: {
fromNodeProviderChain: mockFromNodeProviderChain,
fromIni: mockFromIni,
},
sigV4: {
SignatureV4: mockSignatureV4,
},
sha256: {
Sha256: mockSha256,
},
mocks: {
presign: mockPresign,
credentialProvider: mockCredentialProvider,
fromNodeProviderChain: mockFromNodeProviderChain,
fromIni: mockFromIni,
SignatureV4: mockSignatureV4,
},
};
}

describe('AWS Bedrock Token Provider', () => {
const originalEnv = process.env;

beforeEach(() => {
jest.clearAllMocks();
jest.resetModules();
process.env = { ...originalEnv };
});

afterEach(() => {
process.env = originalEnv;
});

test('generates a valid bedrock token', async () => {
const aws = makeMockAwsSdk();

jest.mock('@aws-sdk/credential-providers', () => aws.credProviders, { virtual: true });
jest.mock('@smithy/signature-v4', () => aws.sigV4, { virtual: true });
jest.mock('@aws-crypto/sha256-js', () => aws.sha256, { virtual: true });

const getToken = awsBedrockTokenProvider({ region: 'us-east-1' });
const token = await getToken();

expect(token.startsWith('bedrock-api-key-')).toBe(true);

const encodedPart = token.slice('bedrock-api-key-'.length);
const decodedUrl = Buffer.from(encodedPart, 'base64').toString('utf-8');

expect(decodedUrl).toContain('bedrock.amazonaws.com');
expect(decodedUrl).toContain('X-Amz-Signature=');
expect(decodedUrl).toContain('X-Amz-Credential=');
expect(decodedUrl).toContain('Action=CallWithBearerToken');
expect(decodedUrl).toContain('&Version=1');
});

test('uses custom region in the signed request', async () => {
const aws = makeMockAwsSdk();

jest.mock('@aws-sdk/credential-providers', () => aws.credProviders, { virtual: true });
jest.mock('@smithy/signature-v4', () => aws.sigV4, { virtual: true });
jest.mock('@aws-crypto/sha256-js', () => aws.sha256, { virtual: true });

const getToken = awsBedrockTokenProvider({ region: 'eu-west-1' });
await getToken();

expect(aws.mocks.SignatureV4).toHaveBeenCalledWith(
expect.objectContaining({ region: 'eu-west-1', service: 'bedrock' }),
);
});

test('uses profile when provided', async () => {
const aws = makeMockAwsSdk();

jest.mock('@aws-sdk/credential-providers', () => aws.credProviders, { virtual: true });
jest.mock('@smithy/signature-v4', () => aws.sigV4, { virtual: true });
jest.mock('@aws-crypto/sha256-js', () => aws.sha256, { virtual: true });

const getToken = awsBedrockTokenProvider({ region: 'us-east-1', profile: 'my-profile' });
await getToken();

expect(aws.mocks.fromIni).toHaveBeenCalledWith({ profile: 'my-profile' });
expect(aws.mocks.fromNodeProviderChain).not.toHaveBeenCalled();
});

test('throws SubjectTokenProviderError when no credentials found', async () => {
const aws = makeMockAwsSdk({ noCredentials: true });

jest.mock('@aws-sdk/credential-providers', () => aws.credProviders, { virtual: true });
jest.mock('@smithy/signature-v4', () => aws.sigV4, { virtual: true });
jest.mock('@aws-crypto/sha256-js', () => aws.sha256, { virtual: true });

const getToken = awsBedrockTokenProvider({ region: 'us-east-1' });
await expect(getToken()).rejects.toThrow(SubjectTokenProviderError);
await expect(getToken()).rejects.toThrow('Failed to generate AWS Bedrock token');
});

test('throws SubjectTokenProviderError when region is not set', async () => {
const aws = makeMockAwsSdk();

jest.mock('@aws-sdk/credential-providers', () => aws.credProviders, { virtual: true });
jest.mock('@smithy/signature-v4', () => aws.sigV4, { virtual: true });
jest.mock('@aws-crypto/sha256-js', () => aws.sha256, { virtual: true });

delete process.env['AWS_REGION'];
delete process.env['AWS_DEFAULT_REGION'];

const getToken = awsBedrockTokenProvider();
await expect(getToken()).rejects.toThrow(SubjectTokenProviderError);
await expect(getToken()).rejects.toThrow('AWS region must be provided');
});

test('resolves region from AWS_REGION env var', async () => {
const aws = makeMockAwsSdk();

jest.mock('@aws-sdk/credential-providers', () => aws.credProviders, { virtual: true });
jest.mock('@smithy/signature-v4', () => aws.sigV4, { virtual: true });
jest.mock('@aws-crypto/sha256-js', () => aws.sha256, { virtual: true });

process.env['AWS_REGION'] = 'ap-southeast-1';

const getToken = awsBedrockTokenProvider();
await getToken();

expect(aws.mocks.SignatureV4).toHaveBeenCalledWith(expect.objectContaining({ region: 'ap-southeast-1' }));
});

test('regenerates token on each call (no caching)', async () => {
const aws = makeMockAwsSdk();

jest.mock('@aws-sdk/credential-providers', () => aws.credProviders, { virtual: true });
jest.mock('@smithy/signature-v4', () => aws.sigV4, { virtual: true });
jest.mock('@aws-crypto/sha256-js', () => aws.sha256, { virtual: true });

const getToken = awsBedrockTokenProvider({ region: 'us-east-1' });
await getToken();
await getToken();

// presign should be called each time — no token caching
expect(aws.mocks.presign).toHaveBeenCalledTimes(2);
});
});
Loading