Skip to content

Commit

Permalink
Merge pull request #859 from trheyi/main
Browse files Browse the repository at this point in the history
Add assistant details to chat API responses
  • Loading branch information
trheyi authored Feb 9, 2025
2 parents 1556f3c + 4ba2739 commit e4375ea
Show file tree
Hide file tree
Showing 4 changed files with 282 additions and 70 deletions.
61 changes: 51 additions & 10 deletions neo/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -471,18 +471,27 @@ func (neo *DSL) handleChatLatest(c *gin.Context) {
// Create a new chat
if len(chats.Groups) == 0 || len(chats.Groups[0].Chats) == 0 {

ast := neo.Assistant
assistantID := c.Query("assistant_id")
if assistantID != "" {
ast, err = assistant.Get(assistantID)
if err != nil {
c.JSON(500, gin.H{"message": err.Error(), "code": 500})
c.Done()
return
}
assistantID := neo.Use
queryAssistantID := c.Query("assistant_id")
if queryAssistantID != "" {
assistantID = queryAssistantID
}

// Get the assistant info
ast, err := assistant.Get(assistantID)
if err != nil {
c.JSON(500, gin.H{"message": err.Error(), "code": 500})
c.Done()
return
}

c.JSON(200, map[string]interface{}{"data": map[string]interface{}{"placeholder": ast.GetPlaceholder()}})
c.JSON(200, map[string]interface{}{"data": map[string]interface{}{
"placeholder": ast.GetPlaceholder(),
"assistant_id": ast.ID,
"assistant_name": ast.Name,
"assistant_avatar": ast.Avatar,
"assistant_deleteable": neo.Use != ast.ID,
}})
c.Done()
return
}
Expand All @@ -502,6 +511,22 @@ func (neo *DSL) handleChatLatest(c *gin.Context) {
return
}

// assistant_id is nil return the default assistant
if chat.Chat["assistant_id"] == nil {
chat.Chat["assistant_id"] = neo.Use

// Get the assistant info
ast, err := assistant.Get(neo.Use)
if err != nil {
c.JSON(500, gin.H{"message": err.Error(), "code": 500})
c.Done()
return
}
chat.Chat["assistant_name"] = ast.Name
chat.Chat["assistant_avatar"] = ast.Avatar
}

chat.Chat["assistant_deleteable"] = neo.Use != chat.Chat["assistant_id"]
c.JSON(200, map[string]interface{}{"data": chat})
c.Done()
}
Expand Down Expand Up @@ -529,6 +554,22 @@ func (neo *DSL) handleChatDetail(c *gin.Context) {
return
}

// assistant_id is nil return the default assistant
if chat.Chat["assistant_id"] == nil {
chat.Chat["assistant_id"] = neo.Use

// Get the assistant info
ast, err := assistant.Get(neo.Use)
if err != nil {
c.JSON(500, gin.H{"message": err.Error(), "code": 500})
c.Done()
return
}
chat.Chat["assistant_name"] = ast.Name
chat.Chat["assistant_avatar"] = ast.Avatar
}

chat.Chat["assistant_deleteable"] = neo.Use != chat.Chat["assistant_id"]
c.JSON(200, map[string]interface{}{"data": chat})
c.Done()
}
Expand Down
49 changes: 34 additions & 15 deletions neo/assistant/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ func (ast *Assistant) Execute(c *gin.Context, ctx chatctx.Context, input string,
}

// Execute implements the execute functionality
func (ast *Assistant) execute(c *gin.Context, ctx chatctx.Context, input []chatMessage.Message, options map[string]interface{}, contents *chatMessage.Contents) error {
func (ast *Assistant) execute(c *gin.Context, ctx chatctx.Context, input []chatMessage.Message, userOptions map[string]interface{}, contents *chatMessage.Contents) error {

if contents == nil {
contents = chatMessage.NewContents()
}
options = ast.withOptions(options)
options := ast.withOptions(userOptions)

// Add RAG and Version support
ctx.RAG = rag != nil
Expand All @@ -78,6 +78,22 @@ func (ast *Assistant) execute(c *gin.Context, ctx chatctx.Context, input []chatM
return err
}

// Update options if provided
if res != nil && res.Options != nil {
options = res.Options
}

// messages
if res != nil && res.Input != nil {
input = res.Input
}

// Handle next action
// It's not used, return the new assistant_id and chat_id
// if res != nil && res.Next != nil {
// return res.Next.Execute(c, ctx, contents)
// }

// Switch to the new assistant if necessary
if res != nil && res.AssistantID != ctx.AssistantID {
newAst, err := Get(res.AssistantID)
Expand All @@ -89,22 +105,25 @@ func (ast *Assistant) execute(c *gin.Context, ctx chatctx.Context, input []chatM
Write(c.Writer)
return err
}
*ast = *newAst
}

// Handle next action
if res != nil && res.Next != nil {
return res.Next.Execute(c, ctx, contents)
}
// Reset Message Contents
last := input[len(input)-1]
input, err = newAst.withHistory(ctx, last)
if err != nil {
return err
}

// Update options if provided
if res != nil && res.Options != nil {
options = res.Options
}
// Reset options
options = newAst.withOptions(userOptions)

// messages
if res != nil && res.Input != nil {
input = res.Input
// Update options if provided
if res.Options != nil {
options = res.Options
}

// Update assistant id
ctx.AssistantID = res.AssistantID
return newAst.handleChatStream(c, ctx, input, options, contents)
}

// Only proceed with chat stream if no specific next action was handled
Expand Down
118 changes: 90 additions & 28 deletions neo/store/xun.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ func (conv *Xun) initChatTable() error {
table.ID("id")
table.String("chat_id", 200).Unique().Index()
table.String("title", 200).Null()
table.String("assistant_id", 200).Null().Index()
table.String("sid", 255).Index()
table.TimestampTz("created_at").SetDefaultRaw("NOW()").Index()
table.TimestampTz("updated_at").Null().Index()
Expand All @@ -202,7 +203,7 @@ func (conv *Xun) initChatTable() error {
return err
}

fields := []string{"id", "chat_id", "title", "sid", "created_at", "updated_at"}
fields := []string{"id", "chat_id", "title", "assistant_id", "sid", "created_at", "updated_at"}
for _, field := range fields {
if !tab.HasColumn(field) {
return fmt.Errorf("%s is required", field)
Expand Down Expand Up @@ -336,7 +337,7 @@ func (conv *Xun) GetChats(sid string, filter ChatFilter) (*ChatGroupResponse, er

// Build base query
qb := conv.newQueryChat().
Select("chat_id", "title", "created_at", "updated_at").
Select("chat_id", "title", "assistant_id", "created_at", "updated_at").
Where("sid", userID).
Where("chat_id", "!=", "")

Expand Down Expand Up @@ -384,15 +385,54 @@ func (conv *Xun) GetChats(sid string, filter ChatFilter) (*ChatGroupResponse, er
"Even Earlier": {},
}

// Get assistant details for all chats
assistantIDs := []interface{}{}
assistantMap := make(map[string]map[string]interface{})

for _, row := range rows {
if assistantID := row.Get("assistant_id"); assistantID != nil && assistantID != "" {
assistantIDs = append(assistantIDs, assistantID)
}
}

if len(assistantIDs) > 0 {
assistants, err := conv.query.New().
Table(conv.getAssistantTable()).
Select("assistant_id", "name", "avatar").
WhereIn("assistant_id", assistantIDs).
Get()
if err != nil {
return nil, err
}

for _, assistant := range assistants {
if id := assistant.Get("assistant_id"); id != nil {
assistantMap[fmt.Sprintf("%v", id)] = map[string]interface{}{
"name": assistant.Get("name"),
"avatar": assistant.Get("avatar"),
}
}
}
}

for _, row := range rows {
chatID := row.Get("chat_id")
if chatID == nil || chatID == "" {
continue
}

chat := map[string]interface{}{
"chat_id": chatID,
"title": row.Get("title"),
"chat_id": chatID,
"title": row.Get("title"),
"assistant_id": row.Get("assistant_id"),
}

// Add assistant details if available
if assistantID := row.Get("assistant_id"); assistantID != nil && assistantID != "" {
if assistant, ok := assistantMap[fmt.Sprintf("%v", assistantID)]; ok {
chat["assistant_name"] = assistant["name"]
chat["assistant_avatar"] = assistant["avatar"]
}
}

var dbDatetime = row.Get("updated_at")
Expand Down Expand Up @@ -514,6 +554,14 @@ func (conv *Xun) SaveHistory(sid string, messages []map[string]interface{}, cid
return err
}

// Get assistant_id from context
var assistantID interface{} = nil
if context != nil {
if id, ok := context["assistant_id"].(string); ok && id != "" {
assistantID = id
}
}

// First ensure chat record exists
exists, err := conv.newQueryChat().
Where("chat_id", cid).
Expand All @@ -528,14 +576,28 @@ func (conv *Xun) SaveHistory(sid string, messages []map[string]interface{}, cid
// Create new chat record
err = conv.newQueryChat().
Insert(map[string]interface{}{
"chat_id": cid,
"sid": userID,
"created_at": time.Now(),
"chat_id": cid,
"sid": userID,
"assistant_id": assistantID,
"created_at": time.Now(),
})

if err != nil {
return err
}
} else {
// Update assistant_id if it exists
if assistantID != nil {
_, err = conv.newQueryChat().
Where("chat_id", cid).
Where("sid", userID).
Update(map[string]interface{}{
"assistant_id": assistantID,
})
if err != nil {
return err
}
}
}

// Save message history
Expand Down Expand Up @@ -637,7 +699,7 @@ func (conv *Xun) GetChat(sid string, cid string) (*ChatInfo, error) {

// Get chat info
qb := conv.newQueryChat().
Select("chat_id", "title").
Select("chat_id", "title", "assistant_id").
Where("sid", userID).
Where("chat_id", cid)

Expand All @@ -652,8 +714,26 @@ func (conv *Xun) GetChat(sid string, cid string) (*ChatInfo, error) {
}

chat := map[string]interface{}{
"chat_id": row.Get("chat_id"),
"title": row.Get("title"),
"chat_id": row.Get("chat_id"),
"title": row.Get("title"),
"assistant_id": row.Get("assistant_id"),
}

// Get assistant details if assistant_id exists
if assistantID := row.Get("assistant_id"); assistantID != nil && assistantID != "" {
assistant, err := conv.query.New().
Table(conv.getAssistantTable()).
Select("name", "avatar").
Where("assistant_id", assistantID).
First()
if err != nil {
return nil, err
}

if assistant != nil {
chat["assistant_name"] = assistant.Get("name")
chat["assistant_avatar"] = assistant.Get("avatar")
}
}

// Get chat history
Expand Down Expand Up @@ -715,24 +795,6 @@ func (conv *Xun) DeleteAllChats(sid string) error {
return err
}

// processJSONField processes a field that should be stored as JSON string
func (conv *Xun) processJSONField(field interface{}) (interface{}, error) {
if field == nil {
return nil, nil
}

switch v := field.(type) {
case string:
return v, nil
default:
jsonStr, err := jsoniter.MarshalToString(v)
if err != nil {
return nil, fmt.Errorf("failed to marshal %v to JSON: %v", field, err)
}
return jsonStr, nil
}
}

// parseJSONFields parses JSON string fields into their corresponding Go types
func (conv *Xun) parseJSONFields(data map[string]interface{}, fields []string) {
for _, field := range fields {
Expand Down
Loading

0 comments on commit e4375ea

Please sign in to comment.