Skip to content

Commit

Permalink
Merge pull request #817 from trheyi/main
Browse files Browse the repository at this point in the history
Added function support and action execution capabilities to Neo API assistant, improving its flexibility and robustness. Key updates include function management, enhanced message handling, and improved error validation.
  • Loading branch information
trheyi authored Jan 15, 2025
2 parents 9c7b9cf + e7185aa commit cad8aa2
Show file tree
Hide file tree
Showing 8 changed files with 309 additions and 23 deletions.
148 changes: 130 additions & 18 deletions neo/assistant/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/gin-gonic/gin"
"github.com/yaoapp/gou/fs"
"github.com/yaoapp/gou/process"
chatctx "github.com/yaoapp/yao/neo/context"
"github.com/yaoapp/yao/neo/message"
chatMessage "github.com/yaoapp/yao/neo/message"
Expand Down Expand Up @@ -69,11 +70,7 @@ func (ast *Assistant) Execute(c *gin.Context, ctx chatctx.Context, input string,

// Handle next action
if res.Next != nil {
switch res.Next.Action {
case "exit":
return nil
// Add other actions here if needed
}
return res.Next.Execute(c, ctx)
}

// Update options if provided
Expand All @@ -90,15 +87,87 @@ func (ast *Assistant) Execute(c *gin.Context, ctx chatctx.Context, input string,
return ast.handleChatStream(c, ctx, messages, options)
}

// Execute the next action
func (next *NextAction) Execute(c *gin.Context, ctx chatctx.Context) error {
switch next.Action {

case "process":
if next.Payload == nil {
return fmt.Errorf("payload is required")
}

name, ok := next.Payload["name"].(string)
if !ok {
return fmt.Errorf("process name should be string")
}

args := []interface{}{}
if v, ok := next.Payload["args"].([]interface{}); ok {
args = v
}

// Add context and writer to args
args = append(args, ctx, c.Writer)
p, err := process.Of(name, args...)
if err != nil {
return fmt.Errorf("get process error: %s", err.Error())
}

err = p.Execute()
if err != nil {
return fmt.Errorf("execute process error: %s", err.Error())
}
defer p.Release()

return nil

case "assistant":
if next.Payload == nil {
return fmt.Errorf("payload is required")
}

// Get assistant id
id, ok := next.Payload["assistant_id"].(string)
if !ok {
return fmt.Errorf("assistant id should be string")
}

// Get assistant
assistant, err := Get(id)
if err != nil {
return fmt.Errorf("get assistant error: %s", err.Error())
}

// Input
input, ok := next.Payload["input"].(string)
if !ok {
return fmt.Errorf("input should be string")
}

// Options
options := map[string]interface{}{}
if v, ok := next.Payload["options"].(map[string]interface{}); ok {
options = v
}
return assistant.Execute(c, ctx, input, options)

case "exit":
return nil

default:
return fmt.Errorf("unknown action: %s", next.Action)
}
}

// handleChatStream manages the streaming chat interaction with the AI
func (ast *Assistant) handleChatStream(c *gin.Context, ctx chatctx.Context, messages []message.Message, options map[string]interface{}) error {
clientBreak := make(chan bool, 1)
done := make(chan bool, 1)
content := []byte{}
content := message.NewContent("text")

// Chat with AI in background
go func() {
err := ast.streamChat(c, ctx, messages, options, clientBreak, done, &content)
err := ast.streamChat(c, ctx, messages, options, clientBreak, done, content)
if err != nil {
chatMessage.New().Error(err).Done().Write(c.Writer)
}
Expand All @@ -118,8 +187,14 @@ func (ast *Assistant) handleChatStream(c *gin.Context, ctx chatctx.Context, mess
}

// streamChat handles the streaming chat interaction
func (ast *Assistant) streamChat(c *gin.Context, ctx chatctx.Context, messages []message.Message, options map[string]interface{},
clientBreak chan bool, done chan bool, content *[]byte) error {
func (ast *Assistant) streamChat(
c *gin.Context,
ctx chatctx.Context,
messages []message.Message,
options map[string]interface{},
clientBreak chan bool,
done chan bool,
content *message.Content) error {

return ast.Chat(c.Request.Context(), messages, options, func(data []byte) int {
select {
Expand All @@ -135,7 +210,7 @@ func (ast *Assistant) streamChat(c *gin.Context, ctx chatctx.Context, messages [
// Handle error
if msg.Type == "error" {
value := msg.String()
res, hookErr := ast.HookFail(c, ctx, messages, string(*content), fmt.Errorf("%s", value))
res, hookErr := ast.HookFail(c, ctx, messages, content.String(), fmt.Errorf("%s", value))
if hookErr == nil && res != nil && (res.Output != "" || res.Error != "") {
value = res.Output
if res.Error != "" {
Expand All @@ -146,20 +221,41 @@ func (ast *Assistant) streamChat(c *gin.Context, ctx chatctx.Context, messages [
return 0 // break
}

// Handle tool call
if msg.Type == "tool_calls" {
content.SetType("function") // Set type to function
// Set id
if id, ok := msg.Props["id"].(string); ok && id != "" {
content.SetID(id)
}

// Set name
if name, ok := msg.Props["name"].(string); ok && name != "" {
content.SetName(name)
}
}

// Append content and send message
*content = msg.Append(*content)
value := msg.String()
content.Append(value)
if value != "" {
// Handle stream
res, err := ast.HookStream(c, ctx, messages, string(*content))
res, err := ast.HookStream(c, ctx, messages, content.String(), msg.Type == "tool_calls")
if err == nil && res != nil {
if res.Output != "" {
value = res.Output
}
if res.Next != nil && res.Next.Action == "exit" {

if res.Next != nil {
err = res.Next.Execute(c, ctx)
if err != nil {
chatMessage.New().Error(err.Error()).Done().Write(c.Writer)
}

done <- true
return 0 // break
}

if res.Silent {
return 1 // continue
}
Expand All @@ -180,7 +276,8 @@ func (ast *Assistant) streamChat(c *gin.Context, ctx chatctx.Context, messages [
// }

// Call HookDone
res, hookErr := ast.HookDone(c, ctx, messages, string(*content))
content.SetStatus(message.ContentStatusDone)
res, hookErr := ast.HookDone(c, ctx, messages, content.String(), msg.Type == "tool_calls")
if hookErr == nil && res != nil {
if res.Output != "" {
chatMessage.New().
Expand All @@ -190,10 +287,16 @@ func (ast *Assistant) streamChat(c *gin.Context, ctx chatctx.Context, messages [
}).
Write(c.Writer)
}
if res.Next != nil && res.Next.Action == "exit" {

if res.Next != nil {
err := res.Next.Execute(c, ctx)
if err != nil {
chatMessage.New().Error(err.Error()).Done().Write(c.Writer)
}
done <- true
return 0 // break
}

} else if value != "" {
chatMessage.New().
Map(map[string]interface{}{
Expand All @@ -213,13 +316,13 @@ func (ast *Assistant) streamChat(c *gin.Context, ctx chatctx.Context, messages [
}

// saveChatHistory saves the chat history if storage is available
func (ast *Assistant) saveChatHistory(ctx chatctx.Context, messages []message.Message, content []byte) {
if len(content) > 0 && ctx.Sid != "" && len(messages) > 0 {
func (ast *Assistant) saveChatHistory(ctx chatctx.Context, messages []message.Message, content *message.Content) {
if len(content.Bytes) > 0 && ctx.Sid != "" && len(messages) > 0 {
storage.SaveHistory(
ctx.Sid,
[]map[string]interface{}{
{"role": "user", "content": messages[len(messages)-1].Content(), "name": ctx.Sid},
{"role": "assistant", "content": string(content), "name": ctx.Sid},
{"role": "assistant", "content": content.String(), "name": ctx.Sid},
},
ctx.ChatID,
nil,
Expand All @@ -237,6 +340,15 @@ func (ast *Assistant) withOptions(options map[string]interface{}) map[string]int
options[key] = value
}
}

// Add functions
if ast.Functions != nil {
options["tools"] = ast.Functions
if options["tool_choice"] == nil {
options["tool_choice"] = "auto"
}
}

return options
}

Expand Down
1 change: 1 addition & 0 deletions neo/assistant/assistant.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ func (ast *Assistant) Map() map[string]interface{} {
"description": ast.Description,
"options": ast.Options,
"prompts": ast.Prompts,
"functions": ast.Functions,
"tags": ast.Tags,
"mentionable": ast.Mentionable,
"automated": ast.Automated,
Expand Down
8 changes: 4 additions & 4 deletions neo/assistant/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ func (ast *Assistant) HookInit(c *gin.Context, context chatctx.Context, input []
}

// HookStream Handle streaming response from LLM
func (ast *Assistant) HookStream(c *gin.Context, context chatctx.Context, input []message.Message, output string) (*ResHookStream, error) {
func (ast *Assistant) HookStream(c *gin.Context, context chatctx.Context, input []message.Message, output string, toolcall bool) (*ResHookStream, error) {

// Create timeout context
ctx, cancel := ast.createTimeoutContext(c)
defer cancel()

v, err := ast.call(ctx, "Stream", context, input, output, c.Writer)
v, err := ast.call(ctx, "Stream", context, input, output, toolcall, c.Writer)
if err != nil {
if err.Error() == HookErrorMethodNotFound {
return nil, nil
Expand Down Expand Up @@ -100,12 +100,12 @@ func (ast *Assistant) HookStream(c *gin.Context, context chatctx.Context, input
}

// HookDone Handle completion of assistant response
func (ast *Assistant) HookDone(c *gin.Context, context chatctx.Context, input []message.Message, output string) (*ResHookDone, error) {
func (ast *Assistant) HookDone(c *gin.Context, context chatctx.Context, input []message.Message, output string, toolcall bool) (*ResHookDone, error) {
// Create timeout context
ctx, cancel := ast.createTimeoutContext(c)
defer cancel()

v, err := ast.call(ctx, "Done", context, input, output, c.Writer)
v, err := ast.call(ctx, "Done", context, input, output, toolcall, c.Writer)
if err != nil {
if err.Error() == HookErrorMethodNotFound {
return nil, nil
Expand Down
55 changes: 55 additions & 0 deletions neo/assistant/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,16 @@ func LoadPath(path string) (*Assistant, error) {
}

// load functions
functionsfile := filepath.Join(path, "functions.json")
if has, _ := app.Exists(functionsfile); has {
functions, ts, err := loadFunctions(functionsfile)
if err != nil {
return nil, err
}
data["functions"] = functions
updatedAt = max(updatedAt, ts)
data["updated_at"] = updatedAt
}

// load flow

Expand Down Expand Up @@ -340,6 +350,25 @@ func loadMap(data map[string]interface{}) (*Assistant, error) {
assistant.Prompts = prompts
}

// functions
if funcs, has := data["functions"]; has {
switch vv := funcs.(type) {
case []Function:
assistant.Functions = vv
default:
raw, err := jsoniter.Marshal(vv)
if err != nil {
return nil, err
}
var functions []Function
err = jsoniter.Unmarshal(raw, &functions)
if err != nil {
return nil, err
}
assistant.Functions = functions
}
}

// script
if data["script"] != nil {
switch v := data["script"].(type) {
Expand Down Expand Up @@ -382,6 +411,32 @@ func loadMap(data map[string]interface{}) (*Assistant, error) {
return assistant, nil
}

func loadFunctions(file string) ([]Function, int64, error) {

app, err := fs.Get("app")
if err != nil {
return nil, 0, err
}

ts, err := app.ModTime(file)
if err != nil {
return nil, 0, err
}

raw, err := app.ReadFile(file)
if err != nil {
return nil, 0, err
}

var functions []Function
err = jsoniter.Unmarshal(raw, &functions)
if err != nil {
return nil, 0, err
}

return functions, ts.UnixNano(), nil
}

func loadPrompts(file string, root string) (string, int64, error) {

app, err := fs.Get("app")
Expand Down
11 changes: 11 additions & 0 deletions neo/assistant/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@ type Prompt struct {
Name string `json:"name,omitempty"`
}

// Function a function
type Function struct {
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]interface{} `json:"parameters"`
} `json:"function"`
}

// QueryParam the assistant query param
type QueryParam struct {
Limit uint `json:"limit"`
Expand All @@ -110,6 +120,7 @@ type Assistant struct {
Automated bool `json:"automated,omitempty"` // Whether this assistant is automated
Options map[string]interface{} `json:"options,omitempty"` // AI Options
Prompts []Prompt `json:"prompts,omitempty"` // AI Prompts
Functions []Function `json:"functions,omitempty"` // Assistant Functions
Flows []map[string]interface{} `json:"flows,omitempty"` // Assistant Flows
Script *v8.Script `json:"-" yaml:"-"` // Assistant Script
CreatedAt int64 `json:"created_at"` // Creation timestamp
Expand Down
Loading

0 comments on commit cad8aa2

Please sign in to comment.