Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

# Neo API: Enhance Message Handling and Vision Support #819

Merged
merged 2 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 82 additions & 49 deletions neo/assistant/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ import (
"strings"

"github.com/gin-gonic/gin"
jsoniter "github.com/json-iterator/go"
"github.com/yaoapp/gou/fs"
"github.com/yaoapp/gou/process"
"github.com/yaoapp/kun/utils"
chatctx "github.com/yaoapp/yao/neo/context"
"github.com/yaoapp/yao/neo/message"
chatMessage "github.com/yaoapp/yao/neo/message"
)

Expand Down Expand Up @@ -160,10 +161,10 @@ func (next *NextAction) Execute(c *gin.Context, ctx chatctx.Context) error {
}

// handleChatStream manages the streaming chat interaction with the AI
func (ast *Assistant) handleChatStream(c *gin.Context, ctx chatctx.Context, messages []message.Message, options map[string]interface{}) error {
func (ast *Assistant) handleChatStream(c *gin.Context, ctx chatctx.Context, messages []chatMessage.Message, options map[string]interface{}) error {
clientBreak := make(chan bool, 1)
done := make(chan bool, 1)
content := message.NewContent("text")
content := chatMessage.NewContent("text")

// Chat with AI in background
go func() {
Expand All @@ -190,11 +191,11 @@ func (ast *Assistant) handleChatStream(c *gin.Context, ctx chatctx.Context, mess
func (ast *Assistant) streamChat(
c *gin.Context,
ctx chatctx.Context,
messages []message.Message,
messages []chatMessage.Message,
options map[string]interface{},
clientBreak chan bool,
done chan bool,
content *message.Content) error {
content *chatMessage.Content) error {

return ast.Chat(c.Request.Context(), messages, options, func(data []byte) int {
select {
Expand Down Expand Up @@ -240,7 +241,7 @@ func (ast *Assistant) streamChat(
content.Append(value)
if value != "" {
// Handle stream
res, err := ast.HookStream(c, ctx, messages, content.String(), msg.Type == "tool_calls")
res, err := ast.HookStream(c, ctx, messages, content.String(), content.Type == "function")
if err == nil && res != nil {
if res.Output != "" {
value = res.Output
Expand Down Expand Up @@ -276,8 +277,8 @@ func (ast *Assistant) streamChat(
// }

// Call HookDone
content.SetStatus(message.ContentStatusDone)
res, hookErr := ast.HookDone(c, ctx, messages, content.String(), msg.Type == "tool_calls")
content.SetStatus(chatMessage.ContentStatusDone)
res, hookErr := ast.HookDone(c, ctx, messages, content.String(), content.Type == "function")
if hookErr == nil && res != nil {
if res.Output != "" {
chatMessage.New().
Expand Down Expand Up @@ -316,7 +317,7 @@ func (ast *Assistant) streamChat(
}

// saveChatHistory saves the chat history if storage is available
func (ast *Assistant) saveChatHistory(ctx chatctx.Context, messages []message.Message, content *message.Content) {
func (ast *Assistant) saveChatHistory(ctx chatctx.Context, messages []chatMessage.Message, content *chatMessage.Content) {
if len(content.Bytes) > 0 && ctx.Sid != "" && len(messages) > 0 {
storage.SaveHistory(
ctx.Sid,
Expand Down Expand Up @@ -352,21 +353,21 @@ func (ast *Assistant) withOptions(options map[string]interface{}) map[string]int
return options
}

func (ast *Assistant) withPrompts(messages []message.Message) []message.Message {
func (ast *Assistant) withPrompts(messages []chatMessage.Message) []chatMessage.Message {
if ast.Prompts != nil {
for _, prompt := range ast.Prompts {
name := ast.Name
if prompt.Name != "" {
name = prompt.Name
}
messages = append(messages, *message.New().Map(map[string]interface{}{"role": prompt.Role, "content": prompt.Content, "name": name}))
messages = append(messages, *chatMessage.New().Map(map[string]interface{}{"role": prompt.Role, "content": prompt.Content, "name": name}))
}
}
return messages
}

func (ast *Assistant) withHistory(ctx chatctx.Context, input string) ([]message.Message, error) {
messages := []message.Message{}
func (ast *Assistant) withHistory(ctx chatctx.Context, input string) ([]chatMessage.Message, error) {
messages := []chatMessage.Message{}
messages = ast.withPrompts(messages)
if storage != nil {
history, err := storage.GetHistory(ctx.Sid, ctx.ChatID)
Expand All @@ -376,17 +377,17 @@ func (ast *Assistant) withHistory(ctx chatctx.Context, input string) ([]message.

// Add history messages
for _, h := range history {
messages = append(messages, *message.New().Map(h))
messages = append(messages, *chatMessage.New().Map(h))
}
}

// Add user message
messages = append(messages, *message.New().Map(map[string]interface{}{"role": "user", "content": input, "name": ctx.Sid}))
messages = append(messages, *chatMessage.New().Map(map[string]interface{}{"role": "user", "content": input, "name": ctx.Sid}))
return messages, nil
}

// Chat implements the chat functionality
func (ast *Assistant) Chat(ctx context.Context, messages []message.Message, option map[string]interface{}, cb func(data []byte) int) error {
func (ast *Assistant) Chat(ctx context.Context, messages []chatMessage.Message, option map[string]interface{}, cb func(data []byte) int) error {
if ast.openai == nil {
return fmt.Errorf("openai is not initialized")
}
Expand All @@ -404,27 +405,10 @@ func (ast *Assistant) Chat(ctx context.Context, messages []message.Message, opti
return nil
}

func (ast *Assistant) requestMessages(ctx context.Context, messages []message.Message) ([]map[string]interface{}, error) {
func (ast *Assistant) requestMessages(ctx context.Context, messages []chatMessage.Message) ([]map[string]interface{}, error) {
newMessages := []map[string]interface{}{}
// With Prompts
if ast.Prompts != nil {
for _, prompt := range ast.Prompts {
msg := map[string]interface{}{
"role": prompt.Role,
"content": prompt.Content,
}

name := ast.Name
if prompt.Name != "" {
name = prompt.Name
}

msg["name"] = name
newMessages = append(newMessages, msg)
}
}

length := len(messages)

for index, message := range messages {
role := message.Role
if role == "" {
Expand Down Expand Up @@ -454,12 +438,24 @@ func (ast *Assistant) requestMessages(ctx context.Context, messages []message.Me
}

newMessage["content"] = msg.Text
if msg.Attachments != nil {
content, err := ast.withAttachments(ctx, msg)
if message.Attachments != nil {
contents, err := ast.withAttachments(ctx, &message)
if err != nil {
return nil, fmt.Errorf("with attachments error: %s", err.Error())
}
newMessage["content"] = content

// if current assistant is vision capable, add the contents directly
if ast.vision {
newMessage["content"] = contents
continue
}

// If current assistant is not vision capable, add the description of the image
if contents != nil {
for _, content := range contents {
newMessages = append(newMessages, content)
}
}
}
}

Expand All @@ -470,31 +466,68 @@ func (ast *Assistant) requestMessages(ctx context.Context, messages []message.Me

func (ast *Assistant) withAttachments(ctx context.Context, msg *chatMessage.Message) ([]map[string]interface{}, error) {
contents := []map[string]interface{}{{"type": "text", "text": msg.Text}}
if !ast.vision {
contents = []map[string]interface{}{{"role": "user", "content": msg.Text}}
}

images := []string{}
for _, attachment := range msg.Attachments {
if strings.HasPrefix(attachment.ContentType, "image/") {
images = append(images, attachment.FileID)
if ast.vision {
images = append(images, attachment.URL)
continue
}

// If the current assistant is not vision capable, add the description of the image
raw, err := jsoniter.MarshalToString(attachment)
if err != nil {
return nil, fmt.Errorf("marshal attachment error: %s", err.Error())
}
contents = append(contents, map[string]interface{}{
"role": "system",
"content": raw,
})
}
}

if len(images) == 0 {
return contents, nil
}

for _, image := range images {
bytes64, err := ast.ReadBase64(ctx, image)
if err != nil {
return nil, fmt.Errorf("read base64 error: %s", err.Error())
// If the current assistant is vision capable, add the image to the contents directly
if ast.vision {
for _, url := range images {

// If the image is already a URL, add it directly
if strings.HasPrefix(url, "http") {
contents = append(contents, map[string]interface{}{
"type": "image_url",
"image_url": map[string]string{
"url": url,
},
})
continue
}

// Read base64
bytes64, err := ast.ReadBase64(ctx, url)
if err != nil {
return nil, fmt.Errorf("read base64 error: %s", err.Error())
}
contents = append(contents, map[string]interface{}{
"type": "image_url",
"image_url": map[string]string{
"url": fmt.Sprintf("data:image/jpeg;base64,%s", bytes64),
},
})
}

contents = append(contents, map[string]interface{}{
"type": "image_url",
"image_url": map[string]string{
"url": fmt.Sprintf("data:image/jpeg;base64,%s", bytes64),
},
})
utils.Dump(contents)
return contents, nil
}

// If the current assistant is not vision capable, add the description of the image

return contents, nil
}

Expand Down
77 changes: 43 additions & 34 deletions neo/assistant/attachment.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,22 @@ func (ast *Assistant) handleRAG(ctx context.Context, file *File, reader io.Reade

// handleVision handles the file with Vision if available
func (ast *Assistant) handleVision(ctx context.Context, file *File, option map[string]interface{}) error {

if vision == nil {
return nil
}

// Check if Vision processing is enabled
if option, ok := option["vision"].(bool); !ok || !option {
handleVision := false
if vv, has := option["vision"]; has {
switch v := vv.(type) {
case bool:
handleVision = v
case string:
handleVision = v == "true" || v == "1" || v == "yes" || v == "on" || v == "enable"
}
}

if !handleVision {
return nil
}

Expand All @@ -211,12 +221,6 @@ func (ast *Assistant) handleVision(ctx context.Context, file *File, option map[s
return nil
}

// Get model from options
model := ""
if v, ok := option["model"].(string); ok {
model = v
}

// Reset reader for vision service
data, err := fs.Get("data")
if err != nil {
Expand All @@ -237,43 +241,48 @@ func (ast *Assistant) handleVision(ctx context.Context, file *File, option map[s
return fmt.Errorf("read file error: %s", err.Error())
}

if VisionCapableModels[model] {
// The model is vision capable
if ast.vision {
// For vision-capable models, upload to vision service to get URL
resp, err := vision.Upload(ctx, file.Filename, bytes.NewReader(imgData), file.ContentType)
if err != nil {
return fmt.Errorf("vision upload error: %s", err.Error())
}
file.URL = resp.URL // Store the URL for vision-capable models to use
} else {
// For non-vision models, get image description
prompt := "Describe this image in detail."
if v, ok := option["vision_prompt"].(string); ok {
prompt = v
}
return nil
}

// Upload to vision service first Compress image
resp, err := vision.Upload(ctx, file.Filename, bytes.NewReader(imgData), file.ContentType)
if err != nil {
return fmt.Errorf("vision upload error: %s", err.Error())
}
// For non-vision models, get image description
prompt := "Describe this image in detail."
if v, ok := option["vision_prompt"].(string); ok {
prompt = v
}

// Analyze using base64 data
result, err := vision.Analyze(ctx, resp.FileID, prompt)
if err != nil {
return fmt.Errorf("vision analyze error: %s", err.Error())
}
// Upload to vision service first Compress image
resp, err := vision.Upload(ctx, file.Filename, bytes.NewReader(imgData), file.ContentType)
if err != nil {
return fmt.Errorf("vision upload error: %s", err.Error())
}

// Analyze using base64 data
result, err := vision.Analyze(ctx, resp.FileID, prompt)
if err != nil {
return fmt.Errorf("vision analyze error: %s", err.Error())
}

// Extract description text from response
if desc, ok := result.Description["text"].(string); ok {
file.Description = desc
} else {
// Convert the entire description to JSON string as fallback
bytes, err := jsoniter.Marshal(result.Description)
if err == nil {
file.Description = string(bytes)
}
// Extract description text from response
if desc, ok := result.Description["description"].(string); ok {
file.Description = desc
} else if desc, ok := result.Description["text"].(string); ok {
file.Description = desc
} else {
// Convert the entire description to JSON string as fallback
bytes, err := jsoniter.Marshal(result.Description)
if err == nil {
file.Description = string(bytes)
}
}

return nil
}

Expand Down
20 changes: 20 additions & 0 deletions neo/assistant/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -517,5 +517,25 @@ func (ast *Assistant) initialize() error {
return err
}
ast.openai = api

// Check if the assistant supports vision
model := api.Model()
if v, ok := ast.Options["model"].(string); ok {
model = strings.TrimLeft(v, "moapi:")
}
if _, ok := VisionCapableModels[model]; ok {
ast.vision = true
}

// Check if the assistant has an init hook
if ast.Script != nil {
scriptCtx, err := ast.Script.NewContext("", nil)
if err != nil {
return err
}
defer scriptCtx.Close()
ast.initHook = scriptCtx.Global().Has("init")
}

return nil
}
Loading
Loading