forked from aws-samples/amazon-bedrock-samples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_choice_converse_bedrock_streamlit.py
138 lines (126 loc) · 4.17 KB
/
model_choice_converse_bedrock_streamlit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import streamlit as st
from threading import Thread
import boto3
import json
import pandas as pd
from botocore.config import Config
print('Boto3 version:', boto3.__version__)
### Streamlit setup
st.set_page_config(layout="wide")
my_config = Config(read_timeout=600,
retries={
'max_attempts': 10,
'mode': 'standard'
})
### Constants
REGION = 'us-east-1'
MODEL_IDS = [
"amazon.titan-text-premier-v1:0",
"amazon.titan-text-express-v1",
"amazon.titan-text-lite-v1",
"ai21.j2-ultra-v1",
"ai21.j2-mid-v1",
"anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-3-haiku-20240307-v1:0",
"cohere.command-r-plus-v1:0",
"cohere.command-r-v1:0",
"meta.llama3-70b-instruct-v1:0",
"meta.llama3-8b-instruct-v1:0",
"mistral.mistral-large-2402-v1:0",
"mistral.mixtral-8x7b-instruct-v0:1",
"mistral.mistral-7b-instruct-v0:2",
"mistral.mistral-small-2402-v1:0"
]
messages = []
### Function for invoking Bedrock Converse
def invoke_bedrock_model(client, id, prompt, max_tokens=2000, temperature=0, top_p=0.9):
response = ""
messages=[
{
"role": "user",
"content": [
{
"text": prompt
}
]
}
]
try:
response = client.converse(
modelId=id,
messages=messages,
inferenceConfig={
"temperature": temperature,
"maxTokens": max_tokens,
"topP": top_p
}
# additionalModelRequestFields={
# }
)
except Exception as e:
print(e)
result = "Model invocation error"
try:
result = {
'Request': {
'modelId': id,
'messages': messages,
'inferenceConfig': {
'temperature': temperature,
'maxTokens': max_tokens,
'topP': top_p
},
},
'Response': {
'output': response['output'],
'stopReason': response['stopReason'],
'usage': response['usage'],
'metrics': response['metrics']
}
}
except Exception as e:
print(e)
result = "Output parsing error"
return result
### Class for theading calls
class ModelThread(Thread):
def __init__(self, model_id, prompt):
Thread.__init__(self)
self.model_id = model_id
self.model_response = None
self.prompt = prompt
self.client = boto3.client("bedrock-runtime", region_name=REGION, config=my_config)
def run(self):
response = invoke_bedrock_model(self.client, self.model_id, self.prompt)
self.model_response = response
print(f'{self.model_id} DONE')
### Function for invoking models in parallel
def invokeModelsInParallel(prompt):
threads = [ModelThread(model_id=m, prompt=prompt) for m in MODEL_IDS]
for thread in threads:
thread.start()
model_responses = {}
for thread in threads:
thread.join()
model_responses[thread.model_id] = thread.model_response
return model_responses
col1, col2 = st.columns([1, 9])
with col1:
st.image('./images/bedrock.png', width=60)
with col2:
st.write("#### Converse API for Amazon Bedrock - Model Choice Demo")
tabs = st.tabs(["Model Responses", "Message Details"])
with tabs[0]:
st.markdown = "Write your prompt..."
prompt = st.text_input("Input Prompt")
if st.button('Go') or prompt != '':
with st.spinner('Generating...'):
model_responses = invokeModelsInParallel(prompt)
with tabs[1]:
for model_id, response in model_responses.items():
with st.expander(model_id):
st.json(response)
table = [[key, value['Response']['output']['message']['content'][0]['text']] for key, value in model_responses.items()]
df = pd.DataFrame(table, columns=['Model', 'ModelResponse'])
df.index += 1
st.table(df)