diff --git a/neo/api.go b/neo/api.go index f047053f9e..d4d34913f0 100644 --- a/neo/api.go +++ b/neo/api.go @@ -471,18 +471,27 @@ func (neo *DSL) handleChatLatest(c *gin.Context) { // Create a new chat if len(chats.Groups) == 0 || len(chats.Groups[0].Chats) == 0 { - ast := neo.Assistant - assistantID := c.Query("assistant_id") - if assistantID != "" { - ast, err = assistant.Get(assistantID) - if err != nil { - c.JSON(500, gin.H{"message": err.Error(), "code": 500}) - c.Done() - return - } + assistantID := neo.Use + queryAssistantID := c.Query("assistant_id") + if queryAssistantID != "" { + assistantID = queryAssistantID + } + + // Get the assistant info + ast, err := assistant.Get(assistantID) + if err != nil { + c.JSON(500, gin.H{"message": err.Error(), "code": 500}) + c.Done() + return } - c.JSON(200, map[string]interface{}{"data": map[string]interface{}{"placeholder": ast.GetPlaceholder()}}) + c.JSON(200, map[string]interface{}{"data": map[string]interface{}{ + "placeholder": ast.GetPlaceholder(), + "assistant_id": ast.ID, + "assistant_name": ast.Name, + "assistant_avatar": ast.Avatar, + "assistant_deleteable": neo.Use != ast.ID, + }}) c.Done() return } @@ -502,6 +511,22 @@ func (neo *DSL) handleChatLatest(c *gin.Context) { return } + // assistant_id is nil return the default assistant + if chat.Chat["assistant_id"] == nil { + chat.Chat["assistant_id"] = neo.Use + + // Get the assistant info + ast, err := assistant.Get(neo.Use) + if err != nil { + c.JSON(500, gin.H{"message": err.Error(), "code": 500}) + c.Done() + return + } + chat.Chat["assistant_name"] = ast.Name + chat.Chat["assistant_avatar"] = ast.Avatar + } + + chat.Chat["assistant_deleteable"] = neo.Use != chat.Chat["assistant_id"] c.JSON(200, map[string]interface{}{"data": chat}) c.Done() } @@ -529,6 +554,22 @@ func (neo *DSL) handleChatDetail(c *gin.Context) { return } + // assistant_id is nil return the default assistant + if chat.Chat["assistant_id"] == nil { + chat.Chat["assistant_id"] = neo.Use + + // Get the assistant info + ast, err := assistant.Get(neo.Use) + if err != nil { + c.JSON(500, gin.H{"message": err.Error(), "code": 500}) + c.Done() + return + } + chat.Chat["assistant_name"] = ast.Name + chat.Chat["assistant_avatar"] = ast.Avatar + } + + chat.Chat["assistant_deleteable"] = neo.Use != chat.Chat["assistant_id"] c.JSON(200, map[string]interface{}{"data": chat}) c.Done() } diff --git a/neo/assistant/api.go b/neo/assistant/api.go index ca66df6ea0..a2883121a5 100644 --- a/neo/assistant/api.go +++ b/neo/assistant/api.go @@ -56,12 +56,12 @@ func (ast *Assistant) Execute(c *gin.Context, ctx chatctx.Context, input string, } // Execute implements the execute functionality -func (ast *Assistant) execute(c *gin.Context, ctx chatctx.Context, input []chatMessage.Message, options map[string]interface{}, contents *chatMessage.Contents) error { +func (ast *Assistant) execute(c *gin.Context, ctx chatctx.Context, input []chatMessage.Message, userOptions map[string]interface{}, contents *chatMessage.Contents) error { if contents == nil { contents = chatMessage.NewContents() } - options = ast.withOptions(options) + options := ast.withOptions(userOptions) // Add RAG and Version support ctx.RAG = rag != nil @@ -78,6 +78,22 @@ func (ast *Assistant) execute(c *gin.Context, ctx chatctx.Context, input []chatM return err } + // Update options if provided + if res != nil && res.Options != nil { + options = res.Options + } + + // messages + if res != nil && res.Input != nil { + input = res.Input + } + + // Handle next action + // It's not used, return the new assistant_id and chat_id + // if res != nil && res.Next != nil { + // return res.Next.Execute(c, ctx, contents) + // } + // Switch to the new assistant if necessary if res != nil && res.AssistantID != ctx.AssistantID { newAst, err := Get(res.AssistantID) @@ -89,22 +105,25 @@ func (ast *Assistant) execute(c *gin.Context, ctx chatctx.Context, input []chatM Write(c.Writer) return err } - *ast = *newAst - } - // Handle next action - if res != nil && res.Next != nil { - return res.Next.Execute(c, ctx, contents) - } + // Reset Message Contents + last := input[len(input)-1] + input, err = newAst.withHistory(ctx, last) + if err != nil { + return err + } - // Update options if provided - if res != nil && res.Options != nil { - options = res.Options - } + // Reset options + options = newAst.withOptions(userOptions) - // messages - if res != nil && res.Input != nil { - input = res.Input + // Update options if provided + if res.Options != nil { + options = res.Options + } + + // Update assistant id + ctx.AssistantID = res.AssistantID + return newAst.handleChatStream(c, ctx, input, options, contents) } // Only proceed with chat stream if no specific next action was handled diff --git a/neo/store/xun.go b/neo/store/xun.go index ae8ac824a5..6f3b4c97c5 100644 --- a/neo/store/xun.go +++ b/neo/store/xun.go @@ -185,6 +185,7 @@ func (conv *Xun) initChatTable() error { table.ID("id") table.String("chat_id", 200).Unique().Index() table.String("title", 200).Null() + table.String("assistant_id", 200).Null().Index() table.String("sid", 255).Index() table.TimestampTz("created_at").SetDefaultRaw("NOW()").Index() table.TimestampTz("updated_at").Null().Index() @@ -202,7 +203,7 @@ func (conv *Xun) initChatTable() error { return err } - fields := []string{"id", "chat_id", "title", "sid", "created_at", "updated_at"} + fields := []string{"id", "chat_id", "title", "assistant_id", "sid", "created_at", "updated_at"} for _, field := range fields { if !tab.HasColumn(field) { return fmt.Errorf("%s is required", field) @@ -336,7 +337,7 @@ func (conv *Xun) GetChats(sid string, filter ChatFilter) (*ChatGroupResponse, er // Build base query qb := conv.newQueryChat(). - Select("chat_id", "title", "created_at", "updated_at"). + Select("chat_id", "title", "assistant_id", "created_at", "updated_at"). Where("sid", userID). Where("chat_id", "!=", "") @@ -384,6 +385,36 @@ func (conv *Xun) GetChats(sid string, filter ChatFilter) (*ChatGroupResponse, er "Even Earlier": {}, } + // Get assistant details for all chats + assistantIDs := []interface{}{} + assistantMap := make(map[string]map[string]interface{}) + + for _, row := range rows { + if assistantID := row.Get("assistant_id"); assistantID != nil && assistantID != "" { + assistantIDs = append(assistantIDs, assistantID) + } + } + + if len(assistantIDs) > 0 { + assistants, err := conv.query.New(). + Table(conv.getAssistantTable()). + Select("assistant_id", "name", "avatar"). + WhereIn("assistant_id", assistantIDs). + Get() + if err != nil { + return nil, err + } + + for _, assistant := range assistants { + if id := assistant.Get("assistant_id"); id != nil { + assistantMap[fmt.Sprintf("%v", id)] = map[string]interface{}{ + "name": assistant.Get("name"), + "avatar": assistant.Get("avatar"), + } + } + } + } + for _, row := range rows { chatID := row.Get("chat_id") if chatID == nil || chatID == "" { @@ -391,8 +422,17 @@ func (conv *Xun) GetChats(sid string, filter ChatFilter) (*ChatGroupResponse, er } chat := map[string]interface{}{ - "chat_id": chatID, - "title": row.Get("title"), + "chat_id": chatID, + "title": row.Get("title"), + "assistant_id": row.Get("assistant_id"), + } + + // Add assistant details if available + if assistantID := row.Get("assistant_id"); assistantID != nil && assistantID != "" { + if assistant, ok := assistantMap[fmt.Sprintf("%v", assistantID)]; ok { + chat["assistant_name"] = assistant["name"] + chat["assistant_avatar"] = assistant["avatar"] + } } var dbDatetime = row.Get("updated_at") @@ -514,6 +554,14 @@ func (conv *Xun) SaveHistory(sid string, messages []map[string]interface{}, cid return err } + // Get assistant_id from context + var assistantID interface{} = nil + if context != nil { + if id, ok := context["assistant_id"].(string); ok && id != "" { + assistantID = id + } + } + // First ensure chat record exists exists, err := conv.newQueryChat(). Where("chat_id", cid). @@ -528,14 +576,28 @@ func (conv *Xun) SaveHistory(sid string, messages []map[string]interface{}, cid // Create new chat record err = conv.newQueryChat(). Insert(map[string]interface{}{ - "chat_id": cid, - "sid": userID, - "created_at": time.Now(), + "chat_id": cid, + "sid": userID, + "assistant_id": assistantID, + "created_at": time.Now(), }) if err != nil { return err } + } else { + // Update assistant_id if it exists + if assistantID != nil { + _, err = conv.newQueryChat(). + Where("chat_id", cid). + Where("sid", userID). + Update(map[string]interface{}{ + "assistant_id": assistantID, + }) + if err != nil { + return err + } + } } // Save message history @@ -637,7 +699,7 @@ func (conv *Xun) GetChat(sid string, cid string) (*ChatInfo, error) { // Get chat info qb := conv.newQueryChat(). - Select("chat_id", "title"). + Select("chat_id", "title", "assistant_id"). Where("sid", userID). Where("chat_id", cid) @@ -652,8 +714,26 @@ func (conv *Xun) GetChat(sid string, cid string) (*ChatInfo, error) { } chat := map[string]interface{}{ - "chat_id": row.Get("chat_id"), - "title": row.Get("title"), + "chat_id": row.Get("chat_id"), + "title": row.Get("title"), + "assistant_id": row.Get("assistant_id"), + } + + // Get assistant details if assistant_id exists + if assistantID := row.Get("assistant_id"); assistantID != nil && assistantID != "" { + assistant, err := conv.query.New(). + Table(conv.getAssistantTable()). + Select("name", "avatar"). + Where("assistant_id", assistantID). + First() + if err != nil { + return nil, err + } + + if assistant != nil { + chat["assistant_name"] = assistant.Get("name") + chat["assistant_avatar"] = assistant.Get("avatar") + } } // Get chat history @@ -715,24 +795,6 @@ func (conv *Xun) DeleteAllChats(sid string) error { return err } -// processJSONField processes a field that should be stored as JSON string -func (conv *Xun) processJSONField(field interface{}) (interface{}, error) { - if field == nil { - return nil, nil - } - - switch v := field.(type) { - case string: - return v, nil - default: - jsonStr, err := jsoniter.MarshalToString(v) - if err != nil { - return nil, fmt.Errorf("failed to marshal %v to JSON: %v", field, err) - } - return jsonStr, nil - } -} - // parseJSONFields parses JSON string fields into their corresponding Go types func (conv *Xun) parseJSONFields(data map[string]interface{}, fields []string) { for _, field := range fields { diff --git a/neo/store/xun_test.go b/neo/store/xun_test.go index 0eef1794b0..e1a4d7d774 100644 --- a/neo/store/xun_test.go +++ b/neo/store/xun_test.go @@ -245,11 +245,15 @@ func TestXunSaveAndGetHistoryWithCID(t *testing.T) { // save the history with specific cid sid := "123456" cid := "789012" + assistantID := "test-assistant-1" messages := []map[string]interface{}{ {"role": "user", "name": "user1", "content": "hello"}, {"role": "assistant", "name": "assistant1", "content": "Hi! How can I help you?"}, } - err = store.SaveHistory(sid, messages, cid, nil) + context := map[string]interface{}{ + "assistant_id": assistantID, + } + err = store.SaveHistory(sid, messages, cid, context) assert.Nil(t, err) // get the history for specific cid @@ -259,14 +263,29 @@ func TestXunSaveAndGetHistoryWithCID(t *testing.T) { } assert.Equal(t, 2, len(data)) - // save another message with different cid + // Verify assistant_id is saved in chat + chat, err := store.GetChat(sid, cid) + assert.Nil(t, err) + assert.Equal(t, assistantID, chat.Chat["assistant_id"]) + + // save another message with different cid and assistant anotherCID := "345678" + anotherAssistantID := "test-assistant-2" moreMessages := []map[string]interface{}{ {"role": "user", "name": "user1", "content": "another message"}, + {"role": "assistant", "name": "assistant2", "content": "Hello!"}, + } + anotherContext := map[string]interface{}{ + "assistant_id": anotherAssistantID, } - err = store.SaveHistory(sid, moreMessages, anotherCID, nil) + err = store.SaveHistory(sid, moreMessages, anotherCID, anotherContext) assert.Nil(t, err) + // Verify second chat's assistant_id + chat2, err := store.GetChat(sid, anotherCID) + assert.Nil(t, err) + assert.Equal(t, anotherAssistantID, chat2.Chat["assistant_id"]) + // get history for the first cid - should still be 2 messages data, err = store.GetHistory(sid, cid) if err != nil { @@ -274,12 +293,12 @@ func TestXunSaveAndGetHistoryWithCID(t *testing.T) { } assert.Equal(t, 2, len(data)) - // get history for the second cid - should be 1 message + // get history for the second cid - should be 2 messages data, err = store.GetHistory(sid, anotherCID) if err != nil { t.Fatal(err) } - assert.Equal(t, 1, len(data)) + assert.Equal(t, 2, len(data)) // get all history for the sid without specifying cid allData, err := store.GetHistory(sid, cid) @@ -294,8 +313,9 @@ func TestXunGetChats(t *testing.T) { defer test.Clean() defer capsule.Schema().DropTableIfExists("__unit_test_conversation_history") defer capsule.Schema().DropTableIfExists("__unit_test_conversation_chat") + defer capsule.Schema().DropTableIfExists("__unit_test_conversation_assistant") - // Drop both tables before test + // Drop tables before test err := capsule.Schema().DropTableIfExists("__unit_test_conversation_history") if err != nil { t.Fatal(err) @@ -304,6 +324,10 @@ func TestXunGetChats(t *testing.T) { if err != nil { t.Fatal(err) } + err = capsule.Schema().DropTableIfExists("__unit_test_conversation_assistant") + if err != nil { + t.Fatal(err) + } store, err := NewXun(Setting{ Connector: "default", @@ -313,45 +337,111 @@ func TestXunGetChats(t *testing.T) { t.Fatal(err) } + // Create test assistants first + assistant1 := map[string]interface{}{ + "assistant_id": "test-assistant-1", + "name": "Test Assistant 1", + "avatar": "avatar1.png", + "type": "assistant", + "connector": "test", + } + assistant2 := map[string]interface{}{ + "assistant_id": "test-assistant-2", + "name": "Test Assistant 2", + "avatar": "avatar2.png", + "type": "assistant", + "connector": "test", + } + _, err = store.SaveAssistant(assistant1) + assert.Nil(t, err) + _, err = store.SaveAssistant(assistant2) + assert.Nil(t, err) + // Save some test chats sid := "test_user" messages := []map[string]interface{}{ {"role": "user", "content": "test message"}, } - // Create chats with different dates + // Create chats with different dates and assistants for i := 0; i < 5; i++ { chatID := fmt.Sprintf("chat_%d", i) title := fmt.Sprintf("Test Chat %d", i) + var context map[string]interface{} + + // Alternate between having assistant and no assistant + if i%2 == 0 { + context = map[string]interface{}{ + "assistant_id": "test-assistant-1", + } + } else if i%3 == 0 { + context = map[string]interface{}{ + "assistant_id": "test-assistant-2", + } + } // Save history first to create the chat - err = store.SaveHistory(sid, messages, chatID, nil) + err = store.SaveHistory(sid, messages, chatID, context) assert.Nil(t, err) // Update the chat title err = store.UpdateChatTitle(sid, chatID, title) assert.Nil(t, err) + + // Verify chat was created with correct assistant info + chat, err := store.GetChat(sid, chatID) + assert.Nil(t, err) + assert.NotNil(t, chat) + assert.Equal(t, chatID, chat.Chat["chat_id"]) + assert.Equal(t, title, chat.Chat["title"]) + + if i%2 == 0 { + assert.Equal(t, "test-assistant-1", chat.Chat["assistant_id"]) + assert.Equal(t, "Test Assistant 1", chat.Chat["assistant_name"]) + assert.Equal(t, "avatar1.png", chat.Chat["assistant_avatar"]) + } else if i%3 == 0 { + assert.Equal(t, "test-assistant-2", chat.Chat["assistant_id"]) + assert.Equal(t, "Test Assistant 2", chat.Chat["assistant_name"]) + assert.Equal(t, "avatar2.png", chat.Chat["assistant_avatar"]) + } else { + assert.Nil(t, chat.Chat["assistant_id"]) + assert.Nil(t, chat.Chat["assistant_name"]) + assert.Nil(t, chat.Chat["assistant_avatar"]) + } } - // Test getting chats with default filter + // Test GetChats filter := ChatFilter{ PageSize: 10, Order: "desc", } groups, err := store.GetChats(sid, filter) - if err != nil { - t.Fatal(err) - } - + assert.Nil(t, err) + assert.NotNil(t, groups) assert.Greater(t, len(groups.Groups), 0) + // Verify assistant information in chat list + for _, group := range groups.Groups { + for _, chat := range group.Chats { + if assistantID, ok := chat["assistant_id"].(string); ok && assistantID != "" { + if assistantID == "test-assistant-1" { + assert.Equal(t, "Test Assistant 1", chat["assistant_name"]) + assert.Equal(t, "avatar1.png", chat["assistant_avatar"]) + } else if assistantID == "test-assistant-2" { + assert.Equal(t, "Test Assistant 2", chat["assistant_name"]) + assert.Equal(t, "avatar2.png", chat["assistant_avatar"]) + } + } else { + assert.Nil(t, chat["assistant_name"]) + assert.Nil(t, chat["assistant_avatar"]) + } + } + } + // Test with keywords filter.Keywords = "test" groups, err = store.GetChats(sid, filter) - if err != nil { - t.Fatal(err) - } - + assert.Nil(t, err) assert.Greater(t, len(groups.Groups), 0) }