diff --git a/ROADMAP.md b/ROADMAP.md index 43407f8..ec834f2 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -1,4 +1,5 @@ # Roadmap -- [ ] add support for conversation history -- [ ] add support for setting a model for a conversation -- [ ] add support for multiple conversations \ No newline at end of file +- [x] add support for conversation history +- [x] add support for setting a model for a conversation +- [ ] add support for multiple conversations +- [ ] implement stopping a stream \ No newline at end of file diff --git a/adapters/outbound/ollama_llm_service.go b/adapters/outbound/ollama_llm_service.go index 32cf2f3..50fb909 100644 --- a/adapters/outbound/ollama_llm_service.go +++ b/adapters/outbound/ollama_llm_service.go @@ -4,15 +4,17 @@ import ( cont "context" "errors" + "github.com/ChristianSch/Theta/domain/models" "github.com/ChristianSch/Theta/domain/ports/outbound" "github.com/jmorganca/ollama/api" "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/llms/ollama" + "github.com/tmc/langchaingo/schema" ) type OllamaLlmService struct { client *api.Client - llm *ollama.LLM + llm *ollama.Chat model *string log outbound.Log } @@ -52,7 +54,7 @@ func (s *OllamaLlmService) ListModels() ([]string, error) { func (s *OllamaLlmService) SetModel(model string) error { s.model = &model - llm, err := ollama.New(ollama.WithModel(model)) + llm, err := ollama.NewChat(ollama.WithLLMOptions(ollama.WithModel(model))) if err != nil { return err } @@ -63,15 +65,29 @@ func (s *OllamaLlmService) SetModel(model string) error { return nil } -func (s *OllamaLlmService) SendMessage(prompt string, context []string, resHandler outbound.ResponseHandler) error { +func (s *OllamaLlmService) SendMessage(prompt string, context []models.Message, resHandler outbound.ResponseHandler) error { if s.llm == nil { return errors.New(ModelNotSetError) } - ctx := cont.Background() + var messages []schema.ChatMessage + + for _, msg := range context { + if msg.Type == models.UserMessage { + messages = append(messages, schema.HumanChatMessage{ + Content: msg.Text, + }) + } else { + messages = append(messages, schema.AIChatMessage{ + Content: msg.Text, + }) + } + } + + messages = append(messages, schema.HumanChatMessage{ + Content: prompt, + }) - _, err := s.llm.Call(ctx, prompt, - llms.WithStreamingFunc(resHandler), - ) + _, err := s.llm.Call(cont.Background(), messages, llms.WithStreamingFunc(resHandler)) return err } diff --git a/adapters/outbound/repo/in_memory_conversation_repo.go b/adapters/outbound/repo/in_memory_conversation_repo.go new file mode 100644 index 0000000..601a06d --- /dev/null +++ b/adapters/outbound/repo/in_memory_conversation_repo.go @@ -0,0 +1,68 @@ +package repo + +import ( + "errors" + "sync" + + "github.com/ChristianSch/Theta/domain/models" + "github.com/google/uuid" +) + +const ( + ErrConversationNotFound = "conversation not found" +) + +type InMemoryConversationRepo struct { + conversations map[string]models.Conversation + mu sync.Mutex +} + +func NewInMemoryConversationRepo() *InMemoryConversationRepo { + return &InMemoryConversationRepo{ + conversations: make(map[string]models.Conversation), + } +} + +func nextId() string { + return uuid.New().String() +} + +func (r *InMemoryConversationRepo) CreateConversation(model string) (models.Conversation, error) { + r.mu.Lock() + defer r.mu.Unlock() + + conv := models.Conversation{ + Id: nextId(), + Model: model, + } + r.conversations[conv.Id] = conv + + return conv, nil +} + +func (r *InMemoryConversationRepo) GetConversation(id string) (models.Conversation, error) { + r.mu.Lock() + defer r.mu.Unlock() + + conv, ok := r.conversations[id] + if !ok { + return models.Conversation{}, errors.New(ErrConversationNotFound) + } + + return conv, nil +} + +func (r *InMemoryConversationRepo) AddMessage(id string, message models.Message) (models.Conversation, error) { + r.mu.Lock() + defer r.mu.Unlock() + + conv, ok := r.conversations[id] + if !ok { + return models.Conversation{}, errors.New(ErrConversationNotFound) + } + + conv.Messages = append(conv.Messages, message) + r.conversations[id] = conv + + return conv, nil +} diff --git a/domain/models/conversation.go b/domain/models/conversation.go new file mode 100644 index 0000000..9a094f1 --- /dev/null +++ b/domain/models/conversation.go @@ -0,0 +1,12 @@ +package models + +import "time" + +type Conversation struct { + Id string + ConversationStart time.Time + Model string + Messages []Message + // Active means that this conversation can be continued (needs the model to be available in the main context) + Active bool +} diff --git a/domain/ports/outbound/llm_service.go b/domain/ports/outbound/llm_service.go index 4f67dea..56ada71 100644 --- a/domain/ports/outbound/llm_service.go +++ b/domain/ports/outbound/llm_service.go @@ -1,6 +1,10 @@ package outbound -import "context" +import ( + "context" + + "github.com/ChristianSch/Theta/domain/models" +) // ResponseHandler is a function that is called for every data chunk that is received. EOF is indicated by an empty chunk. type ResponseHandler func(ctx context.Context, chunk []byte) error @@ -8,5 +12,5 @@ type ResponseHandler func(ctx context.Context, chunk []byte) error type LlmService interface { ListModels() ([]string, error) SetModel(model string) error - SendMessage(prompt string, context []string, resHandler ResponseHandler) error + SendMessage(prompt string, context []models.Message, resHandler ResponseHandler) error } diff --git a/domain/ports/outbound/repo/conversation_repo.go b/domain/ports/outbound/repo/conversation_repo.go new file mode 100644 index 0000000..910d4c0 --- /dev/null +++ b/domain/ports/outbound/repo/conversation_repo.go @@ -0,0 +1,11 @@ +package repo + +import "github.com/ChristianSch/Theta/domain/models" + +type ConversationRepo interface { + // CreateConversation creates a new conversation with the given model + CreateConversation(model string) (models.Conversation, error) + // GetConversation returns the conversation with the given id + GetConversation(id string) (models.Conversation, error) + AddMessage(id string, message models.Message) (models.Conversation, error) +} diff --git a/domain/usecases/chat/handle_incoming_message.go b/domain/usecases/chat/handle_incoming_message.go index 3534c38..69f55d6 100644 --- a/domain/usecases/chat/handle_incoming_message.go +++ b/domain/usecases/chat/handle_incoming_message.go @@ -8,16 +8,18 @@ import ( "github.com/ChristianSch/Theta/domain/models" "github.com/ChristianSch/Theta/domain/ports/outbound" + "github.com/ChristianSch/Theta/domain/ports/outbound/repo" "github.com/gofiber/fiber/v2/log" "github.com/google/uuid" ) type IncomingMessageHandlerConfig struct { // dependencies - Sender outbound.SendMessageService - Formatter outbound.MessageFormatter - Llm outbound.LlmService - PostProcessors []outbound.PostProcessor + Sender outbound.SendMessageService + Formatter outbound.MessageFormatter + Llm outbound.LlmService + PostProcessors []outbound.PostProcessor + ConversationRepo repo.ConversationRepo } type IncomingMessageHandler struct { @@ -43,7 +45,7 @@ func NewIncomingMessageHandler(cfg IncomingMessageHandlerConfig) *IncomingMessag } } -func (h *IncomingMessageHandler) Handle(message models.Message, connection interface{}) error { +func (h *IncomingMessageHandler) Handle(message models.Message, conversation models.Conversation, connection interface{}) error { msgId := fmt.Sprintf("msg-%s", strings.Split(uuid.New().String(), "-")[0]) log.Debug("starting processing of message", outbound.LogField{Key: "messageId", Value: msgId}) @@ -119,7 +121,7 @@ func (h *IncomingMessageHandler) Handle(message models.Message, connection inter // send message to llm via a goroutine so we can wait for the answer go func() { - err = h.cfg.Llm.SendMessage(message.Text, []string{}, fn) + err = h.cfg.Llm.SendMessage(message.Text, conversation.Messages, fn) if err != nil { done <- true } @@ -128,6 +130,13 @@ func (h *IncomingMessageHandler) Handle(message models.Message, connection inter // wait for answer to be finished <-done + // add message and answer to conversation + h.cfg.ConversationRepo.AddMessage(conversation.Id, message) + h.cfg.ConversationRepo.AddMessage(conversation.Id, models.Message{ + Text: string(chunks), + Type: models.GptMessage, + }) + if err != nil { log.Error("error while receiving answer", outbound.LogField{Key: "component", Value: "handle_incoming_message"}, diff --git a/infrastructure/views/chat.gohtml b/infrastructure/views/chat.gohtml index a006dc3..17be40d 100644 --- a/infrastructure/views/chat.gohtml +++ b/infrastructure/views/chat.gohtml @@ -1,25 +1,30 @@