Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(chore)APM: Refactor Bedrock Integration #5137

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
/packages/datadog-plugin-langchain/ @DataDog/ml-observability
/packages/datadog-instrumentations/src/openai.js @DataDog/ml-observability
/packages/datadog-instrumentations/src/langchain.js @DataDog/ml-observability
/packages/datadog-plugin-aws-sdk/src/services/bedrockruntime @DataDog/ml-observability
yahya-mouman marked this conversation as resolved.
Show resolved Hide resolved
/packages/datadog-plugin-aws-sdk/test/bedrockruntime.spec.js @DataDog/ml-observability

# CI
/.github/workflows/appsec.yml @DataDog/asm-js
Expand Down
4 changes: 3 additions & 1 deletion packages/datadog-instrumentations/src/aws-sdk.js
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ function getMessage (request, error, result) {
}

function getChannelSuffix (name) {
// some resource identifiers have spaces between ex: bedrock runtime
name = name.split(' ').join('')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
name = name.split(' ').join('')
name = name.replaceAll(' ', '')

This should be quicker as it skips the intermediary array representation.

return [
'cloudwatchlogs',
'dynamodb',
Expand All @@ -168,7 +170,7 @@ function getChannelSuffix (name) {
'sqs',
'states',
'stepfunctions',
'bedrock runtime'
'bedrockruntime'
].includes(name)
? name
: 'default'
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
const CompositePlugin = require('../../../../dd-trace/src/plugins/composite')
const BedrockRuntimeTracing = require('./tracing')
class BedrockRuntimePlugin extends CompositePlugin {
static get id () {
return 'bedrockruntime'
}

static get plugins () {
return {
tracing: BedrockRuntimeTracing
}
}
}
module.exports = BedrockRuntimePlugin
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
'use strict'

const BaseAwsSdkPlugin = require('../../base')
const { parseModelId, extractRequestParams, extractTextAndResponseReason } = require('./utils')

const enabledOperations = ['invokeModel']

class BedrockRuntime extends BaseAwsSdkPlugin {
static get id () { return 'bedrockruntime' }

isEnabled (request) {
const operation = request.operation
if (!enabledOperations.includes(operation)) {
return false
}

return super.isEnabled(request)
}

generateTags (params, operation, response) {
const { modelProvider, modelName } = parseModelId(params.modelId)

const requestParams = extractRequestParams(params, modelProvider)
const textAndResponseReason = extractTextAndResponseReason(response, modelProvider, modelName)

const tags = buildTagsFromParams(requestParams, textAndResponseReason, modelProvider, modelName, operation)

return tags
}
}

function buildTagsFromParams (requestParams, textAndResponseReason, modelProvider, modelName, operation) {
const tags = {}

// add request tags
tags['resource.name'] = operation
tags['aws.bedrock.request.model'] = modelName
tags['aws.bedrock.request.model_provider'] = modelProvider.toLowerCase()
tags['aws.bedrock.request.prompt'] = requestParams.prompt
tags['aws.bedrock.request.temperature'] = requestParams.temperature
tags['aws.bedrock.request.top_p'] = requestParams.topP
tags['aws.bedrock.request.top_k'] = requestParams.topK
tags['aws.bedrock.request.max_tokens'] = requestParams.maxTokens
tags['aws.bedrock.request.stop_sequences'] = requestParams.stopSequences
tags['aws.bedrock.request.input_type'] = requestParams.inputType
tags['aws.bedrock.request.truncate'] = requestParams.truncate
tags['aws.bedrock.request.stream'] = requestParams.stream
tags['aws.bedrock.request.n'] = requestParams.n

// add response tags
if (modelName.includes('embed')) {
tags['aws.bedrock.response.embedding_length'] = textAndResponseReason.message.length
}
if (textAndResponseReason.choiceId) {
tags['aws.bedrock.response.choices.id'] = textAndResponseReason.choiceId
}
tags['aws.bedrock.response.choices.text'] = textAndResponseReason.message
tags['aws.bedrock.response.choices.finish_reason'] = textAndResponseReason.finishReason

return tags
}

module.exports = BedrockRuntime
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
'use strict'

const BaseAwsSdkPlugin = require('../base')
const log = require('../../../dd-trace/src/log')
const log = require('../../../../dd-trace/src/log')

const MODEL_TYPE_IDENTIFIERS = [
'foundation-model/',
'custom-model/',
'provisioned-model/',
'imported-module/',
'prompt/',
'endpoint/',
'inference-profile/',
'default-prompt-router/'
]

const PROVIDER = {
AI21: 'AI21',
Expand All @@ -13,44 +23,6 @@ const PROVIDER = {
MISTRAL: 'MISTRAL'
}

const enabledOperations = ['invokeModel']

class BedrockRuntime extends BaseAwsSdkPlugin {
static get id () { return 'bedrock runtime' }

isEnabled (request) {
const operation = request.operation
if (!enabledOperations.includes(operation)) {
return false
}

return super.isEnabled(request)
}

generateTags (params, operation, response) {
let tags = {}
let modelName = ''
let modelProvider = ''
const modelMeta = params.modelId.split('.')
if (modelMeta.length === 2) {
[modelProvider, modelName] = modelMeta
modelProvider = modelProvider.toUpperCase()
} else {
[, modelProvider, modelName] = modelMeta
modelProvider = modelProvider.toUpperCase()
}

const shouldSetChoiceIds = modelProvider === PROVIDER.COHERE && !modelName.includes('embed')

const requestParams = extractRequestParams(params, modelProvider)
const textAndResponseReason = extractTextAndResponseReason(response, modelProvider, modelName, shouldSetChoiceIds)

tags = buildTagsFromParams(requestParams, textAndResponseReason, modelProvider, modelName, operation)

return tags
}
}

class Generation {
constructor ({ message = '', finishReason = '', choiceId = '' } = {}) {
// stringify message as it could be a single generated message as well as a list of embeddings
Expand All @@ -65,18 +37,19 @@ class RequestParams {
prompt = '',
temperature = undefined,
topP = undefined,
topK = undefined,
maxTokens = undefined,
stopSequences = [],
inputType = '',
truncate = '',
stream = '',
n = undefined
} = {}) {
// TODO: set a truncation limit to prompt
// stringify prompt as it could be a single prompt as well as a list of message objects
this.prompt = typeof prompt === 'string' ? prompt : JSON.stringify(prompt) || ''
this.temperature = temperature !== undefined ? temperature : undefined
this.topP = topP !== undefined ? topP : undefined
this.topK = topK !== undefined ? topK : undefined
this.maxTokens = maxTokens !== undefined ? maxTokens : undefined
this.stopSequences = stopSequences || []
this.inputType = inputType || ''
Expand All @@ -86,11 +59,53 @@ class RequestParams {
}
}

function parseModelId (modelId) {
// Best effort to extract the model provider and model name from the bedrock model ID.
// modelId can be a 1/2 period-separated string or a full AWS ARN, based on the following formats:
// 1. Base model: "{model_provider}.{model_name}"
// 2. Cross-region model: "{region}.{model_provider}.{model_name}"
// 3. Other: Prefixed by AWS ARN "arn:aws{+region?}:bedrock:{region}:{account-id}:"
// a. Foundation model: ARN prefix + "foundation-model/{region?}.{model_provider}.{model_name}"
// b. Custom model: ARN prefix + "custom-model/{model_provider}.{model_name}"
// c. Provisioned model: ARN prefix + "provisioned-model/{model-id}"
// d. Imported model: ARN prefix + "imported-module/{model-id}"
// e. Prompt management: ARN prefix + "prompt/{prompt-id}"
// f. Sagemaker: ARN prefix + "endpoint/{model-id}"
// g. Inference profile: ARN prefix + "{application-?}inference-profile/{model-id}"
// h. Default prompt router: ARN prefix + "default-prompt-router/{prompt-id}"
// If model provider cannot be inferred from the modelId formatting, then default to "custom"
modelId = modelId.toLowerCase()
if (!modelId.startsWith('arn:aws')) {
const modelMeta = modelId.split('.')
if (modelMeta.length < 2) {
return { modelProvider: 'custom', modelName: modelMeta[0] }
}
return { modelProvider: modelMeta[modelMeta.length - 2], modelName: modelMeta[modelMeta.length - 1] }
}

for (const identifier of MODEL_TYPE_IDENTIFIERS) {
if (!modelId.includes(identifier)) {
continue
}
modelId = modelId.split(identifier).pop()
if (['foundation-model/', 'custom-model/'].includes(identifier)) {
const modelMeta = modelId.split('.')
if (modelMeta.length < 2) {
return { modelProvider: 'custom', modelName: modelId }
}
return { modelProvider: modelMeta[modelMeta.length - 2], modelName: modelMeta[modelMeta.length - 1] }
}
return { modelProvider: 'custom', modelName: modelId }
}

return { modelProvider: 'custom', modelName: 'custom' }
}

function extractRequestParams (params, provider) {
const requestBody = JSON.parse(params.body)
const modelId = params.modelId

switch (provider) {
switch (provider.toUpperCase()) {
case PROVIDER.AI21: {
let userPrompt = requestBody.prompt
if (modelId.includes('jamba')) {
Expand Down Expand Up @@ -176,11 +191,11 @@ function extractRequestParams (params, provider) {
}
}

function extractTextAndResponseReason (response, provider, modelName, shouldSetChoiceIds) {
function extractTextAndResponseReason (response, provider, modelName) {
const body = JSON.parse(Buffer.from(response.body).toString('utf8'))

const shouldSetChoiceIds = provider.toUpperCase() === PROVIDER.COHERE && !modelName.includes('embed')
try {
switch (provider) {
switch (provider.toUpperCase()) {
case PROVIDER.AI21: {
if (modelName.includes('jamba')) {
const generations = body.choices || []
Expand Down Expand Up @@ -262,34 +277,11 @@ function extractTextAndResponseReason (response, provider, modelName, shouldSetC
return new Generation()
}

function buildTagsFromParams (requestParams, textAndResponseReason, modelProvider, modelName, operation) {
const tags = {}

// add request tags
tags['resource.name'] = operation
tags['aws.bedrock.request.model'] = modelName
tags['aws.bedrock.request.model_provider'] = modelProvider
tags['aws.bedrock.request.prompt'] = requestParams.prompt
tags['aws.bedrock.request.temperature'] = requestParams.temperature
tags['aws.bedrock.request.top_p'] = requestParams.topP
tags['aws.bedrock.request.max_tokens'] = requestParams.maxTokens
tags['aws.bedrock.request.stop_sequences'] = requestParams.stopSequences
tags['aws.bedrock.request.input_type'] = requestParams.inputType
tags['aws.bedrock.request.truncate'] = requestParams.truncate
tags['aws.bedrock.request.stream'] = requestParams.stream
tags['aws.bedrock.request.n'] = requestParams.n

// add response tags
if (modelName.includes('embed')) {
tags['aws.bedrock.response.embedding_length'] = textAndResponseReason.message.length
}
if (textAndResponseReason.choiceId) {
tags['aws.bedrock.response.choices.id'] = textAndResponseReason.choiceId
}
tags['aws.bedrock.response.choices.text'] = textAndResponseReason.message
tags['aws.bedrock.response.choices.finish_reason'] = textAndResponseReason.finishReason

return tags
module.exports = {
Generation,
RequestParams,
parseModelId,
extractRequestParams,
extractTextAndResponseReason,
PROVIDER
}

module.exports = BedrockRuntime
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ const PROVIDER = {
}

describe('Plugin', () => {
describe('aws-sdk (bedrock)', function () {
describe('aws-sdk (bedrockruntime)', function () {
setup()

withVersions('aws-sdk', ['@aws-sdk/smithy-client', 'aws-sdk'], '>=3', (version, moduleName) => {
Expand Down Expand Up @@ -217,7 +217,7 @@ describe('Plugin', () => {
expect(span.meta).to.include({
'aws.operation': 'invokeModel',
'aws.bedrock.request.model': model.modelId.split('.')[1],
'aws.bedrock.request.model_provider': model.provider,
'aws.bedrock.request.model_provider': model.provider.toLowerCase(),
'aws.bedrock.request.prompt': model.userPrompt
})
expect(span.metrics).to.include({
Expand Down
Loading