diff --git a/neo/assistant/api.go b/neo/assistant/api.go index eda682bdbf..7a6ab2c9d5 100644 --- a/neo/assistant/api.go +++ b/neo/assistant/api.go @@ -8,6 +8,7 @@ import ( "github.com/gin-gonic/gin" "github.com/yaoapp/gou/fs" + "github.com/yaoapp/gou/process" chatctx "github.com/yaoapp/yao/neo/context" "github.com/yaoapp/yao/neo/message" chatMessage "github.com/yaoapp/yao/neo/message" @@ -69,11 +70,7 @@ func (ast *Assistant) Execute(c *gin.Context, ctx chatctx.Context, input string, // Handle next action if res.Next != nil { - switch res.Next.Action { - case "exit": - return nil - // Add other actions here if needed - } + return res.Next.Execute(c, ctx) } // Update options if provided @@ -90,15 +87,87 @@ func (ast *Assistant) Execute(c *gin.Context, ctx chatctx.Context, input string, return ast.handleChatStream(c, ctx, messages, options) } +// Execute the next action +func (next *NextAction) Execute(c *gin.Context, ctx chatctx.Context) error { + switch next.Action { + + case "process": + if next.Payload == nil { + return fmt.Errorf("payload is required") + } + + name, ok := next.Payload["name"].(string) + if !ok { + return fmt.Errorf("process name should be string") + } + + args := []interface{}{} + if v, ok := next.Payload["args"].([]interface{}); ok { + args = v + } + + // Add context and writer to args + args = append(args, ctx, c.Writer) + p, err := process.Of(name, args...) + if err != nil { + return fmt.Errorf("get process error: %s", err.Error()) + } + + err = p.Execute() + if err != nil { + return fmt.Errorf("execute process error: %s", err.Error()) + } + defer p.Release() + + return nil + + case "assistant": + if next.Payload == nil { + return fmt.Errorf("payload is required") + } + + // Get assistant id + id, ok := next.Payload["assistant_id"].(string) + if !ok { + return fmt.Errorf("assistant id should be string") + } + + // Get assistant + assistant, err := Get(id) + if err != nil { + return fmt.Errorf("get assistant error: %s", err.Error()) + } + + // Input + input, ok := next.Payload["input"].(string) + if !ok { + return fmt.Errorf("input should be string") + } + + // Options + options := map[string]interface{}{} + if v, ok := next.Payload["options"].(map[string]interface{}); ok { + options = v + } + return assistant.Execute(c, ctx, input, options) + + case "exit": + return nil + + default: + return fmt.Errorf("unknown action: %s", next.Action) + } +} + // handleChatStream manages the streaming chat interaction with the AI func (ast *Assistant) handleChatStream(c *gin.Context, ctx chatctx.Context, messages []message.Message, options map[string]interface{}) error { clientBreak := make(chan bool, 1) done := make(chan bool, 1) - content := []byte{} + content := message.NewContent("text") // Chat with AI in background go func() { - err := ast.streamChat(c, ctx, messages, options, clientBreak, done, &content) + err := ast.streamChat(c, ctx, messages, options, clientBreak, done, content) if err != nil { chatMessage.New().Error(err).Done().Write(c.Writer) } @@ -118,8 +187,14 @@ func (ast *Assistant) handleChatStream(c *gin.Context, ctx chatctx.Context, mess } // streamChat handles the streaming chat interaction -func (ast *Assistant) streamChat(c *gin.Context, ctx chatctx.Context, messages []message.Message, options map[string]interface{}, - clientBreak chan bool, done chan bool, content *[]byte) error { +func (ast *Assistant) streamChat( + c *gin.Context, + ctx chatctx.Context, + messages []message.Message, + options map[string]interface{}, + clientBreak chan bool, + done chan bool, + content *message.Content) error { return ast.Chat(c.Request.Context(), messages, options, func(data []byte) int { select { @@ -135,7 +210,7 @@ func (ast *Assistant) streamChat(c *gin.Context, ctx chatctx.Context, messages [ // Handle error if msg.Type == "error" { value := msg.String() - res, hookErr := ast.HookFail(c, ctx, messages, string(*content), fmt.Errorf("%s", value)) + res, hookErr := ast.HookFail(c, ctx, messages, content.String(), fmt.Errorf("%s", value)) if hookErr == nil && res != nil && (res.Output != "" || res.Error != "") { value = res.Output if res.Error != "" { @@ -146,20 +221,41 @@ func (ast *Assistant) streamChat(c *gin.Context, ctx chatctx.Context, messages [ return 0 // break } + // Handle tool call + if msg.Type == "tool_calls" { + content.SetType("function") // Set type to function + // Set id + if id, ok := msg.Props["id"].(string); ok && id != "" { + content.SetID(id) + } + + // Set name + if name, ok := msg.Props["name"].(string); ok && name != "" { + content.SetName(name) + } + } + // Append content and send message - *content = msg.Append(*content) value := msg.String() + content.Append(value) if value != "" { // Handle stream - res, err := ast.HookStream(c, ctx, messages, string(*content)) + res, err := ast.HookStream(c, ctx, messages, content.String(), msg.Type == "tool_calls") if err == nil && res != nil { if res.Output != "" { value = res.Output } - if res.Next != nil && res.Next.Action == "exit" { + + if res.Next != nil { + err = res.Next.Execute(c, ctx) + if err != nil { + chatMessage.New().Error(err.Error()).Done().Write(c.Writer) + } + done <- true return 0 // break } + if res.Silent { return 1 // continue } @@ -180,7 +276,8 @@ func (ast *Assistant) streamChat(c *gin.Context, ctx chatctx.Context, messages [ // } // Call HookDone - res, hookErr := ast.HookDone(c, ctx, messages, string(*content)) + content.SetStatus(message.ContentStatusDone) + res, hookErr := ast.HookDone(c, ctx, messages, content.String(), msg.Type == "tool_calls") if hookErr == nil && res != nil { if res.Output != "" { chatMessage.New(). @@ -190,10 +287,16 @@ func (ast *Assistant) streamChat(c *gin.Context, ctx chatctx.Context, messages [ }). Write(c.Writer) } - if res.Next != nil && res.Next.Action == "exit" { + + if res.Next != nil { + err := res.Next.Execute(c, ctx) + if err != nil { + chatMessage.New().Error(err.Error()).Done().Write(c.Writer) + } done <- true return 0 // break } + } else if value != "" { chatMessage.New(). Map(map[string]interface{}{ @@ -213,13 +316,13 @@ func (ast *Assistant) streamChat(c *gin.Context, ctx chatctx.Context, messages [ } // saveChatHistory saves the chat history if storage is available -func (ast *Assistant) saveChatHistory(ctx chatctx.Context, messages []message.Message, content []byte) { - if len(content) > 0 && ctx.Sid != "" && len(messages) > 0 { +func (ast *Assistant) saveChatHistory(ctx chatctx.Context, messages []message.Message, content *message.Content) { + if len(content.Bytes) > 0 && ctx.Sid != "" && len(messages) > 0 { storage.SaveHistory( ctx.Sid, []map[string]interface{}{ {"role": "user", "content": messages[len(messages)-1].Content(), "name": ctx.Sid}, - {"role": "assistant", "content": string(content), "name": ctx.Sid}, + {"role": "assistant", "content": content.String(), "name": ctx.Sid}, }, ctx.ChatID, nil, @@ -237,6 +340,15 @@ func (ast *Assistant) withOptions(options map[string]interface{}) map[string]int options[key] = value } } + + // Add functions + if ast.Functions != nil { + options["tools"] = ast.Functions + if options["tool_choice"] == nil { + options["tool_choice"] = "auto" + } + } + return options } diff --git a/neo/assistant/assistant.go b/neo/assistant/assistant.go index f4bc9fd00e..00b36062d0 100644 --- a/neo/assistant/assistant.go +++ b/neo/assistant/assistant.go @@ -132,6 +132,7 @@ func (ast *Assistant) Map() map[string]interface{} { "description": ast.Description, "options": ast.Options, "prompts": ast.Prompts, + "functions": ast.Functions, "tags": ast.Tags, "mentionable": ast.Mentionable, "automated": ast.Automated, diff --git a/neo/assistant/hooks.go b/neo/assistant/hooks.go index c6dd65b2c3..9080d3d7de 100644 --- a/neo/assistant/hooks.go +++ b/neo/assistant/hooks.go @@ -57,13 +57,13 @@ func (ast *Assistant) HookInit(c *gin.Context, context chatctx.Context, input [] } // HookStream Handle streaming response from LLM -func (ast *Assistant) HookStream(c *gin.Context, context chatctx.Context, input []message.Message, output string) (*ResHookStream, error) { +func (ast *Assistant) HookStream(c *gin.Context, context chatctx.Context, input []message.Message, output string, toolcall bool) (*ResHookStream, error) { // Create timeout context ctx, cancel := ast.createTimeoutContext(c) defer cancel() - v, err := ast.call(ctx, "Stream", context, input, output, c.Writer) + v, err := ast.call(ctx, "Stream", context, input, output, toolcall, c.Writer) if err != nil { if err.Error() == HookErrorMethodNotFound { return nil, nil @@ -100,12 +100,12 @@ func (ast *Assistant) HookStream(c *gin.Context, context chatctx.Context, input } // HookDone Handle completion of assistant response -func (ast *Assistant) HookDone(c *gin.Context, context chatctx.Context, input []message.Message, output string) (*ResHookDone, error) { +func (ast *Assistant) HookDone(c *gin.Context, context chatctx.Context, input []message.Message, output string, toolcall bool) (*ResHookDone, error) { // Create timeout context ctx, cancel := ast.createTimeoutContext(c) defer cancel() - v, err := ast.call(ctx, "Done", context, input, output, c.Writer) + v, err := ast.call(ctx, "Done", context, input, output, toolcall, c.Writer) if err != nil { if err.Error() == HookErrorMethodNotFound { return nil, nil diff --git a/neo/assistant/load.go b/neo/assistant/load.go index 99cfb2bedd..de95ef682d 100644 --- a/neo/assistant/load.go +++ b/neo/assistant/load.go @@ -246,6 +246,16 @@ func LoadPath(path string) (*Assistant, error) { } // load functions + functionsfile := filepath.Join(path, "functions.json") + if has, _ := app.Exists(functionsfile); has { + functions, ts, err := loadFunctions(functionsfile) + if err != nil { + return nil, err + } + data["functions"] = functions + updatedAt = max(updatedAt, ts) + data["updated_at"] = updatedAt + } // load flow @@ -340,6 +350,25 @@ func loadMap(data map[string]interface{}) (*Assistant, error) { assistant.Prompts = prompts } + // functions + if funcs, has := data["functions"]; has { + switch vv := funcs.(type) { + case []Function: + assistant.Functions = vv + default: + raw, err := jsoniter.Marshal(vv) + if err != nil { + return nil, err + } + var functions []Function + err = jsoniter.Unmarshal(raw, &functions) + if err != nil { + return nil, err + } + assistant.Functions = functions + } + } + // script if data["script"] != nil { switch v := data["script"].(type) { @@ -382,6 +411,32 @@ func loadMap(data map[string]interface{}) (*Assistant, error) { return assistant, nil } +func loadFunctions(file string) ([]Function, int64, error) { + + app, err := fs.Get("app") + if err != nil { + return nil, 0, err + } + + ts, err := app.ModTime(file) + if err != nil { + return nil, 0, err + } + + raw, err := app.ReadFile(file) + if err != nil { + return nil, 0, err + } + + var functions []Function + err = jsoniter.Unmarshal(raw, &functions) + if err != nil { + return nil, 0, err + } + + return functions, ts.UnixNano(), nil +} + func loadPrompts(file string, root string) (string, int64, error) { app, err := fs.Get("app") diff --git a/neo/assistant/types.go b/neo/assistant/types.go index 80862a01ce..535bfdf124 100644 --- a/neo/assistant/types.go +++ b/neo/assistant/types.go @@ -85,6 +85,16 @@ type Prompt struct { Name string `json:"name,omitempty"` } +// Function a function +type Function struct { + Type string `json:"type"` + Function struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters map[string]interface{} `json:"parameters"` + } `json:"function"` +} + // QueryParam the assistant query param type QueryParam struct { Limit uint `json:"limit"` @@ -110,6 +120,7 @@ type Assistant struct { Automated bool `json:"automated,omitempty"` // Whether this assistant is automated Options map[string]interface{} `json:"options,omitempty"` // AI Options Prompts []Prompt `json:"prompts,omitempty"` // AI Prompts + Functions []Function `json:"functions,omitempty"` // Assistant Functions Flows []map[string]interface{} `json:"flows,omitempty"` // Assistant Flows Script *v8.Script `json:"-" yaml:"-"` // Assistant Script CreatedAt int64 `json:"created_at"` // Creation timestamp diff --git a/neo/message/content.go b/neo/message/content.go new file mode 100644 index 0000000000..972037f3be --- /dev/null +++ b/neo/message/content.go @@ -0,0 +1,67 @@ +package message + +import "fmt" + +const ( + // ContentStatusPending the content status pending + ContentStatusPending = iota + // ContentStatusDone the content status done + ContentStatusDone + // ContentStatusError the content status error + ContentStatusError +) + +// Content the content +type Content struct { + ID string `json:"id"` + Name string `json:"name"` + Bytes []byte `json:"bytes"` + Type string `json:"type"` // text, function, error + Status uint8 `json:"status"` // 0: pending, 1: done +} + +// NewContent create a new content +func NewContent(typ string) *Content { + if typ == "" { + typ = "text" + } + + return &Content{ + Bytes: []byte{}, + Type: typ, + Status: ContentStatusPending, + } +} + +// String the content string +func (c *Content) String() string { + if c.Type == "function" { + return fmt.Sprintf(`{"id":"%s","type": "function", "function": {"name": "%s", "arguments": "%s"}}`, c.ID, c.Name, c.Bytes) + } + return string(c.Bytes) +} + +// SetID set the content id +func (c *Content) SetID(id string) { + c.ID = id +} + +// SetName set the content name +func (c *Content) SetName(name string) { + c.Name = name +} + +// SetType set the content type +func (c *Content) SetType(typ string) { + c.Type = typ +} + +// Append append the content +func (c *Content) Append(data string) { + c.Bytes = append(c.Bytes, []byte(data)...) +} + +// SetStatus set the content status +func (c *Content) SetStatus(status uint8) { + c.Status = status +} diff --git a/neo/message/message.go b/neo/message/message.go index cad14749e1..72b9c702c4 100644 --- a/neo/message/message.go +++ b/neo/message/message.go @@ -49,7 +49,7 @@ type Action struct { // New create a new message func New() *Message { - return &Message{Actions: []Action{}} + return &Message{Actions: []Action{}, Props: map[string]interface{}{}} } // NewString create a new message from string @@ -75,6 +75,21 @@ func NewOpenAI(data []byte) *Message { data = []byte(strings.TrimPrefix(text, "data: ")) switch { + + case strings.Contains(text, `"delta":{`) && strings.Contains(text, `"tool_calls"`): + var toolCalls openai.ToolCalls + if err := jsoniter.Unmarshal(data, &toolCalls); err != nil { + msg.Text = err.Error() + "\n" + string(data) + return msg + } + + msg.Type = "tool_calls" + if len(toolCalls.Choices) > 0 && len(toolCalls.Choices[0].Delta.ToolCalls) > 0 { + msg.Props["id"] = toolCalls.Choices[0].Delta.ToolCalls[0].ID + msg.Props["name"] = toolCalls.Choices[0].Delta.ToolCalls[0].Function.Name + msg.Text = toolCalls.Choices[0].Delta.ToolCalls[0].Function.Arguments + } + case strings.Contains(text, `"delta":{`) && strings.Contains(text, `"content":`): var message openai.Message if err := jsoniter.Unmarshal(data, &message); err != nil { @@ -92,6 +107,9 @@ func NewOpenAI(data []byte) *Message { case strings.Contains(text, `"finish_reason":"stop"`): msg.IsDone = true + case strings.Contains(text, `"finish_reason":"tool_calls"`): + msg.IsDone = true + default: str := strings.TrimPrefix(strings.Trim(string(data), "\""), "data: ") msg.Type = "error" diff --git a/openai/types.go b/openai/types.go index 12e2bdbd50..3c99ce7a35 100644 --- a/openai/types.go +++ b/openai/types.go @@ -16,6 +16,28 @@ type Message struct { } `json:"choices,omitempty"` } +// ToolCalls is the response from OpenAI +type ToolCalls struct { + ID string `json:"id,omitempty"` + Object string `json:"object,omitempty"` + Created int64 `json:"created,omitempty"` + Model string `json:"model,omitempty"` + Choices []struct { + Delta struct { + ToolCalls []struct { + ID string `json:"id,omitempty"` + Type string `json:"type,omitempty"` + Function struct { + Name string `json:"name,omitempty"` + Arguments string `json:"arguments,omitempty"` + } `json:"function,omitempty"` + } `json:"tool_calls,omitempty"` + } `json:"delta,omitempty"` + Index int `json:"index,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + } `json:"choices,omitempty"` +} + // ErrorMessage is the error response from OpenAI type ErrorMessage struct { Error Error `json:"error,omitempty"`