Skip to content

Commit

Permalink
Bugfix: main API doesn't expose previous messages cap
Browse files Browse the repository at this point in the history
The previously added capability to add previously exchanged
messages was only exposed in the individual provider
implementations and not in the main libaiac API.
  • Loading branch information
ido50 committed Jul 1, 2024
1 parent ee8e8d8 commit 8ac377f
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 44 deletions.
37 changes: 24 additions & 13 deletions libaiac/bedrock/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,9 @@ import (
// Conversation is a struct used to converse with a Bedrock chat model. It
// maintains all messages sent/received in order to maintain context.
type Conversation struct {
// Messages is the list of all messages exchanged between the user and the
// assistant.
Messages []bedrocktypes.Message

backend *Bedrock
model string
backend *Bedrock
model string
messages []bedrocktypes.Message
}

// Chat initiates a conversation with a Bedrock chat model. A conversation
Expand All @@ -28,20 +25,20 @@ type Conversation struct {
// in the past. This practically allows "loading" previous conversations and
// continuing them.
func (backend *Bedrock) Chat(model string, msgs ...types.Message) types.Conversation {
chat := &Conversation{
conv := &Conversation{
backend: backend,
model: model,
}

if len(msgs) > 0 {
chat.Messages = make([]bedrocktypes.Message, len(msgs))
conv.messages = make([]bedrocktypes.Message, len(msgs))
for i := range msgs {
role := bedrocktypes.ConversationRoleUser
if msgs[i].Role == "assistant" {
role = bedrocktypes.ConversationRoleAssistant
}

chat.Messages[i] = bedrocktypes.Message{
conv.messages[i] = bedrocktypes.Message{
Role: role,
Content: []bedrocktypes.ContentBlock{
&bedrocktypes.ContentBlockMemberText{Value: msgs[i].Content},
Expand All @@ -50,7 +47,7 @@ func (backend *Bedrock) Chat(model string, msgs ...types.Message) types.Conversa
}
}

return chat
return conv
}

// Send sends the provided message to the backend and returns a Response object.
Expand All @@ -61,7 +58,7 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
res types.Response,
err error,
) {
conv.Messages = append(conv.Messages, bedrocktypes.Message{
conv.messages = append(conv.messages, bedrocktypes.Message{
Role: bedrocktypes.ConversationRoleUser,
Content: []bedrocktypes.ContentBlock{
&bedrocktypes.ContentBlockMemberText{Value: prompt},
Expand All @@ -70,7 +67,7 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (

input := bedrockruntime.ConverseInput{
ModelId: aws.String(conv.model),
Messages: conv.Messages,
Messages: conv.messages,
InferenceConfig: &bedrocktypes.InferenceConfiguration{
Temperature: aws.Float32(0.2),
},
Expand Down Expand Up @@ -101,11 +98,25 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
res.TokensUsed = int64(*output.Usage.TotalTokens)
res.StopReason = string(output.StopReason)

conv.Messages = append(conv.Messages, outputMsg)
conv.messages = append(conv.messages, outputMsg)

if res.Code, ok = types.ExtractCode(res.FullOutput); !ok {
res.Code = res.FullOutput
}

return res, nil
}

// Messages returns all the messages that have been exchanged between the user
// and the assistant up to this point.
func (conv *Conversation) Messages() []types.Message {
msgs := make([]types.Message, len(conv.messages))
for i, m := range conv.messages {
content, _ := m.Content[0].(*bedrocktypes.ContentBlockMemberText)
msgs[i] = types.Message{
Role: string(m.Role),
Content: content.Value,
}
}
return msgs
}
16 changes: 10 additions & 6 deletions libaiac/libaiac.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,15 @@ func (aiac *Aiac) ListModels(ctx context.Context, backendName string) (
// sent and received. If backendName is an empty string, the default backend
// defined in the configuration will be used, if any. If model is an empty
// string, the default model defined in the backend configuration will be used,
// if any.
func (aiac *Aiac) Chat(ctx context.Context, backendName, model string) (
chat types.Conversation,
err error,
) {
// if any. Users can also supply zero or more "previous messages" that may have
// been exchanged in the past. This practically allows "loading" previous
// conversations and continuing them.
func (aiac *Aiac) Chat(
ctx context.Context,
backendName string,
model string,
msgs ...types.Message,
) (chat types.Conversation, err error) {
backend, defaultModel, err := aiac.loadBackend(ctx, backendName)
if err != nil {
return chat, fmt.Errorf("failed loading backend: %w", err)
Expand All @@ -84,7 +88,7 @@ func (aiac *Aiac) Chat(ctx context.Context, backendName, model string) (
model = defaultModel
}

return backend.Chat(model), nil
return backend.Chat(model, msgs...), nil
}

func (aiac *Aiac) loadBackend(ctx context.Context, name string) (
Expand Down
27 changes: 15 additions & 12 deletions libaiac/ollama/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,9 @@ import (
// Conversation is a struct used to converse with an Ollama chat model. It
// maintains all messages sent/received in order to maintain context.
type Conversation struct {
// Messages is the list of all messages exchanged between the user and the
// assistant.
Messages []types.Message

backend *Ollama
model string
backend *Ollama
model string
messages []types.Message
}

type chatResponse struct {
Expand All @@ -31,16 +28,16 @@ type chatResponse struct {
// in the past. This practically allows "loading" previous conversations and
// continuing them.
func (backend *Ollama) Chat(model string, msgs ...types.Message) types.Conversation {
chat := &Conversation{
conv := &Conversation{
backend: backend,
model: model,
}

if len(msgs) > 0 {
chat.Messages = msgs
conv.messages = msgs
}

return chat
return conv
}

// Send sends the provided message to the API and returns a Response object.
Expand All @@ -53,15 +50,15 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
) {
var answer chatResponse

conv.Messages = append(conv.Messages, types.Message{
conv.messages = append(conv.messages, types.Message{
Role: "user",
Content: prompt,
})

err = conv.backend.NewRequest("POST", "/chat").
JSONBody(map[string]interface{}{
"model": conv.model,
"messages": conv.Messages,
"messages": conv.messages,
"options": map[string]interface{}{
"temperature": 0.2,
},
Expand All @@ -73,7 +70,7 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
return res, fmt.Errorf("failed sending prompt: %w", err)
}

conv.Messages = append(conv.Messages, answer.Message)
conv.messages = append(conv.messages, answer.Message)

res.FullOutput = strings.TrimSpace(answer.Message.Content)
if answer.Done {
Expand All @@ -89,3 +86,9 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (

return res, nil
}

// Messages returns all the messages that have been exchanged between the user
// and the assistant up to this point.
func (conv *Conversation) Messages() []types.Message {
return conv.messages
}
27 changes: 15 additions & 12 deletions libaiac/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@ import (
// maintains all messages sent/received in order to maintain context just like
// using ChatGPT.
type Conversation struct {
// Messages is the list of all messages exchanged between the user and the
// assistant.
Messages []types.Message

backend *OpenAI
model string
backend *OpenAI
model string
messages []types.Message
}

type chatResponse struct {
Expand All @@ -38,16 +35,16 @@ type chatResponse struct {
// messages" that may have been exchanged in the past. This practically allows
// "loading" previous conversations and continuing them.
func (backend *OpenAI) Chat(model string, msgs ...types.Message) types.Conversation {
chat := &Conversation{
conv := &Conversation{
backend: backend,
model: model,
}

if len(msgs) > 0 {
chat.Messages = msgs
conv.messages = msgs
}

return chat
return conv
}

// Send sends the provided message to the API and returns a Response object.
Expand All @@ -60,7 +57,7 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
) {
var answer chatResponse

conv.Messages = append(conv.Messages, types.Message{
conv.messages = append(conv.messages, types.Message{
Role: "user",
Content: prompt,
})
Expand All @@ -74,7 +71,7 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
NewRequest("POST", fmt.Sprintf("/chat/completions%s", apiVersion)).
JSONBody(map[string]interface{}{
"model": conv.model,
"messages": conv.Messages,
"messages": conv.messages,
"temperature": 0.2,
}).
Into(&answer).
Expand All @@ -87,7 +84,7 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
return res, types.ErrNoResults
}

conv.Messages = append(conv.Messages, answer.Choices[0].Message)
conv.messages = append(conv.messages, answer.Choices[0].Message)

res.FullOutput = strings.TrimSpace(answer.Choices[0].Message.Content)
res.APIKeyUsed = conv.backend.apiKey
Expand All @@ -101,3 +98,9 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (

return res, nil
}

// Messages returns all the messages that have been exchanged between the user
// and the assistant up to this point.
func (conv *Conversation) Messages() []types.Message {
return conv.messages
}
8 changes: 7 additions & 1 deletion libaiac/types/interfaces.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package types

import "context"
import (
"context"
)

// Backend is an interface that must be implemented in order to support an LLM
// provider.
Expand All @@ -20,4 +22,8 @@ type Backend interface {
type Conversation interface {
// Send sends a message to the model and returns the response.
Send(context.Context, string) (Response, error)

// Messages returns all the messages that have been exchanged between the
// user and the assistant up to this point
Messages() []Message
}

0 comments on commit 8ac377f

Please sign in to comment.