diff --git a/ROADMAP.md b/ROADMAP.md index 43407f8..ec834f2 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -1,4 +1,5 @@ # Roadmap -- [ ] add support for conversation history -- [ ] add support for setting a model for a conversation -- [ ] add support for multiple conversations \ No newline at end of file +- [x] add support for conversation history +- [x] add support for setting a model for a conversation +- [ ] add support for multiple conversations +- [ ] implement stopping a stream \ No newline at end of file diff --git a/adapters/outbound/ollama_llm_service.go b/adapters/outbound/ollama_llm_service.go index 32cf2f3..50fb909 100644 --- a/adapters/outbound/ollama_llm_service.go +++ b/adapters/outbound/ollama_llm_service.go @@ -4,15 +4,17 @@ import ( cont "context" "errors" + "github.com/ChristianSch/Theta/domain/models" "github.com/ChristianSch/Theta/domain/ports/outbound" "github.com/jmorganca/ollama/api" "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/llms/ollama" + "github.com/tmc/langchaingo/schema" ) type OllamaLlmService struct { client *api.Client - llm *ollama.LLM + llm *ollama.Chat model *string log outbound.Log } @@ -52,7 +54,7 @@ func (s *OllamaLlmService) ListModels() ([]string, error) { func (s *OllamaLlmService) SetModel(model string) error { s.model = &model - llm, err := ollama.New(ollama.WithModel(model)) + llm, err := ollama.NewChat(ollama.WithLLMOptions(ollama.WithModel(model))) if err != nil { return err } @@ -63,15 +65,29 @@ func (s *OllamaLlmService) SetModel(model string) error { return nil } -func (s *OllamaLlmService) SendMessage(prompt string, context []string, resHandler outbound.ResponseHandler) error { +func (s *OllamaLlmService) SendMessage(prompt string, context []models.Message, resHandler outbound.ResponseHandler) error { if s.llm == nil { return errors.New(ModelNotSetError) } - ctx := cont.Background() + var messages []schema.ChatMessage + + for _, msg := range context { + if msg.Type == models.UserMessage { + messages = append(messages, schema.HumanChatMessage{ + Content: msg.Text, + }) + } else { + messages = append(messages, schema.AIChatMessage{ + Content: msg.Text, + }) + } + } + + messages = append(messages, schema.HumanChatMessage{ + Content: prompt, + }) - _, err := s.llm.Call(ctx, prompt, - llms.WithStreamingFunc(resHandler), - ) + _, err := s.llm.Call(cont.Background(), messages, llms.WithStreamingFunc(resHandler)) return err } diff --git a/adapters/outbound/repo/in_memory_conversation_repo.go b/adapters/outbound/repo/in_memory_conversation_repo.go new file mode 100644 index 0000000..601a06d --- /dev/null +++ b/adapters/outbound/repo/in_memory_conversation_repo.go @@ -0,0 +1,68 @@ +package repo + +import ( + "errors" + "sync" + + "github.com/ChristianSch/Theta/domain/models" + "github.com/google/uuid" +) + +const ( + ErrConversationNotFound = "conversation not found" +) + +type InMemoryConversationRepo struct { + conversations map[string]models.Conversation + mu sync.Mutex +} + +func NewInMemoryConversationRepo() *InMemoryConversationRepo { + return &InMemoryConversationRepo{ + conversations: make(map[string]models.Conversation), + } +} + +func nextId() string { + return uuid.New().String() +} + +func (r *InMemoryConversationRepo) CreateConversation(model string) (models.Conversation, error) { + r.mu.Lock() + defer r.mu.Unlock() + + conv := models.Conversation{ + Id: nextId(), + Model: model, + } + r.conversations[conv.Id] = conv + + return conv, nil +} + +func (r *InMemoryConversationRepo) GetConversation(id string) (models.Conversation, error) { + r.mu.Lock() + defer r.mu.Unlock() + + conv, ok := r.conversations[id] + if !ok { + return models.Conversation{}, errors.New(ErrConversationNotFound) + } + + return conv, nil +} + +func (r *InMemoryConversationRepo) AddMessage(id string, message models.Message) (models.Conversation, error) { + r.mu.Lock() + defer r.mu.Unlock() + + conv, ok := r.conversations[id] + if !ok { + return models.Conversation{}, errors.New(ErrConversationNotFound) + } + + conv.Messages = append(conv.Messages, message) + r.conversations[id] = conv + + return conv, nil +} diff --git a/domain/models/conversation.go b/domain/models/conversation.go new file mode 100644 index 0000000..9a094f1 --- /dev/null +++ b/domain/models/conversation.go @@ -0,0 +1,12 @@ +package models + +import "time" + +type Conversation struct { + Id string + ConversationStart time.Time + Model string + Messages []Message + // Active means that this conversation can be continued (needs the model to be available in the main context) + Active bool +} diff --git a/domain/ports/outbound/llm_service.go b/domain/ports/outbound/llm_service.go index 4f67dea..56ada71 100644 --- a/domain/ports/outbound/llm_service.go +++ b/domain/ports/outbound/llm_service.go @@ -1,6 +1,10 @@ package outbound -import "context" +import ( + "context" + + "github.com/ChristianSch/Theta/domain/models" +) // ResponseHandler is a function that is called for every data chunk that is received. EOF is indicated by an empty chunk. type ResponseHandler func(ctx context.Context, chunk []byte) error @@ -8,5 +12,5 @@ type ResponseHandler func(ctx context.Context, chunk []byte) error type LlmService interface { ListModels() ([]string, error) SetModel(model string) error - SendMessage(prompt string, context []string, resHandler ResponseHandler) error + SendMessage(prompt string, context []models.Message, resHandler ResponseHandler) error } diff --git a/domain/ports/outbound/repo/conversation_repo.go b/domain/ports/outbound/repo/conversation_repo.go new file mode 100644 index 0000000..910d4c0 --- /dev/null +++ b/domain/ports/outbound/repo/conversation_repo.go @@ -0,0 +1,11 @@ +package repo + +import "github.com/ChristianSch/Theta/domain/models" + +type ConversationRepo interface { + // CreateConversation creates a new conversation with the given model + CreateConversation(model string) (models.Conversation, error) + // GetConversation returns the conversation with the given id + GetConversation(id string) (models.Conversation, error) + AddMessage(id string, message models.Message) (models.Conversation, error) +} diff --git a/domain/usecases/chat/handle_incoming_message.go b/domain/usecases/chat/handle_incoming_message.go index 3534c38..69f55d6 100644 --- a/domain/usecases/chat/handle_incoming_message.go +++ b/domain/usecases/chat/handle_incoming_message.go @@ -8,16 +8,18 @@ import ( "github.com/ChristianSch/Theta/domain/models" "github.com/ChristianSch/Theta/domain/ports/outbound" + "github.com/ChristianSch/Theta/domain/ports/outbound/repo" "github.com/gofiber/fiber/v2/log" "github.com/google/uuid" ) type IncomingMessageHandlerConfig struct { // dependencies - Sender outbound.SendMessageService - Formatter outbound.MessageFormatter - Llm outbound.LlmService - PostProcessors []outbound.PostProcessor + Sender outbound.SendMessageService + Formatter outbound.MessageFormatter + Llm outbound.LlmService + PostProcessors []outbound.PostProcessor + ConversationRepo repo.ConversationRepo } type IncomingMessageHandler struct { @@ -43,7 +45,7 @@ func NewIncomingMessageHandler(cfg IncomingMessageHandlerConfig) *IncomingMessag } } -func (h *IncomingMessageHandler) Handle(message models.Message, connection interface{}) error { +func (h *IncomingMessageHandler) Handle(message models.Message, conversation models.Conversation, connection interface{}) error { msgId := fmt.Sprintf("msg-%s", strings.Split(uuid.New().String(), "-")[0]) log.Debug("starting processing of message", outbound.LogField{Key: "messageId", Value: msgId}) @@ -119,7 +121,7 @@ func (h *IncomingMessageHandler) Handle(message models.Message, connection inter // send message to llm via a goroutine so we can wait for the answer go func() { - err = h.cfg.Llm.SendMessage(message.Text, []string{}, fn) + err = h.cfg.Llm.SendMessage(message.Text, conversation.Messages, fn) if err != nil { done <- true } @@ -128,6 +130,13 @@ func (h *IncomingMessageHandler) Handle(message models.Message, connection inter // wait for answer to be finished <-done + // add message and answer to conversation + h.cfg.ConversationRepo.AddMessage(conversation.Id, message) + h.cfg.ConversationRepo.AddMessage(conversation.Id, models.Message{ + Text: string(chunks), + Type: models.GptMessage, + }) + if err != nil { log.Error("error while receiving answer", outbound.LogField{Key: "component", Value: "handle_incoming_message"}, diff --git a/infrastructure/views/chat.gohtml b/infrastructure/views/chat.gohtml index a006dc3..17be40d 100644 --- a/infrastructure/views/chat.gohtml +++ b/infrastructure/views/chat.gohtml @@ -1,25 +1,30 @@
+ {{ range.Messages }} + {{ . }} + {{ end }}
- + + >{{.UserMessage}}
Don't blindly trust the LLM. Use at your own risk. We don't secure you @@ -31,8 +36,6 @@
- + diff --git a/infrastructure/views/new_chat.gohtml b/infrastructure/views/new_chat.gohtml new file mode 100644 index 0000000..0e62bad --- /dev/null +++ b/infrastructure/views/new_chat.gohtml @@ -0,0 +1,48 @@ +
+
+
+ + +
+ +
+ + + + + + +
+ Don't blindly trust the LLM. Use at your own risk. We don't secure you + against XSS or other attacks on purpose, as that would mean censored + output. +
+ +
+ + + \ No newline at end of file diff --git a/main.go b/main.go index 9bd059a..8f18a59 100644 --- a/main.go +++ b/main.go @@ -3,10 +3,12 @@ package main import ( "encoding/json" "errors" + "html/template" "time" "github.com/ChristianSch/Theta/adapters/inbound" "github.com/ChristianSch/Theta/adapters/outbound" + "github.com/ChristianSch/Theta/adapters/outbound/repo" "github.com/ChristianSch/Theta/domain/models" outboundPorts "github.com/ChristianSch/Theta/domain/ports/outbound" "github.com/ChristianSch/Theta/domain/usecases/chat" @@ -36,6 +38,9 @@ func main() { log.Debug("available ollama models", outboundPorts.LogField{Key: "models", Value: ollamaModels}) + // all models + llmModels := ollamaModels + // TODO: init openai if len(ollamaModels) == 0 { panic(errors.New("no models available")) @@ -45,7 +50,7 @@ func main() { ollama.SetModel(ollamaModels[0]) web := inbound.NewFiberWebServer(inbound.FiberWebServerConfig{ - Port: 8080, + Port: 5467, TemplatesPath: "./infrastructure/views", TemplatesExtension: ".gohtml", StaticResourcesPath: "./infrastructure/static", @@ -54,6 +59,9 @@ func main() { // markdown 2 html post processor mdToHtmlPostProcessor := outbound.NewMdToHtmlLlmPostProcessor() + // conversation repo + convRepo := repo.NewInMemoryConversationRepo() + msgSender := outbound.NewSendFiberWebsocketMessage(outbound.SendFiberWebsocketMessageConfig{Log: log}) msgFormatter := outbound.NewFiberMessageFormatter(outbound.FiberMessageFormatterConfig{ MessageTemplatePath: "./infrastructure/views/components/message.gohtml", @@ -69,20 +77,105 @@ func main() { Name: mdToHtmlPostProcessor.GetName(), }, }, + ConversationRepo: convRepo, }) web.AddRoute("GET", "/", func(ctx interface{}) error { - log.Debug("handling request", outboundPorts.LogField{Key: "path", Value: "/"}) fiberCtx := ctx.(*fiber.Ctx) + return fiberCtx.Render("new_chat", fiber.Map{ + "Title": "", + "Models": llmModels, + }, "layouts/main") + }) + + web.AddRoute("GET", "/chat", func(ctx interface{}) error { + fiberCtx := ctx.(*fiber.Ctx) + return fiberCtx.Redirect("/") + }) + + // create new conversation + web.AddRoute("POST", "/chat", func(ctx interface{}) error { + fiberCtx := ctx.(*fiber.Ctx) + + // get model from form + model := fiberCtx.FormValue("model") + if model == "" { + log.Error("no model specified", outboundPorts.LogField{Key: "error", Value: "no model specified"}) + return fiberCtx.Redirect("/") + } + + // get message from form + message := fiberCtx.FormValue("message") + if message == "" { + log.Error("no message specified", outboundPorts.LogField{Key: "error", Value: "no message specified"}) + return fiberCtx.Redirect("/") + } + + // get conversation + conv, err := convRepo.CreateConversation(model) + if err != nil { + log.Error("error while creating conversation", outboundPorts.LogField{Key: "error", Value: err}) + return err + } + + fiberCtx.Append("HX-Replace-Url", "/chat/"+conv.Id) + return fiberCtx.Render("chat", fiber.Map{ - "Title": "", - "Model": ollamaModels[0], + "Title": "", + "Models": llmModels, + "ConversationId": conv.Id, + "UserMessage": message, + }, "layouts/empty") + }) + + // open existing conversation + web.AddRoute("GET", "/chat/:id", func(ctx interface{}) error { + fiberCtx := ctx.(*fiber.Ctx) + convId := fiberCtx.Params("id") + + // get conversation + conv, err := convRepo.GetConversation(convId) + if err != nil { + log.Error("error while getting conversation", outboundPorts.LogField{Key: "error", Value: err}) + return fiberCtx.Redirect("/") + } + + var renderedMessages []template.HTML + + for _, msg := range conv.Messages { + renderedMsg, err := msgFormatter.Format(msg) + if err != nil { + log.Error("error while formatting message", outboundPorts.LogField{Key: "error", Value: err}) + return err + } + + // note that you shouldn't do this under no circumstances, this circumvents the XSS protection + renderedMessages = append(renderedMessages, template.HTML(renderedMsg)) + } + + return fiberCtx.Render("chat", fiber.Map{ + "Title": "", + "Model": conv.Model, + "ConversationId": conv.Id, + "Messages": renderedMessages, }, "layouts/main") }) - web.AddWebsocketRoute("/ws/chat", func(conn interface{}) error { - log.Debug("handling websocket request", outboundPorts.LogField{Key: "path", Value: "/ws/chat"}) + web.AddWebsocketRoute("/ws/chat/:id", func(conn interface{}) error { fiberConn := conn.(*websocket.Conn) + convId := fiberConn.Params("id") + log.Debug("handling websocket request", + outboundPorts.LogField{Key: "path", Value: "/ws/chat/:id"}, + outboundPorts.LogField{Key: "id", Value: convId}) + + // get conversation + conv, err := convRepo.GetConversation(convId) + if err != nil { + log.Error("error while getting conversation", outboundPorts.LogField{Key: "error", Value: err}) + return err + } + + log.Debug("conversation received message", outboundPorts.LogField{Key: "conversation", Value: conv}) for { messageType, message, err := fiberConn.ReadMessage() @@ -103,15 +196,18 @@ func main() { outboundPorts.LogField{Key: "messageType", Value: messageType}, ) - msg := models.Message{ - Text: wsMsg.Message, - Timestamp: time.Now(), - Type: models.UserMessage, - } - - if err := msgHandler.Handle(msg, fiberConn); err != nil { - log.Error("error while writing message", outboundPorts.LogField{Key: "error", Value: err}) - break + if len(wsMsg.Message) > 0 { + msg := models.Message{ + Text: wsMsg.Message, + Timestamp: time.Now(), + Type: models.UserMessage, + } + + // add message to conversation! + if err := msgHandler.Handle(msg, conv, fiberConn); err != nil { + log.Error("error while writing message", outboundPorts.LogField{Key: "error", Value: err}) + break + } } }