Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added function support and action execution capabilities to Neo API assistant, improving its flexibility and robustness. Key updates include function management, enhanced message handling, and improved error validation. #817

Merged
merged 2 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 130 additions & 18 deletions neo/assistant/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/gin-gonic/gin"
"github.com/yaoapp/gou/fs"
"github.com/yaoapp/gou/process"
chatctx "github.com/yaoapp/yao/neo/context"
"github.com/yaoapp/yao/neo/message"
chatMessage "github.com/yaoapp/yao/neo/message"
Expand Down Expand Up @@ -69,11 +70,7 @@ func (ast *Assistant) Execute(c *gin.Context, ctx chatctx.Context, input string,

// Handle next action
if res.Next != nil {
switch res.Next.Action {
case "exit":
return nil
// Add other actions here if needed
}
return res.Next.Execute(c, ctx)
}

// Update options if provided
Expand All @@ -90,15 +87,87 @@ func (ast *Assistant) Execute(c *gin.Context, ctx chatctx.Context, input string,
return ast.handleChatStream(c, ctx, messages, options)
}

// Execute the next action
func (next *NextAction) Execute(c *gin.Context, ctx chatctx.Context) error {
switch next.Action {

case "process":
if next.Payload == nil {
return fmt.Errorf("payload is required")
}

name, ok := next.Payload["name"].(string)
if !ok {
return fmt.Errorf("process name should be string")
}

args := []interface{}{}
if v, ok := next.Payload["args"].([]interface{}); ok {
args = v
}

// Add context and writer to args
args = append(args, ctx, c.Writer)
p, err := process.Of(name, args...)
if err != nil {
return fmt.Errorf("get process error: %s", err.Error())
}

err = p.Execute()
if err != nil {
return fmt.Errorf("execute process error: %s", err.Error())
}
defer p.Release()

return nil

case "assistant":
if next.Payload == nil {
return fmt.Errorf("payload is required")
}

// Get assistant id
id, ok := next.Payload["assistant_id"].(string)
if !ok {
return fmt.Errorf("assistant id should be string")
}

// Get assistant
assistant, err := Get(id)
if err != nil {
return fmt.Errorf("get assistant error: %s", err.Error())
}

// Input
input, ok := next.Payload["input"].(string)
if !ok {
return fmt.Errorf("input should be string")
}

// Options
options := map[string]interface{}{}
if v, ok := next.Payload["options"].(map[string]interface{}); ok {
options = v
}
return assistant.Execute(c, ctx, input, options)

case "exit":
return nil

default:
return fmt.Errorf("unknown action: %s", next.Action)
}
}

// handleChatStream manages the streaming chat interaction with the AI
func (ast *Assistant) handleChatStream(c *gin.Context, ctx chatctx.Context, messages []message.Message, options map[string]interface{}) error {
clientBreak := make(chan bool, 1)
done := make(chan bool, 1)
content := []byte{}
content := message.NewContent("text")

// Chat with AI in background
go func() {
err := ast.streamChat(c, ctx, messages, options, clientBreak, done, &content)
err := ast.streamChat(c, ctx, messages, options, clientBreak, done, content)
if err != nil {
chatMessage.New().Error(err).Done().Write(c.Writer)
}
Expand All @@ -118,8 +187,14 @@ func (ast *Assistant) handleChatStream(c *gin.Context, ctx chatctx.Context, mess
}

// streamChat handles the streaming chat interaction
func (ast *Assistant) streamChat(c *gin.Context, ctx chatctx.Context, messages []message.Message, options map[string]interface{},
clientBreak chan bool, done chan bool, content *[]byte) error {
func (ast *Assistant) streamChat(
c *gin.Context,
ctx chatctx.Context,
messages []message.Message,
options map[string]interface{},
clientBreak chan bool,
done chan bool,
content *message.Content) error {

return ast.Chat(c.Request.Context(), messages, options, func(data []byte) int {
select {
Expand All @@ -135,7 +210,7 @@ func (ast *Assistant) streamChat(c *gin.Context, ctx chatctx.Context, messages [
// Handle error
if msg.Type == "error" {
value := msg.String()
res, hookErr := ast.HookFail(c, ctx, messages, string(*content), fmt.Errorf("%s", value))
res, hookErr := ast.HookFail(c, ctx, messages, content.String(), fmt.Errorf("%s", value))
if hookErr == nil && res != nil && (res.Output != "" || res.Error != "") {
value = res.Output
if res.Error != "" {
Expand All @@ -146,20 +221,41 @@ func (ast *Assistant) streamChat(c *gin.Context, ctx chatctx.Context, messages [
return 0 // break
}

// Handle tool call
if msg.Type == "tool_calls" {
content.SetType("function") // Set type to function
// Set id
if id, ok := msg.Props["id"].(string); ok && id != "" {
content.SetID(id)
}

// Set name
if name, ok := msg.Props["name"].(string); ok && name != "" {
content.SetName(name)
}
}

// Append content and send message
*content = msg.Append(*content)
value := msg.String()
content.Append(value)
if value != "" {
// Handle stream
res, err := ast.HookStream(c, ctx, messages, string(*content))
res, err := ast.HookStream(c, ctx, messages, content.String(), msg.Type == "tool_calls")
if err == nil && res != nil {
if res.Output != "" {
value = res.Output
}
if res.Next != nil && res.Next.Action == "exit" {

if res.Next != nil {
err = res.Next.Execute(c, ctx)
if err != nil {
chatMessage.New().Error(err.Error()).Done().Write(c.Writer)
}

done <- true
return 0 // break
}

if res.Silent {
return 1 // continue
}
Expand All @@ -180,7 +276,8 @@ func (ast *Assistant) streamChat(c *gin.Context, ctx chatctx.Context, messages [
// }

// Call HookDone
res, hookErr := ast.HookDone(c, ctx, messages, string(*content))
content.SetStatus(message.ContentStatusDone)
res, hookErr := ast.HookDone(c, ctx, messages, content.String(), msg.Type == "tool_calls")
if hookErr == nil && res != nil {
if res.Output != "" {
chatMessage.New().
Expand All @@ -190,10 +287,16 @@ func (ast *Assistant) streamChat(c *gin.Context, ctx chatctx.Context, messages [
}).
Write(c.Writer)
}
if res.Next != nil && res.Next.Action == "exit" {

if res.Next != nil {
err := res.Next.Execute(c, ctx)
if err != nil {
chatMessage.New().Error(err.Error()).Done().Write(c.Writer)
}
done <- true
return 0 // break
}

} else if value != "" {
chatMessage.New().
Map(map[string]interface{}{
Expand All @@ -213,13 +316,13 @@ func (ast *Assistant) streamChat(c *gin.Context, ctx chatctx.Context, messages [
}

// saveChatHistory saves the chat history if storage is available
func (ast *Assistant) saveChatHistory(ctx chatctx.Context, messages []message.Message, content []byte) {
if len(content) > 0 && ctx.Sid != "" && len(messages) > 0 {
func (ast *Assistant) saveChatHistory(ctx chatctx.Context, messages []message.Message, content *message.Content) {
if len(content.Bytes) > 0 && ctx.Sid != "" && len(messages) > 0 {
storage.SaveHistory(
ctx.Sid,
[]map[string]interface{}{
{"role": "user", "content": messages[len(messages)-1].Content(), "name": ctx.Sid},
{"role": "assistant", "content": string(content), "name": ctx.Sid},
{"role": "assistant", "content": content.String(), "name": ctx.Sid},
},
ctx.ChatID,
nil,
Expand All @@ -237,6 +340,15 @@ func (ast *Assistant) withOptions(options map[string]interface{}) map[string]int
options[key] = value
}
}

// Add functions
if ast.Functions != nil {
options["tools"] = ast.Functions
if options["tool_choice"] == nil {
options["tool_choice"] = "auto"
}
}

return options
}

Expand Down
1 change: 1 addition & 0 deletions neo/assistant/assistant.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ func (ast *Assistant) Map() map[string]interface{} {
"description": ast.Description,
"options": ast.Options,
"prompts": ast.Prompts,
"functions": ast.Functions,
"tags": ast.Tags,
"mentionable": ast.Mentionable,
"automated": ast.Automated,
Expand Down
8 changes: 4 additions & 4 deletions neo/assistant/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ func (ast *Assistant) HookInit(c *gin.Context, context chatctx.Context, input []
}

// HookStream Handle streaming response from LLM
func (ast *Assistant) HookStream(c *gin.Context, context chatctx.Context, input []message.Message, output string) (*ResHookStream, error) {
func (ast *Assistant) HookStream(c *gin.Context, context chatctx.Context, input []message.Message, output string, toolcall bool) (*ResHookStream, error) {

// Create timeout context
ctx, cancel := ast.createTimeoutContext(c)
defer cancel()

v, err := ast.call(ctx, "Stream", context, input, output, c.Writer)
v, err := ast.call(ctx, "Stream", context, input, output, toolcall, c.Writer)
if err != nil {
if err.Error() == HookErrorMethodNotFound {
return nil, nil
Expand Down Expand Up @@ -100,12 +100,12 @@ func (ast *Assistant) HookStream(c *gin.Context, context chatctx.Context, input
}

// HookDone Handle completion of assistant response
func (ast *Assistant) HookDone(c *gin.Context, context chatctx.Context, input []message.Message, output string) (*ResHookDone, error) {
func (ast *Assistant) HookDone(c *gin.Context, context chatctx.Context, input []message.Message, output string, toolcall bool) (*ResHookDone, error) {
// Create timeout context
ctx, cancel := ast.createTimeoutContext(c)
defer cancel()

v, err := ast.call(ctx, "Done", context, input, output, c.Writer)
v, err := ast.call(ctx, "Done", context, input, output, toolcall, c.Writer)
if err != nil {
if err.Error() == HookErrorMethodNotFound {
return nil, nil
Expand Down
55 changes: 55 additions & 0 deletions neo/assistant/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,16 @@ func LoadPath(path string) (*Assistant, error) {
}

// load functions
functionsfile := filepath.Join(path, "functions.json")
if has, _ := app.Exists(functionsfile); has {
functions, ts, err := loadFunctions(functionsfile)
if err != nil {
return nil, err
}
data["functions"] = functions
updatedAt = max(updatedAt, ts)
data["updated_at"] = updatedAt
}

// load flow

Expand Down Expand Up @@ -340,6 +350,25 @@ func loadMap(data map[string]interface{}) (*Assistant, error) {
assistant.Prompts = prompts
}

// functions
if funcs, has := data["functions"]; has {
switch vv := funcs.(type) {
case []Function:
assistant.Functions = vv
default:
raw, err := jsoniter.Marshal(vv)
if err != nil {
return nil, err
}
var functions []Function
err = jsoniter.Unmarshal(raw, &functions)
if err != nil {
return nil, err
}
assistant.Functions = functions
}
}

// script
if data["script"] != nil {
switch v := data["script"].(type) {
Expand Down Expand Up @@ -382,6 +411,32 @@ func loadMap(data map[string]interface{}) (*Assistant, error) {
return assistant, nil
}

func loadFunctions(file string) ([]Function, int64, error) {

app, err := fs.Get("app")
if err != nil {
return nil, 0, err
}

ts, err := app.ModTime(file)
if err != nil {
return nil, 0, err
}

raw, err := app.ReadFile(file)
if err != nil {
return nil, 0, err
}

var functions []Function
err = jsoniter.Unmarshal(raw, &functions)
if err != nil {
return nil, 0, err
}

return functions, ts.UnixNano(), nil
}

func loadPrompts(file string, root string) (string, int64, error) {

app, err := fs.Get("app")
Expand Down
11 changes: 11 additions & 0 deletions neo/assistant/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@ type Prompt struct {
Name string `json:"name,omitempty"`
}

// Function a function
type Function struct {
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]interface{} `json:"parameters"`
} `json:"function"`
}

// QueryParam the assistant query param
type QueryParam struct {
Limit uint `json:"limit"`
Expand All @@ -110,6 +120,7 @@ type Assistant struct {
Automated bool `json:"automated,omitempty"` // Whether this assistant is automated
Options map[string]interface{} `json:"options,omitempty"` // AI Options
Prompts []Prompt `json:"prompts,omitempty"` // AI Prompts
Functions []Function `json:"functions,omitempty"` // Assistant Functions
Flows []map[string]interface{} `json:"flows,omitempty"` // Assistant Flows
Script *v8.Script `json:"-" yaml:"-"` // Assistant Script
CreatedAt int64 `json:"created_at"` // Creation timestamp
Expand Down
Loading
Loading