Skip to content

Commit

Permalink
Allow conversations to have "previous messages"
Browse files Browse the repository at this point in the history
This commit modifies the `Chat` method (in a backwards compatible
way) to allow users to start the conversation with previously
exchanged messages, i.e. messages that had already been exchanged
between the user and the assistant. This practically allows
"loading" previous conversations in order to continue them.

To support this, the `Chat` method now takes `types.Message`
values in a variadic way. The messages field of the `Conversation`
type is also exposed now, so that users may save it.
  • Loading branch information
ido50 committed Jul 1, 2024
1 parent ae51057 commit b209dc1
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 29 deletions.
43 changes: 34 additions & 9 deletions libaiac/bedrock/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,44 @@ import (
// Conversation is a struct used to converse with a Bedrock chat model. It
// maintains all messages sent/received in order to maintain context.
type Conversation struct {
backend *Bedrock
model string
messages []bedrocktypes.Message
// Messages is the list of all messages exchanged between the user and the
// assistant.
Messages []bedrocktypes.Message

backend *Bedrock
model string
}

// Chat initiates a conversation with a Bedrock chat model. A conversation
// maintains context, allowing to send further instructions to modify the output
// from previous requests.
func (backend *Bedrock) Chat(model string) types.Conversation {
return &Conversation{
// from previous requests. The name of the model to use must be provided. Users
// can also supply zero or more "previous messages" that may have been exchanged
// in the past. This practically allows "loading" previous conversations and
// continuing them.
func (backend *Bedrock) Chat(model string, msgs ...types.Message) types.Conversation {
chat := &Conversation{
backend: backend,
model: model,
}

if len(msgs) > 0 {
chat.Messages = make([]bedrocktypes.Message, len(msgs))
for i := range msgs {
role := bedrocktypes.ConversationRoleUser
if msgs[i].Role == "assistant" {
role = bedrocktypes.ConversationRoleAssistant
}

chat.Messages[i] = bedrocktypes.Message{
Role: role,
Content: []bedrocktypes.ContentBlock{
&bedrocktypes.ContentBlockMemberText{Value: msgs[i].Content},
},
}
}
}

return chat
}

// Send sends the provided message to the backend and returns a Response object.
Expand All @@ -36,7 +61,7 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
res types.Response,
err error,
) {
conv.messages = append(conv.messages, bedrocktypes.Message{
conv.Messages = append(conv.Messages, bedrocktypes.Message{
Role: bedrocktypes.ConversationRoleUser,
Content: []bedrocktypes.ContentBlock{
&bedrocktypes.ContentBlockMemberText{Value: prompt},
Expand All @@ -45,7 +70,7 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (

input := bedrockruntime.ConverseInput{
ModelId: aws.String(conv.model),
Messages: conv.messages,
Messages: conv.Messages,
InferenceConfig: &bedrocktypes.InferenceConfiguration{
Temperature: aws.Float32(0.2),
},
Expand Down Expand Up @@ -76,7 +101,7 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
res.TokensUsed = int64(*output.Usage.TotalTokens)
res.StopReason = string(output.StopReason)

conv.messages = append(conv.messages, outputMsg)
conv.Messages = append(conv.Messages, outputMsg)

if res.Code, ok = types.ExtractCode(res.FullOutput); !ok {
res.Code = res.FullOutput
Expand Down
30 changes: 21 additions & 9 deletions libaiac/ollama/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@ import (
// Conversation is a struct used to converse with an Ollama chat model. It
// maintains all messages sent/received in order to maintain context.
type Conversation struct {
backend *Ollama
model string
messages []types.Message
// Messages is the list of all messages exchanged between the user and the
// assistant.
Messages []types.Message

backend *Ollama
model string
}

type chatResponse struct {
Expand All @@ -23,12 +26,21 @@ type chatResponse struct {

// Chat initiates a conversation with an Ollama chat model. A conversation
// maintains context, allowing to send further instructions to modify the output
// from previous requests.
func (backend *Ollama) Chat(model string) types.Conversation {
return &Conversation{
// from previous requests. The name of the model to use must be provided. Users
// can also supply zero or more "previous messages" that may have been exchanged
// in the past. This practically allows "loading" previous conversations and
// continuing them.
func (backend *Ollama) Chat(model string, msgs ...types.Message) types.Conversation {
chat := &Conversation{
backend: backend,
model: model,
}

if len(msgs) > 0 {
chat.Messages = msgs
}

return chat
}

// Send sends the provided message to the API and returns a Response object.
Expand All @@ -41,15 +53,15 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
) {
var answer chatResponse

conv.messages = append(conv.messages, types.Message{
conv.Messages = append(conv.Messages, types.Message{
Role: "user",
Content: prompt,
})

err = conv.backend.NewRequest("POST", "/chat").
JSONBody(map[string]interface{}{
"model": conv.model,
"messages": conv.messages,
"messages": conv.Messages,
"options": map[string]interface{}{
"temperature": 0.2,
},
Expand All @@ -61,7 +73,7 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
return res, fmt.Errorf("failed sending prompt: %w", err)
}

conv.messages = append(conv.messages, answer.Message)
conv.Messages = append(conv.Messages, answer.Message)

res.FullOutput = strings.TrimSpace(answer.Message.Content)
if answer.Done {
Expand Down
30 changes: 21 additions & 9 deletions libaiac/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@ import (
// maintains all messages sent/received in order to maintain context just like
// using ChatGPT.
type Conversation struct {
backend *OpenAI
model string
messages []types.Message
// Messages is the list of all messages exchanged between the user and the
// assistant.
Messages []types.Message

backend *OpenAI
model string
}

type chatResponse struct {
Expand All @@ -30,12 +33,21 @@ type chatResponse struct {

// Chat initiates a conversation with an OpenAI chat model. A conversation
// maintains context, allowing to send further instructions to modify the output
// from previous requests, just like using the ChatGPT website.
func (backend *OpenAI) Chat(model string) types.Conversation {
return &Conversation{
// from previous requests, just like using the ChatGPT website. The name of the
// model to use must be provided. Users can also supply zero or more "previous
// messages" that may have been exchanged in the past. This practically allows
// "loading" previous conversations and continuing them.
func (backend *OpenAI) Chat(model string, msgs ...types.Message) types.Conversation {
chat := &Conversation{
backend: backend,
model: model,
}

if len(msgs) > 0 {
chat.Messages = msgs
}

return chat
}

// Send sends the provided message to the API and returns a Response object.
Expand All @@ -48,7 +60,7 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
) {
var answer chatResponse

conv.messages = append(conv.messages, types.Message{
conv.Messages = append(conv.Messages, types.Message{
Role: "user",
Content: prompt,
})
Expand All @@ -62,7 +74,7 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
NewRequest("POST", fmt.Sprintf("/chat/completions%s", apiVersion)).
JSONBody(map[string]interface{}{
"model": conv.model,
"messages": conv.messages,
"messages": conv.Messages,
"temperature": 0.2,
}).
Into(&answer).
Expand All @@ -75,7 +87,7 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
return res, types.ErrNoResults
}

conv.messages = append(conv.messages, answer.Choices[0].Message)
conv.Messages = append(conv.Messages, answer.Choices[0].Message)

res.FullOutput = strings.TrimSpace(answer.Choices[0].Message.Content)
res.APIKeyUsed = conv.backend.apiKey
Expand Down
7 changes: 5 additions & 2 deletions libaiac/types/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@ type Backend interface {
// ListModels returns a list of all models supported by the backend.
ListModels(context.Context) ([]string, error)

// Chat initiates a conversation with an LLM backend.
Chat(string) Conversation
// Chat initiates a conversation with an LLM backend. The name of the model
// to use must be provided. Users can also supply zero or more "previous
// messages" that may have been exchanged in the past. This practically
// allows "loading" previous conversations and continuing them.
Chat(string, ...Message) Conversation
}

// Conversation is an interface that must be implemented in order to support
Expand Down

0 comments on commit b209dc1

Please sign in to comment.