From 5f7b2e866c80fc87848c7706ba61b2a747397ee0 Mon Sep 17 00:00:00 2001 From: Max Date: Mon, 6 Jan 2025 11:05:38 +0800 Subject: [PATCH 01/11] Refactor Neo API assistant functionality and enhance model capabilities - Removed obsolete file upload/download functions and related constants from the Assistant struct, streamlining the codebase. - Updated the connector type in load tests to use "gpt-3_5-turbo" for improved compatibility with current models. - Introduced VisionCapableModels to support a range of models with vision capabilities, enhancing AI functionality. - Enhanced the File struct to include additional fields for description, URL, and document IDs, improving data management for file handling. These changes improve the overall structure and maintainability of the Neo API, paving the way for future enhancements in assistant capabilities. --- neo/assistant/api.go | 109 ---------- neo/assistant/attachment.go | 312 ++++++++++++++++++++++++++++ neo/assistant/attachment_test.go | 338 +++++++++++++++++++++++++++++++ neo/assistant/load_test.go | 4 +- neo/assistant/types.go | 38 +++- 5 files changed, 685 insertions(+), 116 deletions(-) create mode 100644 neo/assistant/attachment.go create mode 100644 neo/assistant/attachment_test.go diff --git a/neo/assistant/api.go b/neo/assistant/api.go index 8cad389043..8481ed9d69 100644 --- a/neo/assistant/api.go +++ b/neo/assistant/api.go @@ -2,14 +2,9 @@ package assistant import ( "context" - "crypto/sha256" "encoding/base64" "fmt" - "io" - "mime/multipart" - "path/filepath" "strings" - "time" "github.com/yaoapp/gou/fs" chatMessage "github.com/yaoapp/yao/neo/message" @@ -45,22 +40,6 @@ func GetByConnector(connector string, name string) (*Assistant, error) { return assistant, nil } -// AllowedFileTypes the allowed file types -var AllowedFileTypes = map[string]string{ - "application/json": "json", - "application/pdf": "pdf", - "application/msword": "doc", - "application/vnd.openxmlformats-officedocument.wordprocessingml.document": "docx", - "application/vnd.oasis.opendocument.text": "odt", - "application/vnd.ms-excel": "xls", - "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": "xlsx", - "application/vnd.ms-powerpoint": "ppt", - "application/vnd.openxmlformats-officedocument.presentationml.presentation": "pptx", -} - -// MaxSize 20M max file size -var MaxSize int64 = 20 * 1024 * 1024 - // Chat implements the chat functionality func (ast *Assistant) Chat(ctx context.Context, messages []map[string]interface{}, option map[string]interface{}, cb func(data []byte) int) error { if ast.openai == nil { @@ -175,94 +154,6 @@ func (ast *Assistant) withAttachments(ctx context.Context, msg *chatMessage.Mess return contents, nil } -// Upload implements file upload functionality -func (ast *Assistant) Upload(ctx context.Context, file *multipart.FileHeader, reader io.Reader, option map[string]interface{}) (*File, error) { - // check file size - if file.Size > MaxSize { - return nil, fmt.Errorf("file size %d exceeds the maximum size of %d", file.Size, MaxSize) - } - - contentType := file.Header.Get("Content-Type") - if !ast.allowed(contentType) { - return nil, fmt.Errorf("file type %s not allowed", contentType) - } - - data, err := fs.Get("data") - if err != nil { - return nil, err - } - - ext := filepath.Ext(file.Filename) - id, err := ast.id(file.Filename, ext) - if err != nil { - return nil, err - } - - filename := id - _, err = data.Write(filename, reader, 0644) - if err != nil { - return nil, err - } - - return &File{ - ID: filename, - Filename: filename, - ContentType: contentType, - Bytes: int(file.Size), - CreatedAt: int(time.Now().Unix()), - }, nil -} - -func (ast *Assistant) allowed(contentType string) bool { - if _, ok := AllowedFileTypes[contentType]; ok { - return true - } - if strings.HasPrefix(contentType, "text/") || strings.HasPrefix(contentType, "image/") || - strings.HasPrefix(contentType, "audio/") || strings.HasPrefix(contentType, "video/") { - return true - } - return false -} - -func (ast *Assistant) id(temp string, ext string) (string, error) { - date := time.Now().Format("20060102") - hash := fmt.Sprintf("%x", sha256.Sum256([]byte(temp)))[:8] - return fmt.Sprintf("/__assistants/%s/%s/%s%s", ast.ID, date, hash, ext), nil -} - -// Download implements file download functionality -func (ast *Assistant) Download(ctx context.Context, fileID string) (*FileResponse, error) { - data, err := fs.Get("data") - if err != nil { - return nil, fmt.Errorf("get filesystem error: %s", err.Error()) - } - - exists, err := data.Exists(fileID) - if err != nil { - return nil, fmt.Errorf("check file error: %s", err.Error()) - } - if !exists { - return nil, fmt.Errorf("file %s not found", fileID) - } - - reader, err := data.ReadCloser(fileID) - if err != nil { - return nil, err - } - - ext := filepath.Ext(fileID) - contentType := "application/octet-stream" - if v, err := data.MimeType(fileID); err == nil { - contentType = v - } - - return &FileResponse{ - Reader: reader, - ContentType: contentType, - Extension: ext, - }, nil -} - // ReadBase64 implements base64 file reading functionality func (ast *Assistant) ReadBase64(ctx context.Context, fileID string) (string, error) { data, err := fs.Get("data") diff --git a/neo/assistant/attachment.go b/neo/assistant/attachment.go new file mode 100644 index 0000000000..7cd1f270e1 --- /dev/null +++ b/neo/assistant/attachment.go @@ -0,0 +1,312 @@ +package assistant + +import ( + "bytes" + "context" + "crypto/sha256" + "fmt" + "io" + "mime/multipart" + "path/filepath" + "strings" + "time" + + jsoniter "github.com/json-iterator/go" + "github.com/yaoapp/gou/fs" + "github.com/yaoapp/gou/rag/driver" +) + +// AllowedFileTypes the allowed file types +var AllowedFileTypes = map[string]string{ + "application/json": "json", + "application/pdf": "pdf", + "application/msword": "doc", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document": "docx", + "application/vnd.oasis.opendocument.text": "odt", + "application/vnd.ms-excel": "xls", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": "xlsx", + "application/vnd.ms-powerpoint": "ppt", + "application/vnd.openxmlformats-officedocument.presentationml.presentation": "pptx", +} + +// MaxSize 20M max file size +var MaxSize int64 = 20 * 1024 * 1024 + +// Upload implements file upload functionality +func (ast *Assistant) Upload(ctx context.Context, file *multipart.FileHeader, reader io.Reader, option map[string]interface{}) (*File, error) { + // check file size + if file.Size > MaxSize { + return nil, fmt.Errorf("file size %d exceeds the maximum size of %d", file.Size, MaxSize) + } + + contentType := file.Header.Get("Content-Type") + if !ast.allowed(contentType) { + return nil, fmt.Errorf("file type %s not allowed", contentType) + } + + // Get chat ID and session ID from options + chatID := "" + sid := "" + if v, ok := option["chat_id"].(string); ok { + chatID = v + } + if v, ok := option["sid"].(string); ok { + sid = v + } + + // Generate file ID with namespace + fileID, err := ast.generateFileID(file.Filename, sid, chatID) + if err != nil { + return nil, err + } + + // Upload file to storage + data, err := fs.Get("data") + if err != nil { + return nil, err + } + + _, err = data.Write(fileID, reader, 0644) + if err != nil { + return nil, err + } + + // Create file response + fileResp := &File{ + ID: fileID, + Filename: fileID, + ContentType: contentType, + Bytes: int(file.Size), + CreatedAt: int(time.Now().Unix()), + } + + // Handle RAG if available + if err := ast.handleRAG(ctx, fileResp, reader); err != nil { + return nil, fmt.Errorf("RAG handling error: %s", err.Error()) + } + + // Handle Vision if available + if err := ast.handleVision(ctx, fileResp, option); err != nil { + return nil, fmt.Errorf("Vision handling error: %s", err.Error()) + } + + return fileResp, nil +} + +// generateFileID generates a file ID with proper namespace +func (ast *Assistant) generateFileID(filename string, sid string, chatID string) (string, error) { + ext := filepath.Ext(filename) + hash := fmt.Sprintf("%x", sha256.Sum256([]byte(filename)))[:8] + date := time.Now().Format("20060102") + + // Build namespace + namespace := fmt.Sprintf("__assistants/%s", ast.ID) + if sid != "" { + namespace = fmt.Sprintf("%s/%s", namespace, sid) + if chatID != "" { + namespace = fmt.Sprintf("%s/%s", namespace, chatID) + } + } + + return fmt.Sprintf("%s/%s/%s%s", namespace, date, hash, ext), nil +} + +// handleRAG handles the file with RAG if available +func (ast *Assistant) handleRAG(ctx context.Context, file *File, reader io.Reader) error { + if rag == nil { + return nil + } + + // Only handle text-based files + if !strings.HasPrefix(file.ContentType, "text/") { + return nil + } + + // Reset reader to beginning + if seeker, ok := reader.(io.Seeker); ok { + if _, err := seeker.Seek(0, io.SeekStart); err != nil { + return err + } + } + + // Extract sid and chat_id from file path + parts := strings.Split(file.ID, "/") + indexName := fmt.Sprintf("%s%s", rag.Setting.IndexPrefix, ast.ID) // Default: prefix-assistant + + if len(parts) >= 4 { // Has sid + sid := parts[2] + indexName = fmt.Sprintf("%s%s-%s", rag.Setting.IndexPrefix, ast.ID, sid) // prefix-assistant-user + + if len(parts) >= 5 { // Has chat_id + chatID := parts[3] + indexName = fmt.Sprintf("%s%s-%s-%s", rag.Setting.IndexPrefix, ast.ID, sid, chatID) // prefix-assistant-user-chat + } + } + + // Check if index exists + exists, err := rag.Engine.HasIndex(ctx, indexName) + if err != nil { + return fmt.Errorf("check index error: %s", err.Error()) + } + + // Create index if not exists + if !exists { + err = rag.Engine.CreateIndex(ctx, driver.IndexConfig{Name: indexName}) + if err != nil { + return fmt.Errorf("create index error: %s", err.Error()) + } + } + + // Reset reader again after checking index + if seeker, ok := reader.(io.Seeker); ok { + if _, err := seeker.Seek(0, io.SeekStart); err != nil { + return err + } + } + + // Upload and index the file + result, err := rag.Uploader.Upload(ctx, reader, driver.FileUploadOptions{ + Async: false, + ChunkSize: 1024, // Default chunk size + ChunkOverlap: 256, // Default overlap + IndexName: indexName, + }) + + if err != nil { + return fmt.Errorf("upload error: %s", err.Error()) + } + + if len(result.Documents) == 0 { + return fmt.Errorf("no documents indexed") + } + + // Store the document IDs + docIDs := make([]string, len(result.Documents)) + for i, doc := range result.Documents { + docIDs[i] = doc.DocID + } + file.DocIDs = docIDs + + return nil +} + +// handleVision handles the file with Vision if available +func (ast *Assistant) handleVision(ctx context.Context, file *File, option map[string]interface{}) error { + if vision == nil { + return nil + } + + // Check if file is an image + if !strings.HasPrefix(file.ContentType, "image/") { + return nil + } + + // Get model from options + model := "" + if v, ok := option["model"].(string); ok { + model = v + } + + // Reset reader for vision service + data, err := fs.Get("data") + if err != nil { + return fmt.Errorf("get filesystem error: %s", err.Error()) + } + + exists, err := data.Exists(file.ID) + if err != nil { + return fmt.Errorf("check file error: %s", err.Error()) + } + if !exists { + return fmt.Errorf("file %s not found", file.ID) + } + + // Read file content into memory + imgData, err := data.ReadFile(file.ID) + if err != nil { + return fmt.Errorf("read file error: %s", err.Error()) + } + + if VisionCapableModels[model] { + // For vision-capable models, upload to vision service to get URL + resp, err := vision.Upload(ctx, file.Filename, bytes.NewReader(imgData), file.ContentType) + if err != nil { + return fmt.Errorf("vision upload error: %s", err.Error()) + } + file.URL = resp.URL // Store the URL for vision-capable models to use + } else { + // For non-vision models, get image description + prompt := "Describe this image in detail." + if v, ok := option["vision_prompt"].(string); ok { + prompt = v + } + + // Upload to vision service first Compress image + resp, err := vision.Upload(ctx, file.Filename, bytes.NewReader(imgData), file.ContentType) + if err != nil { + return fmt.Errorf("vision upload error: %s", err.Error()) + } + + // Analyze using base64 data + result, err := vision.Analyze(ctx, resp.FileID, prompt) + if err != nil { + return fmt.Errorf("vision analyze error: %s", err.Error()) + } + + // Extract description text from response + if desc, ok := result.Description["text"].(string); ok { + file.Description = desc + } else { + // Convert the entire description to JSON string as fallback + bytes, err := jsoniter.Marshal(result.Description) + if err == nil { + file.Description = string(bytes) + } + } + } + return nil +} + +// Download implements file download functionality +func (ast *Assistant) Download(ctx context.Context, fileID string) (*FileResponse, error) { + data, err := fs.Get("data") + if err != nil { + return nil, fmt.Errorf("get filesystem error: %s", err.Error()) + } + + exists, err := data.Exists(fileID) + if err != nil { + return nil, fmt.Errorf("check file error: %s", err.Error()) + } + if !exists { + return nil, fmt.Errorf("file %s not found", fileID) + } + + reader, err := data.ReadCloser(fileID) + if err != nil { + return nil, err + } + + ext := filepath.Ext(fileID) + contentType := "application/octet-stream" + if v, err := data.MimeType(fileID); err == nil { + contentType = v + } + + return &FileResponse{ + Reader: reader, + ContentType: contentType, + Extension: ext, + }, nil +} + +func (ast *Assistant) allowed(contentType string) bool { + if _, ok := AllowedFileTypes[contentType]; ok { + return true + } + if strings.HasPrefix(contentType, "text/") || strings.HasPrefix(contentType, "image/") || + strings.HasPrefix(contentType, "audio/") || strings.HasPrefix(contentType, "video/") { + return true + } + return false +} diff --git a/neo/assistant/attachment_test.go b/neo/assistant/attachment_test.go new file mode 100644 index 0000000000..54a60e9e0c --- /dev/null +++ b/neo/assistant/attachment_test.go @@ -0,0 +1,338 @@ +package assistant + +import ( + "bytes" + "context" + "encoding/base64" + "io" + "mime/multipart" + "os" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/yaoapp/gou/fs" + gourag "github.com/yaoapp/gou/rag" + "github.com/yaoapp/gou/rag/driver" + "github.com/yaoapp/yao/config" + neovision "github.com/yaoapp/yao/neo/vision" + vdriver "github.com/yaoapp/yao/neo/vision/driver" + "github.com/yaoapp/yao/test" +) + +var ( + // 1x1 transparent PNG for testing + testImageBase64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" +) + +func TestUpload(t *testing.T) { + test.Prepare(t, config.Conf) + defer test.Clean() + + ast := setupTestAssistant(t) + ctx := context.Background() + + t.Run("Basic File Upload", func(t *testing.T) { + content := []byte("test content") + file := &multipart.FileHeader{ + Filename: "test.txt", + Size: int64(len(content)), + } + file.Header = make(map[string][]string) + file.Header.Set("Content-Type", "text/plain") + + reader := bytes.NewReader(content) + fileResp, err := ast.Upload(ctx, file, reader, map[string]interface{}{ + "sid": "test-user", + "chat_id": "test-chat", + }) + + assert.NoError(t, err) + assert.NotNil(t, fileResp) + assert.Contains(t, fileResp.ID, "test-assistant/test-user/test-chat") + assert.Equal(t, len(content), fileResp.Bytes) + assert.Equal(t, "text/plain", fileResp.ContentType) + }) + + t.Run("File Size Limit", func(t *testing.T) { + content := make([]byte, MaxSize+1) + file := &multipart.FileHeader{ + Filename: "large.txt", + Size: int64(len(content)), + } + file.Header = make(map[string][]string) + file.Header.Set("Content-Type", "text/plain") + + reader := bytes.NewReader(content) + _, err := ast.Upload(ctx, file, reader, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "exceeds the maximum size") + }) + + t.Run("Invalid Content Type", func(t *testing.T) { + content := []byte("test") + file := &multipart.FileHeader{ + Filename: "test.invalid", + Size: int64(len(content)), + } + file.Header = make(map[string][]string) + file.Header.Set("Content-Type", "invalid/type") + + reader := bytes.NewReader(content) + _, err := ast.Upload(ctx, file, reader, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not allowed") + }) +} + +func TestUploadWithRAG(t *testing.T) { + test.Prepare(t, config.Conf) + defer test.Clean() + + ast := setupTestAssistant(t) + ragEngine, ragUploader, ragVectorizer := setupTestRAG(t) + SetRAG(ragEngine, ragUploader, ragVectorizer, RAGSetting{IndexPrefix: "test_"}) + defer func() { + rag = nil // Completely reset the global rag variable + }() + ctx := context.Background() + + t.Run("Text File with RAG", func(t *testing.T) { + content := []byte("This is a test document for RAG indexing") + file := &multipart.FileHeader{ + Filename: "test.txt", + Size: int64(len(content)), + } + file.Header = make(map[string][]string) + file.Header.Set("Content-Type", "text/plain") + + reader := bytes.NewReader(content) + fileResp, err := ast.Upload(ctx, file, reader, map[string]interface{}{ + "sid": "test-user", + "chat_id": "test-chat", + }) + + assert.NoError(t, err) + assert.NotNil(t, fileResp) + assert.NotEmpty(t, fileResp.DocIDs, "Document IDs should not be empty") + + // Wait for indexing to complete + time.Sleep(500 * time.Millisecond) + + // Verify the file was indexed by checking if it exists in RAG + exists, err := ragEngine.HasDocument(ctx, "test_test-assistant-test-user-test-chat", fileResp.DocIDs[0]) + assert.NoError(t, err) + assert.True(t, exists, "Document should exist in RAG index") + }) + + t.Run("Non-Text File with RAG", func(t *testing.T) { + imgData, _ := base64.StdEncoding.DecodeString(testImageBase64) + file := &multipart.FileHeader{ + Filename: "test.png", + Size: int64(len(imgData)), + } + file.Header = make(map[string][]string) + file.Header.Set("Content-Type", "image/png") + + reader := bytes.NewReader(imgData) + fileResp, err := ast.Upload(ctx, file, reader, nil) + assert.NoError(t, err) + assert.NotNil(t, fileResp) + // Verify the file was not indexed + exists, err := ragEngine.HasDocument(ctx, "test_test-assistant", fileResp.ID) + assert.NoError(t, err) + assert.False(t, exists) + }) +} + +func setupTestAssistant(t *testing.T) *Assistant { + ast := &Assistant{ + ID: "test-assistant", + Name: "Test Assistant", + Connector: "test-connector", + } + return ast +} + +func setupTestRAG(t *testing.T) (driver.Engine, driver.FileUpload, driver.Vectorizer) { + // Get test config + openaiKey := os.Getenv("OPENAI_API_KEY") + if openaiKey == "" { + t.Skip("OPENAI_API_KEY not set") + } + + vectorizeConfig := driver.VectorizeConfig{ + Model: os.Getenv("VECTORIZER_MODEL"), + Options: map[string]string{ + "api_key": openaiKey, + }, + } + + // Qdrant config + host := os.Getenv("QDRANT_HOST") + if host == "" { + host = "localhost" + } + + port := os.Getenv("QDRANT_PORT") + if port == "" { + port = "6334" + } + + // Create vectorizer + vectorizer, err := gourag.NewVectorizer(gourag.DriverOpenAI, vectorizeConfig) + if err != nil { + t.Fatal(err) + } + + // Create engine + engine, err := gourag.NewEngine(gourag.DriverQdrant, driver.IndexConfig{ + Options: map[string]string{ + "host": host, + "port": port, + "api_key": "", + }, + }, vectorizer) + if err != nil { + t.Fatal(err) + } + + // Create file upload + fileUpload, err := gourag.NewFileUpload(gourag.DriverQdrant, engine, vectorizer) + if err != nil { + t.Fatal(err) + } + + return engine, fileUpload, vectorizer +} + +func setupTestVision(t *testing.T) *neovision.Vision { + // Create test data directory + data, err := fs.Get("data") + assert.NoError(t, err) + + // Write test image data + imgData, err := base64.StdEncoding.DecodeString(testImageBase64) + assert.NoError(t, err) + _, err = data.WriteFile("/test.png", imgData, 0644) + assert.NoError(t, err) + + cfg := &vdriver.Config{ + Storage: vdriver.StorageConfig{ + Driver: "local", + Options: map[string]interface{}{ + "path": "/__vision_test", + "compression": true, + }, + }, + Model: vdriver.ModelConfig{ + Driver: "openai", + Options: map[string]interface{}{ + "api_key": os.Getenv("OPENAI_API_KEY"), + "model": os.Getenv("VISION_MODEL"), + }, + }, + } + + v, err := neovision.New(cfg) + if err != nil { + t.Fatal(err) + } + return v +} + +func TestUploadWithVision(t *testing.T) { + test.Prepare(t, config.Conf) + defer test.Clean() + + ast := setupTestAssistant(t) + vision := setupTestVision(t) + SetVision(vision) + defer func() { + vision = nil // Completely reset the global vision variable + }() + ctx := context.Background() + + t.Run("Image with Vision-Capable Model", func(t *testing.T) { + imgData, _ := base64.StdEncoding.DecodeString(testImageBase64) + file := &multipart.FileHeader{ + Filename: "test.png", + Size: int64(len(imgData)), + } + file.Header = make(map[string][]string) + file.Header.Set("Content-Type", "image/png") + + reader := bytes.NewReader(imgData) + fileResp, err := ast.Upload(ctx, file, reader, map[string]interface{}{ + "model": "gpt-4-vision-preview", + }) + + assert.NoError(t, err) + assert.NotNil(t, fileResp) + assert.NotEmpty(t, fileResp.URL) + assert.Empty(t, fileResp.Description) + }) + + t.Run("Image with Non-Vision Model", func(t *testing.T) { + imgData, _ := base64.StdEncoding.DecodeString(testImageBase64) + file := &multipart.FileHeader{ + Filename: "test.png", + Size: int64(len(imgData)), + } + file.Header = make(map[string][]string) + file.Header.Set("Content-Type", "image/png") + + reader := bytes.NewReader(imgData) + fileResp, err := ast.Upload(ctx, file, reader, map[string]interface{}{ + "model": "gpt-4", + "vision_prompt": "What's in this image?", + }) + + assert.NoError(t, err) + assert.NotNil(t, fileResp) + assert.Empty(t, fileResp.URL) + assert.NotEmpty(t, fileResp.Description) + }) +} + +func TestDownload(t *testing.T) { + test.Prepare(t, config.Conf) + defer test.Clean() + + ast := setupTestAssistant(t) + ctx := context.Background() + + t.Run("Download Existing File", func(t *testing.T) { + // First upload a file + content := []byte("test content") + file := &multipart.FileHeader{ + Filename: "test.txt", + Size: int64(len(content)), + } + file.Header = make(map[string][]string) + file.Header.Set("Content-Type", "text/plain") + + reader := bytes.NewReader(content) + fileResp, err := ast.Upload(ctx, file, reader, nil) + assert.NoError(t, err) + + // Then download it + downloadResp, err := ast.Download(ctx, fileResp.ID) + assert.NoError(t, err) + assert.NotNil(t, downloadResp) + assert.True(t, strings.HasPrefix(downloadResp.ContentType, "text/plain"), "Content-Type should start with text/plain") + assert.Equal(t, ".txt", downloadResp.Extension) + + // Verify content + downloaded, err := io.ReadAll(downloadResp.Reader) + assert.NoError(t, err) + assert.Equal(t, content, downloaded) + }) + + t.Run("Download Non-Existent File", func(t *testing.T) { + _, err := ast.Download(ctx, "non-existent-file") + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found") + }) +} diff --git a/neo/assistant/load_test.go b/neo/assistant/load_test.go index 5c6d21bab2..dd5088c8ba 100644 --- a/neo/assistant/load_test.go +++ b/neo/assistant/load_test.go @@ -53,7 +53,7 @@ func TestLoad_LoadStore(t *testing.T) { "assistant_id": "test-id", "name": "Test Assistant", "avatar": "test-avatar", - "connector": "test-connector", + "connector": "gpt-3_5-turbo", }, }, } @@ -67,7 +67,7 @@ func TestLoad_LoadStore(t *testing.T) { assert.Equal(t, "test-id", assistant.ID) assert.Equal(t, "Test Assistant", assistant.Name) assert.Equal(t, "test-avatar", assistant.Avatar) - assert.Equal(t, "test-connector", assistant.Connector) + assert.Equal(t, "gpt-3_5-turbo", assistant.Connector) // Test cache functionality assistant2, err := LoadStore("test-id") diff --git a/neo/assistant/types.go b/neo/assistant/types.go index 480f8146e0..2002ae24a0 100644 --- a/neo/assistant/types.go +++ b/neo/assistant/types.go @@ -70,13 +70,41 @@ type Assistant struct { openai *api.OpenAI // OpenAI API } +// VisionCapableModels list of LLM models that support vision capabilities +var VisionCapableModels = map[string]bool{ + // OpenAI Models + "gpt-4-vision-preview": true, + "gpt-4v": true, // Alias for gpt-4-vision-preview + + // Anthropic Models + "claude-3-opus": true, // Most capable Claude model + "claude-3-sonnet": true, // Balanced Claude model + "claude-3-haiku": true, // Fast and efficient Claude model + + // Google Models + "gemini-pro-vision": true, + + // Open Source Models + "llava-13b": true, + "cogvlm": true, + "qwen-vl": true, + "yi-vl": true, + + // Custom Models + "gpt-4o": true, // Custom OpenAI compatible model + "gpt-4o-mini": true, // Custom OpenAI compatible model - mini version +} + // File the file type File struct { - ID string `json:"file_id"` - Bytes int `json:"bytes"` - CreatedAt int `json:"created_at"` - Filename string `json:"filename"` - ContentType string `json:"content_type"` + ID string `json:"file_id"` + Bytes int `json:"bytes"` + CreatedAt int `json:"created_at"` + Filename string `json:"filename"` + ContentType string `json:"content_type"` + Description string `json:"description,omitempty"` // Vision analysis result or other description + URL string `json:"url,omitempty"` // Vision URL for vision-capable models + DocIDs []string `json:"doc_ids,omitempty"` // RAG document IDs } // FileResponse represents a file download response From d4d824a22c313c9776260b9e5693174040de5ba3 Mon Sep 17 00:00:00 2001 From: Max Date: Mon, 6 Jan 2025 11:27:44 +0800 Subject: [PATCH 02/11] Refactor Neo API upload functionality and enhance environment variable handling - Updated the Upload method in the DSL to default to the assistant in context, improving error handling for missing assistant IDs. - Introduced parseEnvValue and convertOptions functions in the Vision module to support parsing environment variables in configuration options, enhancing flexibility. - Refactored storage and model initialization in the Vision service to utilize converted options, ensuring environment variables are correctly applied. These changes improve the robustness and maintainability of the Neo API, paving the way for better configuration management and error handling in assistant functionalities. --- neo/neo.go | 19 ++++++++++--------- neo/vision/vision.go | 39 ++++++++++++++++++++++++++++++++++----- 2 files changed, 44 insertions(+), 14 deletions(-) diff --git a/neo/neo.go b/neo/neo.go index b114b0d8eb..f9d56b355d 100644 --- a/neo/neo.go +++ b/neo/neo.go @@ -219,15 +219,16 @@ func (neo *DSL) Upload(ctx Context, c *gin.Context) (*assistant.File, error) { Option: option, } - res, err := neo.HookCreate(ctx, []map[string]interface{}{}, c) - if err != nil { - return nil, err - } - - // Select Assistant - ast, err := neo.Select(res.AssistantID) - if err != nil { - return nil, err + // Default use the assistant in context + ast := neo.Assistant + if ctx.ChatID == "" { + if ctx.AssistantID == "" { + return nil, fmt.Errorf("assistant_id is required") + } + ast, err = neo.Select(ctx.AssistantID) + if err != nil { + return nil, err + } } return ast.Upload(ctx, tmpfile, reader, option) diff --git a/neo/vision/vision.go b/neo/vision/vision.go index 79c5f3721f..0886a90017 100644 --- a/neo/vision/vision.go +++ b/neo/vision/vision.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "os" "strings" "time" @@ -13,6 +14,30 @@ import ( "github.com/yaoapp/yao/neo/vision/driver/s3" ) +// parseEnvValue parse environment variable if the value starts with $ENV. +func parseEnvValue(value string) string { + if strings.HasPrefix(value, "$ENV.") { + envKey := strings.TrimPrefix(value, "$ENV.") + if envVal := os.Getenv(envKey); envVal != "" { + return envVal + } + } + return value +} + +// convertOptions convert interface{} options map to string map and parse environment variables +func convertOptions(options map[string]interface{}) map[string]interface{} { + converted := make(map[string]interface{}) + for k, v := range options { + if str, ok := v.(string); ok { + converted[k] = parseEnvValue(str) + } else { + converted[k] = v + } + } + return converted +} + // Vision the vision service type Vision struct { storage driver.Storage @@ -22,20 +47,24 @@ type Vision struct { // New create a new vision service func New(cfg *driver.Config) (*Vision, error) { + // Parse environment variables in options + storageOptions := convertOptions(cfg.Storage.Options) + modelOptions := convertOptions(cfg.Model.Options) + // Create storage driver var storage driver.Storage var err error switch cfg.Storage.Driver { case "local": - storage, err = local.New(cfg.Storage.Options) + storage, err = local.New(storageOptions) case "s3": // Convert expiration string to duration if present - if exp, ok := cfg.Storage.Options["expiration"].(string); ok { + if exp, ok := storageOptions["expiration"].(string); ok { if duration, err := time.ParseDuration(exp); err == nil { - cfg.Storage.Options["expiration"] = duration + storageOptions["expiration"] = duration } } - storage, err = s3.New(cfg.Storage.Options) + storage, err = s3.New(storageOptions) default: return nil, fmt.Errorf("storage driver %s not supported", cfg.Storage.Driver) } @@ -47,7 +76,7 @@ func New(cfg *driver.Config) (*Vision, error) { var model driver.Model switch cfg.Model.Driver { case "openai": - model, err = openai.New(cfg.Model.Options) + model, err = openai.New(modelOptions) default: return nil, fmt.Errorf("model driver %s not supported", cfg.Model.Driver) } From 2becefd06e99a01f527c5ae298c56bd2b6854a4a Mon Sep 17 00:00:00 2001 From: Max Date: Mon, 6 Jan 2025 11:43:31 +0800 Subject: [PATCH 03/11] Enhance Vision module with flexible prompt handling and comprehensive tests - Refactored the Analyze method in the Vision and OpenAI model to accept a variadic prompt parameter, allowing for optional custom prompts while defaulting to a predefined prompt if none is provided. - Added multiple test cases in vision_test.go and model_test.go to validate image analysis with default, custom, and empty prompts, ensuring robust functionality and error handling. - Updated the Model interface to reflect the new prompt handling, improving clarity and usability. These changes enhance the flexibility of the Vision module, paving the way for improved user experience and functionality in image analysis. --- neo/vision/driver/openai/model.go | 10 ++- neo/vision/driver/openai/model_test.go | 45 +++++++++++++ neo/vision/driver/types.go | 4 +- neo/vision/vision.go | 4 +- neo/vision/vision_test.go | 90 ++++++++++++++++++++++++++ 5 files changed, 148 insertions(+), 5 deletions(-) diff --git a/neo/vision/driver/openai/model.go b/neo/vision/driver/openai/model.go index a766c3804b..8a55abd180 100644 --- a/neo/vision/driver/openai/model.go +++ b/neo/vision/driver/openai/model.go @@ -52,11 +52,17 @@ func New(options map[string]interface{}) (*Model, error) { } // Analyze analyze image using OpenAI vision model -func (model *Model) Analyze(ctx context.Context, fileID string, prompt string) (map[string]interface{}, error) { +func (model *Model) Analyze(ctx context.Context, fileID string, prompt ...string) (map[string]interface{}, error) { if model.APIKey == "" { return nil, fmt.Errorf("api_key is required") } + // Use default prompt if none provided + userPrompt := model.Prompt + if len(prompt) > 0 && prompt[0] != "" { + userPrompt = prompt[0] + } + // Check if fileID is a URL or base64 data var imageURL string if strings.HasPrefix(fileID, "data:image/") { @@ -103,7 +109,7 @@ func (model *Model) Analyze(ctx context.Context, fileID string, prompt string) ( "content": []map[string]interface{}{ { "type": "text", - "text": prompt, + "text": userPrompt, }, { "type": "image_url", diff --git a/neo/vision/driver/openai/model_test.go b/neo/vision/driver/openai/model_test.go index e93d516e94..332d013f12 100644 --- a/neo/vision/driver/openai/model_test.go +++ b/neo/vision/driver/openai/model_test.go @@ -146,4 +146,49 @@ func TestOpenAIModel(t *testing.T) { assert.Error(t, err) assert.Contains(t, err.Error(), "OpenAI API error") }) + + t.Run("Analyze with Default Prompt", func(t *testing.T) { + model, err := New(map[string]interface{}{ + "api_key": os.Getenv("OPENAI_API_KEY"), + "model": os.Getenv("VISION_MODEL"), + "prompt": "Default test prompt", + }) + assert.NoError(t, err) + + // Use base64 image data without providing a prompt + result, err := model.Analyze(context.Background(), "data:image/png;base64,"+testImageBase64) + assert.NoError(t, err) + assert.NotNil(t, result) + assert.NotEmpty(t, result["description"]) + }) + + t.Run("Analyze with Custom Prompt Overriding Default", func(t *testing.T) { + model, err := New(map[string]interface{}{ + "api_key": os.Getenv("OPENAI_API_KEY"), + "model": os.Getenv("VISION_MODEL"), + "prompt": "Default test prompt", + }) + assert.NoError(t, err) + + // Use base64 image data with custom prompt + result, err := model.Analyze(context.Background(), "data:image/png;base64,"+testImageBase64, "Custom test prompt") + assert.NoError(t, err) + assert.NotNil(t, result) + assert.NotEmpty(t, result["description"]) + }) + + t.Run("Analyze with Empty Custom Prompt", func(t *testing.T) { + model, err := New(map[string]interface{}{ + "api_key": os.Getenv("OPENAI_API_KEY"), + "model": os.Getenv("VISION_MODEL"), + "prompt": "Default test prompt", + }) + assert.NoError(t, err) + + // Use base64 image data with empty prompt (should use default) + result, err := model.Analyze(context.Background(), "data:image/png;base64,"+testImageBase64, "") + assert.NoError(t, err) + assert.NotNil(t, result) + assert.NotEmpty(t, result["description"]) + }) } diff --git a/neo/vision/driver/types.go b/neo/vision/driver/types.go index ccc64ee6e7..9cc2006e92 100644 --- a/neo/vision/driver/types.go +++ b/neo/vision/driver/types.go @@ -32,7 +32,9 @@ type Storage interface { // Model the vision model interface type Model interface { - Analyze(ctx context.Context, fileID string, prompt string) (map[string]interface{}, error) + // Analyze analyzes an image file + // If prompt is empty, it will use the default prompt from model.options.prompt + Analyze(ctx context.Context, fileID string, prompt ...string) (map[string]interface{}, error) } // Response the vision response diff --git a/neo/vision/vision.go b/neo/vision/vision.go index 0886a90017..792be41a5a 100644 --- a/neo/vision/vision.go +++ b/neo/vision/vision.go @@ -104,7 +104,7 @@ func (v *Vision) Upload(ctx context.Context, filename string, reader io.Reader, } // Analyze analyze image using vision model -func (v *Vision) Analyze(ctx context.Context, fileID string, prompt string) (*driver.Response, error) { +func (v *Vision) Analyze(ctx context.Context, fileID string, prompt ...string) (*driver.Response, error) { if v.model == nil { return nil, fmt.Errorf("model is required") } @@ -121,7 +121,7 @@ func (v *Vision) Analyze(ctx context.Context, fileID string, prompt string) (*dr } } - result, err := v.model.Analyze(ctx, url, prompt) + result, err := v.model.Analyze(ctx, url, prompt...) if err != nil { return nil, err } diff --git a/neo/vision/vision_test.go b/neo/vision/vision_test.go index a0bd42ae0d..411a66a792 100644 --- a/neo/vision/vision_test.go +++ b/neo/vision/vision_test.go @@ -331,6 +331,96 @@ func TestVision(t *testing.T) { assert.LessOrEqual(t, bounds.Dx(), MaxImageSize) assert.LessOrEqual(t, bounds.Dy(), MaxImageSize) }) + + t.Run("Analyze Image with Default Prompt", func(t *testing.T) { + // Create vision service with default prompt + cfg := &driver.Config{ + Storage: driver.StorageConfig{ + Driver: "local", + Options: map[string]interface{}{ + "path": "/__vision_test", + "compression": true, + }, + }, + Model: driver.ModelConfig{ + Driver: "openai", + Options: map[string]interface{}{ + "api_key": os.Getenv("OPENAI_API_KEY"), + "model": os.Getenv("VISION_MODEL"), + "prompt": "Default test prompt", + }, + }, + } + + vision, err := New(cfg) + assert.NoError(t, err) + + // Use base64 data without providing a prompt + result, err := vision.Analyze(context.Background(), "data:image/png;base64,"+testImageBase64) + assert.NoError(t, err) + assert.NotNil(t, result) + assert.NotEmpty(t, result.Description) + }) + + t.Run("Analyze Image with Custom Prompt", func(t *testing.T) { + // Create vision service with default prompt + cfg := &driver.Config{ + Storage: driver.StorageConfig{ + Driver: "local", + Options: map[string]interface{}{ + "path": "/__vision_test", + "compression": true, + }, + }, + Model: driver.ModelConfig{ + Driver: "openai", + Options: map[string]interface{}{ + "api_key": os.Getenv("OPENAI_API_KEY"), + "model": os.Getenv("VISION_MODEL"), + "prompt": "Default test prompt", + }, + }, + } + + vision, err := New(cfg) + assert.NoError(t, err) + + // Use base64 data with custom prompt + result, err := vision.Analyze(context.Background(), "data:image/png;base64,"+testImageBase64, "Custom test prompt") + assert.NoError(t, err) + assert.NotNil(t, result) + assert.NotEmpty(t, result.Description) + }) + + t.Run("Analyze Image with Empty Custom Prompt", func(t *testing.T) { + // Create vision service with default prompt + cfg := &driver.Config{ + Storage: driver.StorageConfig{ + Driver: "local", + Options: map[string]interface{}{ + "path": "/__vision_test", + "compression": true, + }, + }, + Model: driver.ModelConfig{ + Driver: "openai", + Options: map[string]interface{}{ + "api_key": os.Getenv("OPENAI_API_KEY"), + "model": os.Getenv("VISION_MODEL"), + "prompt": "Default test prompt", + }, + }, + } + + vision, err := New(cfg) + assert.NoError(t, err) + + // Use base64 data with empty prompt (should use default) + result, err := vision.Analyze(context.Background(), "data:image/png;base64,"+testImageBase64, "") + assert.NoError(t, err) + assert.NotNil(t, result) + assert.NotEmpty(t, result.Description) + }) } func createTestVision(baseURL string) (*Vision, error) { From ca78c40293667ce9bc07708812d6dfff8bb7effb Mon Sep 17 00:00:00 2001 From: Max Date: Mon, 6 Jan 2025 14:04:12 +0800 Subject: [PATCH 04/11] Refactor attachment handling and enhance RAG/Vision integration in tests - Updated the TestUpload and TestUploadWithRAG functions to remove unnecessary parameters and improve clarity. - Enhanced RAG handling by adding options for enabling/disabling RAG processing during file uploads. - Improved test cases for RAG functionality, ensuring proper indexing behavior based on RAG settings. - Refactored vision-related tests to include options for enabling/disabling vision processing, ensuring accurate responses based on the model capabilities. - Streamlined the setup of test assistants to simplify the testing process. These changes improve the maintainability and clarity of the attachment handling code, paving the way for better functionality in file uploads and processing. --- neo/assistant/attachment.go | 14 +++++- neo/assistant/attachment_test.go | 83 +++++++++++++++++++++----------- 2 files changed, 66 insertions(+), 31 deletions(-) diff --git a/neo/assistant/attachment.go b/neo/assistant/attachment.go index 7cd1f270e1..61defd8846 100644 --- a/neo/assistant/attachment.go +++ b/neo/assistant/attachment.go @@ -81,7 +81,7 @@ func (ast *Assistant) Upload(ctx context.Context, file *multipart.FileHeader, re } // Handle RAG if available - if err := ast.handleRAG(ctx, fileResp, reader); err != nil { + if err := ast.handleRAG(ctx, fileResp, reader, option); err != nil { return nil, fmt.Errorf("RAG handling error: %s", err.Error()) } @@ -112,11 +112,16 @@ func (ast *Assistant) generateFileID(filename string, sid string, chatID string) } // handleRAG handles the file with RAG if available -func (ast *Assistant) handleRAG(ctx context.Context, file *File, reader io.Reader) error { +func (ast *Assistant) handleRAG(ctx context.Context, file *File, reader io.Reader, option map[string]interface{}) error { if rag == nil { return nil } + // Check if RAG processing is enabled + if option, ok := option["rag"].(bool); !ok || !option { + return nil + } + // Only handle text-based files if !strings.HasPrefix(file.ContentType, "text/") { return nil @@ -196,6 +201,11 @@ func (ast *Assistant) handleVision(ctx context.Context, file *File, option map[s return nil } + // Check if Vision processing is enabled + if option, ok := option["vision"].(bool); !ok || !option { + return nil + } + // Check if file is an image if !strings.HasPrefix(file.ContentType, "image/") { return nil diff --git a/neo/assistant/attachment_test.go b/neo/assistant/attachment_test.go index 54a60e9e0c..17a7d3b8f0 100644 --- a/neo/assistant/attachment_test.go +++ b/neo/assistant/attachment_test.go @@ -30,7 +30,7 @@ func TestUpload(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() - ast := setupTestAssistant(t) + ast := setupTestAssistant() ctx := context.Background() t.Run("Basic File Upload", func(t *testing.T) { @@ -90,15 +90,15 @@ func TestUploadWithRAG(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() - ast := setupTestAssistant(t) + ast := setupTestAssistant() ragEngine, ragUploader, ragVectorizer := setupTestRAG(t) SetRAG(ragEngine, ragUploader, ragVectorizer, RAGSetting{IndexPrefix: "test_"}) defer func() { - rag = nil // Completely reset the global rag variable + rag = nil }() ctx := context.Background() - t.Run("Text File with RAG", func(t *testing.T) { + t.Run("Text File with RAG Enabled", func(t *testing.T) { content := []byte("This is a test document for RAG indexing") file := &multipart.FileHeader{ Filename: "test.txt", @@ -111,6 +111,7 @@ func TestUploadWithRAG(t *testing.T) { fileResp, err := ast.Upload(ctx, file, reader, map[string]interface{}{ "sid": "test-user", "chat_id": "test-chat", + "rag": true, }) assert.NoError(t, err) @@ -120,33 +121,35 @@ func TestUploadWithRAG(t *testing.T) { // Wait for indexing to complete time.Sleep(500 * time.Millisecond) - // Verify the file was indexed by checking if it exists in RAG + // Verify the file was indexed exists, err := ragEngine.HasDocument(ctx, "test_test-assistant-test-user-test-chat", fileResp.DocIDs[0]) assert.NoError(t, err) assert.True(t, exists, "Document should exist in RAG index") }) - t.Run("Non-Text File with RAG", func(t *testing.T) { - imgData, _ := base64.StdEncoding.DecodeString(testImageBase64) + t.Run("Text File with RAG Disabled", func(t *testing.T) { + content := []byte("This is a test document with RAG disabled") file := &multipart.FileHeader{ - Filename: "test.png", - Size: int64(len(imgData)), + Filename: "test.txt", + Size: int64(len(content)), } file.Header = make(map[string][]string) - file.Header.Set("Content-Type", "image/png") + file.Header.Set("Content-Type", "text/plain") + + reader := bytes.NewReader(content) + fileResp, err := ast.Upload(ctx, file, reader, map[string]interface{}{ + "sid": "test-user", + "chat_id": "test-chat", + "rag": false, + }) - reader := bytes.NewReader(imgData) - fileResp, err := ast.Upload(ctx, file, reader, nil) assert.NoError(t, err) assert.NotNil(t, fileResp) - // Verify the file was not indexed - exists, err := ragEngine.HasDocument(ctx, "test_test-assistant", fileResp.ID) - assert.NoError(t, err) - assert.False(t, exists) + assert.Empty(t, fileResp.DocIDs, "Document IDs should be empty when RAG is disabled") }) } -func setupTestAssistant(t *testing.T) *Assistant { +func setupTestAssistant() *Assistant { ast := &Assistant{ ID: "test-assistant", Name: "Test Assistant", @@ -246,15 +249,37 @@ func TestUploadWithVision(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() - ast := setupTestAssistant(t) + ast := setupTestAssistant() vision := setupTestVision(t) SetVision(vision) defer func() { - vision = nil // Completely reset the global vision variable + vision = nil }() ctx := context.Background() - t.Run("Image with Vision-Capable Model", func(t *testing.T) { + t.Run("Image File with Vision Enabled", func(t *testing.T) { + imgData, _ := base64.StdEncoding.DecodeString(testImageBase64) + file := &multipart.FileHeader{ + Filename: "test.png", + Size: int64(len(imgData)), + } + file.Header = make(map[string][]string) + file.Header.Set("Content-Type", "image/png") + + reader := bytes.NewReader(imgData) + fileResp, err := ast.Upload(ctx, file, reader, map[string]interface{}{ + "vision": true, + "model": "gpt-4-vision-preview", + }) + + assert.NoError(t, err) + assert.NotNil(t, fileResp) + if fileResp.URL == "" && fileResp.Description == "" { + t.Error("Either URL or Description should be set when vision is enabled") + } + }) + + t.Run("Image File with Vision Disabled", func(t *testing.T) { imgData, _ := base64.StdEncoding.DecodeString(testImageBase64) file := &multipart.FileHeader{ Filename: "test.png", @@ -265,16 +290,16 @@ func TestUploadWithVision(t *testing.T) { reader := bytes.NewReader(imgData) fileResp, err := ast.Upload(ctx, file, reader, map[string]interface{}{ - "model": "gpt-4-vision-preview", + "vision": false, }) assert.NoError(t, err) assert.NotNil(t, fileResp) - assert.NotEmpty(t, fileResp.URL) - assert.Empty(t, fileResp.Description) + assert.Empty(t, fileResp.URL, "Vision URL should be empty when vision is disabled") + assert.Empty(t, fileResp.Description, "Vision Description should be empty when vision is disabled") }) - t.Run("Image with Non-Vision Model", func(t *testing.T) { + t.Run("Image File with Non-Vision Model", func(t *testing.T) { imgData, _ := base64.StdEncoding.DecodeString(testImageBase64) file := &multipart.FileHeader{ Filename: "test.png", @@ -285,14 +310,14 @@ func TestUploadWithVision(t *testing.T) { reader := bytes.NewReader(imgData) fileResp, err := ast.Upload(ctx, file, reader, map[string]interface{}{ - "model": "gpt-4", - "vision_prompt": "What's in this image?", + "vision": true, + "model": "gpt-4", }) assert.NoError(t, err) assert.NotNil(t, fileResp) - assert.Empty(t, fileResp.URL) - assert.NotEmpty(t, fileResp.Description) + assert.Empty(t, fileResp.URL, "Vision URL should be empty for non-vision models") + assert.NotEmpty(t, fileResp.Description, "Vision Description should be set for non-vision models") }) } @@ -300,7 +325,7 @@ func TestDownload(t *testing.T) { test.Prepare(t, config.Conf) defer test.Clean() - ast := setupTestAssistant(t) + ast := setupTestAssistant() ctx := context.Background() t.Run("Download Existing File", func(t *testing.T) { From 90967f58c14b71aa96154799970ca37f5cc2d5a3 Mon Sep 17 00:00:00 2001 From: Max Date: Mon, 6 Jan 2025 14:20:39 +0800 Subject: [PATCH 05/11] Refactor Neo API context handling and improve context management - Replaced the custom context creation functions with a new context package, enhancing consistency and maintainability across the API. - Updated multiple methods in the DSL to utilize the new context package, ensuring a unified approach to context management. - Removed the obsolete context.go file, streamlining the codebase and reducing redundancy. These changes improve the overall structure and clarity of the Neo API, paving the way for future enhancements in context handling and assistant functionalities. --- neo/api.go | 15 ++++---- neo/assistant/hooks.go | 1 + neo/context.go | 49 ------------------------- neo/context/context.go | 82 ++++++++++++++++++++++++++++++++++++++++++ neo/hooks.go | 7 ++-- neo/neo.go | 27 +++++++------- neo/types.go | 18 ---------- 7 files changed, 109 insertions(+), 90 deletions(-) create mode 100644 neo/assistant/hooks.go delete mode 100644 neo/context.go create mode 100644 neo/context/context.go diff --git a/neo/api.go b/neo/api.go index aebe7bad3e..a67ad1514e 100644 --- a/neo/api.go +++ b/neo/api.go @@ -15,6 +15,7 @@ import ( "github.com/yaoapp/gou/connector" "github.com/yaoapp/gou/process" "github.com/yaoapp/yao/helper" + chatctx "github.com/yaoapp/yao/neo/context" "github.com/yaoapp/yao/neo/message" "github.com/yaoapp/yao/neo/store" ) @@ -173,7 +174,7 @@ func (neo *DSL) handleUpload(c *gin.Context) { } // Set the context - ctx, cancel := NewContextWithCancel(sid, c.Query("chat_id"), "") + ctx, cancel := chatctx.NewWithCancel(sid, c.Query("chat_id"), "") defer cancel() // Upload the file @@ -214,7 +215,7 @@ func (neo *DSL) handleChat(c *gin.Context) { } // Set the context with validated chat_id - ctx, cancel := NewContextWithCancel(sid, chatID, c.Query("context")) + ctx, cancel := chatctx.NewWithCancel(sid, chatID, c.Query("context")) defer cancel() neo.Answer(ctx, content, c) @@ -297,7 +298,7 @@ func (neo *DSL) handleDownload(c *gin.Context) { } // Set the context - ctx, cancel := NewContextWithCancel(sid, c.Query("chat_id"), "") + ctx, cancel := chatctx.NewWithCancel(sid, c.Query("chat_id"), "") defer cancel() // Download the file @@ -537,7 +538,7 @@ func (neo *DSL) handleChatUpdate(c *gin.Context) { // If content is not empty, Generate the chat title if body.Content != "" { - ctx, cancel := NewContextWithCancel(sid, c.Query("chat_id"), "") + ctx, cancel := chatctx.NewWithCancel(sid, c.Query("chat_id"), "") defer cancel() title, err := neo.GenerateChatTitle(ctx, body.Content, c, true) @@ -729,7 +730,7 @@ func (neo *DSL) handleGenerateTitle(c *gin.Context) { return } - ctx, cancel := NewContextWithCancel(resp.sid, c.Query("chat_id"), "") + ctx, cancel := chatctx.NewWithCancel(resp.sid, c.Query("chat_id"), "") defer cancel() // Use silent mode for regular HTTP requests, streaming for SSE @@ -780,7 +781,7 @@ func (neo *DSL) handleGeneratePrompts(c *gin.Context) { return } - ctx, cancel := NewContextWithCancel(resp.sid, c.Query("chat_id"), "") + ctx, cancel := chatctx.NewWithCancel(resp.sid, c.Query("chat_id"), "") defer cancel() // Use silent mode for regular HTTP requests, streaming for SSE @@ -831,7 +832,7 @@ func (neo *DSL) handleGenerateCustom(c *gin.Context) { return } - ctx, cancel := NewContextWithCancel(resp.sid, c.Query("chat_id"), "") + ctx, cancel := chatctx.NewWithCancel(resp.sid, c.Query("chat_id"), "") defer cancel() // Use silent mode for regular HTTP requests, streaming for SSE diff --git a/neo/assistant/hooks.go b/neo/assistant/hooks.go new file mode 100644 index 0000000000..31df51cde2 --- /dev/null +++ b/neo/assistant/hooks.go @@ -0,0 +1 @@ +package assistant diff --git a/neo/context.go b/neo/context.go deleted file mode 100644 index ad848c95d0..0000000000 --- a/neo/context.go +++ /dev/null @@ -1,49 +0,0 @@ -package neo - -import ( - "context" - "time" - - jsoniter "github.com/json-iterator/go" - "github.com/yaoapp/kun/log" -) - -// NewContext create a new context -func NewContext(sid, cid, payload string) Context { - ctx := Context{Context: context.Background(), Sid: sid, ChatID: cid} - if payload == "" { - return ctx - } - - err := jsoniter.Unmarshal([]byte(payload), &ctx) - if err != nil { - log.Error("%s", err.Error()) - } - return ctx -} - -// NewContextWithCancel create a new context with cancel -func NewContextWithCancel(sid, cid, payload string) (Context, context.CancelFunc) { - ctx := NewContext(sid, cid, payload) - return ContextWithCancel(ctx) -} - -// NewContextWithTimeout create a new context with timeout -func NewContextWithTimeout(sid, cid, payload string, timeout time.Duration) (Context, context.CancelFunc) { - ctx := NewContext(sid, cid, payload) - return ContextWithTimeout(ctx, timeout) -} - -// ContextWithCancel create a new context -func ContextWithCancel(parent Context) (Context, context.CancelFunc) { - new, cancel := context.WithCancel(parent.Context) - parent.Context = new - return parent, cancel -} - -// ContextWithTimeout create a new context -func ContextWithTimeout(parent Context, timeout time.Duration) (Context, context.CancelFunc) { - new, cancel := context.WithTimeout(parent.Context, timeout) - parent.Context = new - return parent, cancel -} diff --git a/neo/context/context.go b/neo/context/context.go new file mode 100644 index 0000000000..79b9ff4bd7 --- /dev/null +++ b/neo/context/context.go @@ -0,0 +1,82 @@ +package context + +import ( + "context" + "time" + + jsoniter "github.com/json-iterator/go" + "github.com/yaoapp/kun/log" +) + +// Context the context +type Context struct { + context.Context + Sid string `json:"sid" yaml:"-"` // Session ID + ChatID string `json:"chat_id,omitempty"` // Chat ID, use to select chat + AssistantID string `json:"assistant_id,omitempty"` // Assistant ID, use to select assistant + Stack string `json:"stack,omitempty"` + Path string `json:"pathname,omitempty"` + FormData map[string]interface{} `json:"formdata,omitempty"` + Field *Field `json:"field,omitempty"` + Namespace string `json:"namespace,omitempty"` + Config map[string]interface{} `json:"config,omitempty"` + Signal interface{} `json:"signal,omitempty"` + Upload *FileUpload `json:"upload,omitempty"` +} + +// Field the context field +type Field struct { + Name string `json:"name,omitempty"` + Type string `json:"type,omitempty"` + Bind string `json:"bind,omitempty"` + Props map[string]interface{} `json:"props,omitempty"` + Children []interface{} `json:"children,omitempty"` +} + +// FileUpload the file upload +type FileUpload struct { + Name string `json:"name,omitempty"` + Type string `json:"type,omitempty"` + Size int64 `json:"size,omitempty"` + TempFile string `json:"temp_file,omitempty"` +} + +// New create a new context +func New(sid, cid, payload string) Context { + ctx := Context{Context: context.Background(), Sid: sid, ChatID: cid} + if payload == "" { + return ctx + } + + err := jsoniter.Unmarshal([]byte(payload), &ctx) + if err != nil { + log.Error("%s", err.Error()) + } + return ctx +} + +// NewWithCancel create a new context with cancel +func NewWithCancel(sid, cid, payload string) (Context, context.CancelFunc) { + ctx := New(sid, cid, payload) + return WithCancel(ctx) +} + +// NewWithTimeout create a new context with timeout +func NewWithTimeout(sid, cid, payload string, timeout time.Duration) (Context, context.CancelFunc) { + ctx := New(sid, cid, payload) + return WithTimeout(ctx, timeout) +} + +// WithCancel create a new context +func WithCancel(parent Context) (Context, context.CancelFunc) { + new, cancel := context.WithCancel(parent.Context) + parent.Context = new + return parent, cancel +} + +// WithTimeout create a new context +func WithTimeout(parent Context, timeout time.Duration) (Context, context.CancelFunc) { + new, cancel := context.WithTimeout(parent.Context, timeout) + parent.Context = new + return parent, cancel +} diff --git a/neo/hooks.go b/neo/hooks.go index 4ec7dabb6c..9a3b6e58b6 100644 --- a/neo/hooks.go +++ b/neo/hooks.go @@ -7,10 +7,11 @@ import ( "github.com/gin-gonic/gin" jsoniter "github.com/json-iterator/go" "github.com/yaoapp/gou/process" + chatctx "github.com/yaoapp/yao/neo/context" ) // HookCreate create the assistant -func (neo *DSL) HookCreate(ctx Context, messages []map[string]interface{}, c *gin.Context) (CreateResponse, error) { +func (neo *DSL) HookCreate(ctx chatctx.Context, messages []map[string]interface{}, c *gin.Context) (CreateResponse, error) { // Default assistant assistantID := neo.Use @@ -69,7 +70,7 @@ func (neo *DSL) HookCreate(ctx Context, messages []map[string]interface{}, c *gi } // HookPrepare executes the prepare hook before AI is called -func (neo *DSL) HookPrepare(ctx Context, messages []map[string]interface{}) ([]map[string]interface{}, error) { +func (neo *DSL) HookPrepare(ctx chatctx.Context, messages []map[string]interface{}) ([]map[string]interface{}, error) { if neo.Prepare == "" { return messages, nil } @@ -114,7 +115,7 @@ func (neo *DSL) HookPrepare(ctx Context, messages []map[string]interface{}) ([]m } // HookWrite executes the write hook when response is received from AI -func (neo *DSL) HookWrite(ctx Context, messages []map[string]interface{}, response map[string]interface{}, content string, writer *gin.ResponseWriter) ([]map[string]interface{}, error) { +func (neo *DSL) HookWrite(ctx chatctx.Context, messages []map[string]interface{}, response map[string]interface{}, content string, writer *gin.ResponseWriter) ([]map[string]interface{}, error) { if neo.Write == "" { return []map[string]interface{}{response}, nil } diff --git a/neo/neo.go b/neo/neo.go index f9d56b355d..f374bc1d32 100644 --- a/neo/neo.go +++ b/neo/neo.go @@ -9,6 +9,7 @@ import ( "github.com/gin-gonic/gin" "github.com/yaoapp/kun/log" "github.com/yaoapp/yao/neo/assistant" + chatctx "github.com/yaoapp/yao/neo/context" "github.com/yaoapp/yao/neo/message" ) @@ -16,7 +17,7 @@ import ( var lock sync.Mutex = sync.Mutex{} // Answer reply the message -func (neo *DSL) Answer(ctx Context, question string, c *gin.Context) error { +func (neo *DSL) Answer(ctx chatctx.Context, question string, c *gin.Context) error { messages, err := neo.chatMessages(ctx, question) if err != nil { msg := message.New().Error(err).Done() @@ -50,7 +51,7 @@ func (neo *DSL) Select(id string) (assistant.API, error) { } // GeneratePrompts generate prompts for the AI assistant -func (neo *DSL) GeneratePrompts(ctx Context, input string, c *gin.Context, silent ...bool) (string, error) { +func (neo *DSL) GeneratePrompts(ctx chatctx.Context, input string, c *gin.Context, silent ...bool) (string, error) { prompts := ` Optimize the prompts for the AI assistant 1. Optimize prompts based on the user's input @@ -68,7 +69,7 @@ func (neo *DSL) GeneratePrompts(ctx Context, input string, c *gin.Context, silen } // GenerateChatTitle generate the chat title -func (neo *DSL) GenerateChatTitle(ctx Context, input string, c *gin.Context, silent ...bool) (string, error) { +func (neo *DSL) GenerateChatTitle(ctx chatctx.Context, input string, c *gin.Context, silent ...bool) (string, error) { prompts := ` Help me generate a title for the chat 1. The title should be a short and concise description of the chat. @@ -84,7 +85,7 @@ func (neo *DSL) GenerateChatTitle(ctx Context, input string, c *gin.Context, sil } // GenerateWithAI generate content with AI, type can be "title", "prompts", etc. -func (neo *DSL) GenerateWithAI(ctx Context, input string, messageType string, systemPrompt string, c *gin.Context, silent bool) (string, error) { +func (neo *DSL) GenerateWithAI(ctx chatctx.Context, input string, messageType string, systemPrompt string, c *gin.Context, silent bool) (string, error) { messages := []map[string]interface{}{ {"role": "system", "content": systemPrompt}, { @@ -187,7 +188,7 @@ func (neo *DSL) GenerateWithAI(ctx Context, input string, messageType string, sy } // Upload upload a file -func (neo *DSL) Upload(ctx Context, c *gin.Context) (*assistant.File, error) { +func (neo *DSL) Upload(ctx chatctx.Context, c *gin.Context) (*assistant.File, error) { // Get the file tmpfile, err := c.FormFile("file") if err != nil { @@ -212,11 +213,11 @@ func (neo *DSL) Upload(ctx Context, c *gin.Context) (*assistant.File, error) { } // Get file info - ctx.Upload = &FileUpload{ - Bytes: int(tmpfile.Size), - Name: tmpfile.Filename, - ContentType: tmpfile.Header.Get("Content-Type"), - Option: option, + ctx.Upload = &chatctx.FileUpload{ + Name: tmpfile.Filename, + Type: tmpfile.Header.Get("Content-Type"), + Size: tmpfile.Size, + TempFile: tmpfile.Filename, } // Default use the assistant in context @@ -235,7 +236,7 @@ func (neo *DSL) Upload(ctx Context, c *gin.Context) (*assistant.File, error) { } // Download downloads a file -func (neo *DSL) Download(ctx Context, c *gin.Context) (*assistant.FileResponse, error) { +func (neo *DSL) Download(ctx chatctx.Context, c *gin.Context) (*assistant.FileResponse, error) { // Get file_id from query string fileID := c.Query("file_id") if fileID == "" { @@ -259,7 +260,7 @@ func (neo *DSL) Download(ctx Context, c *gin.Context) (*assistant.FileResponse, } // chat chat with AI -func (neo *DSL) chat(ast assistant.API, ctx Context, messages []map[string]interface{}, c *gin.Context) error { +func (neo *DSL) chat(ast assistant.API, ctx chatctx.Context, messages []map[string]interface{}, c *gin.Context) error { if ast == nil { msg := message.New().Error("assistant is not initialized").Done() msg.Write(c.Writer) @@ -339,7 +340,7 @@ func (neo *DSL) chat(ast assistant.API, ctx Context, messages []map[string]inter } // chatMessages get the chat messages -func (neo *DSL) chatMessages(ctx Context, content ...string) ([]map[string]interface{}, error) { +func (neo *DSL) chatMessages(ctx chatctx.Context, content ...string) ([]map[string]interface{}, error) { history, err := neo.Store.GetHistory(ctx.Sid, ctx.ChatID) if err != nil { diff --git a/neo/types.go b/neo/types.go index e351116866..bbf614d2dd 100644 --- a/neo/types.go +++ b/neo/types.go @@ -1,8 +1,6 @@ package neo import ( - "context" - "github.com/gin-gonic/gin" "github.com/yaoapp/yao/neo/assistant" "github.com/yaoapp/yao/neo/rag" @@ -48,22 +46,6 @@ type Mention struct { Type string `json:"type,omitempty"` } -// Context the context -type Context struct { - Sid string `json:"sid" yaml:"-"` // Session ID - ChatID string `json:"chat_id,omitempty"` // Chat ID, use to select chat - AssistantID string `json:"assistant_id,omitempty"` // Assistant ID, use to select assistant - Stack string `json:"stack,omitempty"` - Path string `json:"pathname,omitempty"` - FormData map[string]interface{} `json:"formdata,omitempty"` - Field *Field `json:"field,omitempty"` - Namespace string `json:"namespace,omitempty"` - Config map[string]interface{} `json:"config,omitempty"` - Signal interface{} `json:"signal,omitempty"` - Upload *FileUpload `json:"upload,omitempty"` - context.Context `json:"-" yaml:"-"` -} - // Field the context field type Field struct { Name string `json:"name,omitempty"` From b01f27d70f0b75c66eb64764ae3a60c3fc50becf Mon Sep 17 00:00:00 2001 From: Max Date: Mon, 13 Jan 2025 11:07:12 +0800 Subject: [PATCH 06/11] Refactor assistant initialization and enhance context handling in Neo API - Updated the Answer method to improve assistant initialization by directly calling the new HookInit method, streamlining the process of selecting and initializing assistants based on context. - Introduced ResHookInit struct to encapsulate the response from the assistant initialization hook, enhancing clarity and maintainability. - Enhanced the context package by adding a Map method to facilitate easier mapping of context data, improving the overall structure of context management. - Refactored the assistant interface to include the new HookInit method, ensuring a consistent approach to assistant interactions. These changes improve the robustness and maintainability of the Neo API, paving the way for future enhancements in assistant functionalities and context management. --- neo/assistant/hooks.go | 73 ++++++++++++++++++++++++++++++++++++++++++ neo/assistant/types.go | 3 ++ neo/context/context.go | 40 +++++++++++++++++++++++ neo/neo.go | 27 +++++++++++----- 4 files changed, 135 insertions(+), 8 deletions(-) diff --git a/neo/assistant/hooks.go b/neo/assistant/hooks.go index 31df51cde2..b9c8fb7b40 100644 --- a/neo/assistant/hooks.go +++ b/neo/assistant/hooks.go @@ -1 +1,74 @@ package assistant + +import ( + "fmt" + + chatctx "github.com/yaoapp/yao/neo/context" + "github.com/yaoapp/yao/neo/message" +) + +const ( + // HookErrorMethodNotFound is the error message for method not found + HookErrorMethodNotFound = "method not found" +) + +// ResHookInit the response of the init hook +type ResHookInit struct { + AssistantID string `json:"assistant_id,omitempty"` + ChatID string `json:"chat_id,omitempty"` +} + +// HookInit initialize the assistant +func (ast *Assistant) HookInit(context chatctx.Context, messages []message.Message) (*ResHookInit, error) { + v, err := ast.call("Init", context, messages) + if err != nil { + if err.Error() == HookErrorMethodNotFound { + return nil, nil + } + return nil, err + } + + response := &ResHookInit{} + switch v := v.(type) { + case map[string]interface{}: + if res, ok := v["assistant_id"].(string); ok { + response.AssistantID = res + } + if res, ok := v["chat_id"].(string); ok { + response.ChatID = res + } + + case string: + response.AssistantID = v + response.ChatID = context.ChatID + + case nil: + response.AssistantID = ast.ID + response.ChatID = context.ChatID + } + + return response, nil +} + +// Call the script method +func (ast *Assistant) call(method string, context chatctx.Context, args ...any) (interface{}, error) { + + if ast.Script == nil { + return nil, nil + } + + ctx, err := ast.Script.NewContext(context.Sid, nil) + if err != nil { + return nil, err + } + defer ctx.Close() + + // Check if the method exists + if !ctx.Global().Has(method) { + return nil, fmt.Errorf(HookErrorMethodNotFound) + } + + // Call the method + args = append([]interface{}{context.Map()}, args...) + return ctx.Call(method, args...) +} diff --git a/neo/assistant/types.go b/neo/assistant/types.go index 2002ae24a0..71dddf1943 100644 --- a/neo/assistant/types.go +++ b/neo/assistant/types.go @@ -7,6 +7,8 @@ import ( "github.com/yaoapp/gou/rag/driver" v8 "github.com/yaoapp/gou/runtime/v8" + chatctx "github.com/yaoapp/yao/neo/context" + "github.com/yaoapp/yao/neo/message" api "github.com/yaoapp/yao/openai" ) @@ -16,6 +18,7 @@ type API interface { Upload(ctx context.Context, file *multipart.FileHeader, reader io.Reader, option map[string]interface{}) (*File, error) Download(ctx context.Context, fileID string) (*FileResponse, error) ReadBase64(ctx context.Context, fileID string) (string, error) + HookInit(ctx chatctx.Context, messages []message.Message) (*ResHookInit, error) } // RAG the RAG interface diff --git a/neo/context/context.go b/neo/context/context.go index 79b9ff4bd7..da0efd5141 100644 --- a/neo/context/context.go +++ b/neo/context/context.go @@ -80,3 +80,43 @@ func WithTimeout(parent Context, timeout time.Duration) (Context, context.Cancel parent.Context = new return parent, cancel } + +// Map the context to a map +func (ctx *Context) Map() map[string]interface{} { + data := map[string]interface{}{ + "sid": ctx.Sid, + } + + if ctx.ChatID != "" { + data["chat_id"] = ctx.ChatID + } + if ctx.AssistantID != "" { + data["assistant_id"] = ctx.AssistantID + } + if ctx.Stack != "" { + data["stack"] = ctx.Stack + } + if ctx.Path != "" { + data["pathname"] = ctx.Path + } + if len(ctx.FormData) > 0 { + data["formdata"] = ctx.FormData + } + if ctx.Field != nil { + data["field"] = ctx.Field + } + if ctx.Namespace != "" { + data["namespace"] = ctx.Namespace + } + if len(ctx.Config) > 0 { + data["config"] = ctx.Config + } + if ctx.Signal != nil { + data["signal"] = ctx.Signal + } + if ctx.Upload != nil { + data["upload"] = ctx.Upload + } + + return data +} diff --git a/neo/neo.go b/neo/neo.go index f374bc1d32..884f4a68ff 100644 --- a/neo/neo.go +++ b/neo/neo.go @@ -25,19 +25,30 @@ func (neo *DSL) Answer(ctx chatctx.Context, question string, c *gin.Context) err return err } - // Get the assistant_id, chat_id - res, err := neo.HookCreate(ctx, messages, c) - if err != nil { - msg := message.New().Error(err).Done() - msg.Write(c.Writer) - return err + var res *assistant.ResHookInit = nil + var ast assistant.API = neo.Assistant + + if ctx.AssistantID != "" { + ast, err = neo.Select(ctx.AssistantID) + if err != nil { + return err + } } - // Select Assistant - ast, err := neo.Select(res.AssistantID) + // Init the assistant + res, err = ast.HookInit(ctx, []message.Message{{Text: question}}) if err != nil { return err } + + // Switch to the new assistant if necessary + if res.AssistantID != ctx.AssistantID { + ast, err = neo.Select(res.AssistantID) + if err != nil { + return err + } + } + // Chat with AI return neo.chat(ast, ctx, messages, c) } From ffb9ede647391c94e040a0748fa11b8c9ff9b26f Mon Sep 17 00:00:00 2001 From: Max Date: Mon, 13 Jan 2025 14:47:17 +0800 Subject: [PATCH 07/11] Refactor Neo API message handling and enhance assistant interaction - Updated the Answer method to utilize the new withHistory function, improving message retrieval and context management. - Refactored the chat method to accept a slice of message.Message instead of a map, enhancing type safety and clarity. - Introduced a new withHistory function to streamline the process of retrieving chat history and user messages, improving maintainability. - Enhanced the HookInit method to accept a gin.Context, allowing for better context handling during assistant initialization. - Updated the assistant API interface to reflect changes in message handling, ensuring consistency across the codebase. These changes improve the robustness and maintainability of the Neo API, paving the way for future enhancements in assistant functionalities and message management. --- neo/assistant/api.go | 22 +++++++-------- neo/assistant/hooks.go | 49 +++++++++++++++++++++++++------- neo/assistant/types.go | 5 ++-- neo/message/message.go | 64 ++++++++++++++++++++++++++++++++++++++++++ neo/neo.go | 31 ++++++++++---------- 5 files changed, 133 insertions(+), 38 deletions(-) diff --git a/neo/assistant/api.go b/neo/assistant/api.go index 8481ed9d69..685e1bf2c8 100644 --- a/neo/assistant/api.go +++ b/neo/assistant/api.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/yaoapp/gou/fs" + "github.com/yaoapp/yao/neo/message" chatMessage "github.com/yaoapp/yao/neo/message" ) @@ -41,7 +42,7 @@ func GetByConnector(connector string, name string) (*Assistant, error) { } // Chat implements the chat functionality -func (ast *Assistant) Chat(ctx context.Context, messages []map[string]interface{}, option map[string]interface{}, cb func(data []byte) int) error { +func (ast *Assistant) Chat(ctx context.Context, messages []message.Message, option map[string]interface{}, cb func(data []byte) int) error { if ast.openai == nil { return fmt.Errorf("openai is not initialized") } @@ -59,13 +60,12 @@ func (ast *Assistant) Chat(ctx context.Context, messages []map[string]interface{ return nil } -func (ast *Assistant) requestMessages(ctx context.Context, messages []map[string]interface{}) ([]map[string]interface{}, error) { +func (ast *Assistant) requestMessages(ctx context.Context, messages []message.Message) ([]map[string]interface{}, error) { newMessages := []map[string]interface{}{} - // With Prompts if ast.Prompts != nil { for _, prompt := range ast.Prompts { - message := map[string]interface{}{ + msg := map[string]interface{}{ "role": prompt.Role, "content": prompt.Content, } @@ -75,20 +75,20 @@ func (ast *Assistant) requestMessages(ctx context.Context, messages []map[string name = prompt.Name } - message["name"] = name - newMessages = append(newMessages, message) + msg["name"] = name + newMessages = append(newMessages, msg) } } length := len(messages) for index, message := range messages { - role, ok := message["role"].(string) - if !ok { + role := message.Role + if role == "" { return nil, fmt.Errorf("role must be string") } - content, ok := message["content"].(string) - if !ok { + content := message.Text + if content == "" { return nil, fmt.Errorf("content must be string") } @@ -97,7 +97,7 @@ func (ast *Assistant) requestMessages(ctx context.Context, messages []map[string "content": content, } - if name, ok := message["name"].(string); ok { + if name := message.Name; name != "" { newMessage["name"] = name } diff --git a/neo/assistant/hooks.go b/neo/assistant/hooks.go index b9c8fb7b40..57f1329b31 100644 --- a/neo/assistant/hooks.go +++ b/neo/assistant/hooks.go @@ -1,8 +1,11 @@ package assistant import ( + "context" "fmt" + "time" + "github.com/gin-gonic/gin" chatctx "github.com/yaoapp/yao/neo/context" "github.com/yaoapp/yao/neo/message" ) @@ -19,8 +22,12 @@ type ResHookInit struct { } // HookInit initialize the assistant -func (ast *Assistant) HookInit(context chatctx.Context, messages []message.Message) (*ResHookInit, error) { - v, err := ast.call("Init", context, messages) +func (ast *Assistant) HookInit(c *gin.Context, context chatctx.Context, messages []message.Message) (*ResHookInit, error) { + // Create timeout context + ctx, cancel := ast.createTimeoutContext(c) + defer cancel() + + v, err := ast.call(ctx, "Init", context, messages, c.Writer) if err != nil { if err.Error() == HookErrorMethodNotFound { return nil, nil @@ -50,25 +57,47 @@ func (ast *Assistant) HookInit(context chatctx.Context, messages []message.Messa return response, nil } -// Call the script method -func (ast *Assistant) call(method string, context chatctx.Context, args ...any) (interface{}, error) { +// createTimeoutContext creates a timeout context with 5 seconds timeout +func (ast *Assistant) createTimeoutContext(c *gin.Context) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Second) + return ctx, cancel +} +// Call the script method +func (ast *Assistant) call(ctx context.Context, method string, context chatctx.Context, args ...any) (interface{}, error) { if ast.Script == nil { return nil, nil } - ctx, err := ast.Script.NewContext(context.Sid, nil) + scriptCtx, err := ast.Script.NewContext(context.Sid, nil) if err != nil { return nil, err } - defer ctx.Close() + defer scriptCtx.Close() // Check if the method exists - if !ctx.Global().Has(method) { + if !scriptCtx.Global().Has(method) { return nil, fmt.Errorf(HookErrorMethodNotFound) } - // Call the method - args = append([]interface{}{context.Map()}, args...) - return ctx.Call(method, args...) + // Create done channel for handling cancellation + done := make(chan struct{}) + var result interface{} + var callErr error + + go func() { + defer close(done) + // Call the method + args = append([]interface{}{context.Map()}, args...) + result, callErr = scriptCtx.Call(method, args...) + }() + + // Wait for either context cancellation or method completion + select { + case <-ctx.Done(): + scriptCtx.Close() // Force close the script context + return nil, ctx.Err() + case <-done: + return result, callErr + } } diff --git a/neo/assistant/types.go b/neo/assistant/types.go index 71dddf1943..54bc47f8e7 100644 --- a/neo/assistant/types.go +++ b/neo/assistant/types.go @@ -5,6 +5,7 @@ import ( "io" "mime/multipart" + "github.com/gin-gonic/gin" "github.com/yaoapp/gou/rag/driver" v8 "github.com/yaoapp/gou/runtime/v8" chatctx "github.com/yaoapp/yao/neo/context" @@ -14,11 +15,11 @@ import ( // API the assistant API interface type API interface { - Chat(ctx context.Context, messages []map[string]interface{}, option map[string]interface{}, cb func(data []byte) int) error + Chat(ctx context.Context, messages []message.Message, option map[string]interface{}, cb func(data []byte) int) error Upload(ctx context.Context, file *multipart.FileHeader, reader io.Reader, option map[string]interface{}) (*File, error) Download(ctx context.Context, fileID string) (*FileResponse, error) ReadBase64(ctx context.Context, fileID string) (string, error) - HookInit(ctx chatctx.Context, messages []message.Message) (*ResHookInit, error) + HookInit(c *gin.Context, ctx chatctx.Context, messages []message.Message) (*ResHookInit, error) } // RAG the RAG interface diff --git a/neo/message/message.go b/neo/message/message.go index 48c3a77056..cad14749e1 100644 --- a/neo/message/message.go +++ b/neo/message/message.go @@ -22,6 +22,8 @@ type Message struct { IsDone bool `json:"done,omitempty"` Actions []Action `json:"actions,omitempty"` // Conversation Actions for frontend Attachments []Attachment `json:"attachments,omitempty"` // File attachments + Role string `json:"role,omitempty"` // user, assistant, system ... + Name string `json:"name,omitempty"` // name for the message Data map[string]interface{} `json:"-"` } @@ -134,12 +136,74 @@ func (m *Message) Error(message interface{}) *Message { return m } +// SetContent set the content +func (m *Message) SetContent(content string) *Message { + if strings.HasPrefix(content, "{") && strings.HasSuffix(content, "}") { + var msg Message + if err := jsoniter.UnmarshalFromString(content, &msg); err != nil { + m.Text = err.Error() + "\n" + content + return m + } + *m = msg + } else { + m.Text = content + m.Type = "text" + } + return m +} + +// Content get the content +func (m *Message) Content() string { + content := map[string]interface{}{"text": m.Text} + if m.Attachments != nil { + content["attachments"] = m.Attachments + } + + if m.Type != "" { + content["type"] = m.Type + } + contentRaw, _ := jsoniter.MarshalToString(content) + return contentRaw +} + +// ToMap convert to map +func (m *Message) ToMap() map[string]interface{} { + return map[string]interface{}{ + "content": m.Content(), + "role": m.Role, + "name": m.Name, + } +} + // Map set from map func (m *Message) Map(msg map[string]interface{}) *Message { if msg == nil { return m } + // Content {"text": "xxxx", "attachments": ... } + if content, ok := msg["content"].(string); ok { + if strings.HasPrefix(content, "{") && strings.HasSuffix(content, "}") { + var msg Message + if err := jsoniter.UnmarshalFromString(content, &msg); err != nil { + m.Text = err.Error() + "\n" + content + return m + } + *m = msg + } else { + m.Text = content + m.Type = "text" + } + } + + if role, ok := msg["role"].(string); ok { + m.Role = role + } + + if name, ok := msg["name"].(string); ok { + m.Name = name + } + if text, ok := msg["text"].(string); ok { m.Text = text } diff --git a/neo/neo.go b/neo/neo.go index 884f4a68ff..83bd91b2b4 100644 --- a/neo/neo.go +++ b/neo/neo.go @@ -18,7 +18,7 @@ var lock sync.Mutex = sync.Mutex{} // Answer reply the message func (neo *DSL) Answer(ctx chatctx.Context, question string, c *gin.Context) error { - messages, err := neo.chatMessages(ctx, question) + messages, err := neo.withHistory(ctx, question) if err != nil { msg := message.New().Error(err).Done() msg.Write(c.Writer) @@ -36,7 +36,7 @@ func (neo *DSL) Answer(ctx chatctx.Context, question string, c *gin.Context) err } // Init the assistant - res, err = ast.HookInit(ctx, []message.Message{{Text: question}}) + res, err = ast.HookInit(c, ctx, messages) if err != nil { return err } @@ -131,7 +131,11 @@ func (neo *DSL) GenerateWithAI(ctx chatctx.Context, input string, messageType st // Chat with AI in background go func() { - err := ast.Chat(c.Request.Context(), messages, neo.Option, func(data []byte) int { + msgList := make([]message.Message, len(messages)) + for i, msg := range messages { + msgList[i] = *message.New().Map(msg) + } + err := ast.Chat(c.Request.Context(), msgList, neo.Option, func(data []byte) int { select { case <-clientBreak: return 0 // break @@ -271,7 +275,7 @@ func (neo *DSL) Download(ctx chatctx.Context, c *gin.Context) (*assistant.FileRe } // chat chat with AI -func (neo *DSL) chat(ast assistant.API, ctx chatctx.Context, messages []map[string]interface{}, c *gin.Context) error { +func (neo *DSL) chat(ast assistant.API, ctx chatctx.Context, messages []message.Message, c *gin.Context) error { if ast == nil { msg := message.New().Error("assistant is not initialized").Done() msg.Write(c.Writer) @@ -350,33 +354,30 @@ func (neo *DSL) chat(ast assistant.API, ctx chatctx.Context, messages []map[stri } } -// chatMessages get the chat messages -func (neo *DSL) chatMessages(ctx chatctx.Context, content ...string) ([]map[string]interface{}, error) { - +func (neo *DSL) withHistory(ctx chatctx.Context, question string) ([]message.Message, error) { history, err := neo.Store.GetHistory(ctx.Sid, ctx.ChatID) if err != nil { return nil, err } - messages := []map[string]interface{}{} - messages = append(messages, history...) - if len(content) == 0 { - return messages, nil + // Add history messages + messages := []message.Message{} + for _, h := range history { + messages = append(messages, *message.New().Map(h)) } // Add user message - messages = append(messages, map[string]interface{}{"role": "user", "content": content[0], "name": ctx.Sid}) + messages = append(messages, *message.New().Map(map[string]interface{}{"role": "user", "content": question, "name": ctx.Sid})) return messages, nil } // saveHistory save the history -func (neo *DSL) saveHistory(sid string, chatID string, content []byte, messages []map[string]interface{}) { - +func (neo *DSL) saveHistory(sid string, chatID string, content []byte, messages []message.Message) { if len(content) > 0 && sid != "" && len(messages) > 0 { err := neo.Store.SaveHistory( sid, []map[string]interface{}{ - {"role": "user", "content": messages[len(messages)-1]["content"], "name": sid}, + {"role": "user", "content": messages[len(messages)-1].Content(), "name": sid}, {"role": "assistant", "content": string(content), "name": sid}, }, chatID, From cb2cd0c3173f9f9d188d29f4d31df4683538f120 Mon Sep 17 00:00:00 2001 From: Max Date: Mon, 13 Jan 2025 16:21:13 +0800 Subject: [PATCH 08/11] Refactor Neo API assistant interaction and streamline message handling - Updated the Answer method to utilize the new Execute method, simplifying the assistant interaction process. - Introduced a new Execute method in the Assistant struct to encapsulate the chat execution logic, enhancing clarity and maintainability. - Refactored the HookInit method to accept input messages and options, improving the initialization process for assistants. - Enhanced the ResHookInit struct to include next action handling and input messages, providing better control over assistant responses. - Removed obsolete chat handling methods, streamlining the codebase and improving overall structure. These changes improve the robustness and maintainability of the Neo API, paving the way for future enhancements in assistant functionalities and message management. --- neo/assistant/api.go | 180 +++++++++++++++++++++++++++++++++++++++++ neo/assistant/hooks.go | 17 +++- neo/assistant/types.go | 3 +- neo/neo.go | 160 +----------------------------------- 4 files changed, 198 insertions(+), 162 deletions(-) diff --git a/neo/assistant/api.go b/neo/assistant/api.go index 685e1bf2c8..bc0840965e 100644 --- a/neo/assistant/api.go +++ b/neo/assistant/api.go @@ -6,7 +6,9 @@ import ( "fmt" "strings" + "github.com/gin-gonic/gin" "github.com/yaoapp/gou/fs" + chatctx "github.com/yaoapp/yao/neo/context" "github.com/yaoapp/yao/neo/message" chatMessage "github.com/yaoapp/yao/neo/message" ) @@ -41,6 +43,184 @@ func GetByConnector(connector string, name string) (*Assistant, error) { return assistant, nil } +// Execute implements the execute functionality +func (ast *Assistant) Execute(c *gin.Context, ctx chatctx.Context, input string, options map[string]interface{}) error { + messages, err := ast.withHistory(ctx, input) + if err != nil { + return err + } + + options = ast.withOptions(options) + + // Run init hook + res, err := ast.HookInit(c, ctx, messages, options) + if err != nil { + return err + } + + // Switch to the new assistant if necessary + if res.AssistantID != ctx.AssistantID { + newAst, err := Get(res.AssistantID) + if err != nil { + return err + } + *ast = *newAst + } + + // Handle next action + if res.Next != nil { + switch res.Next.Action { + case "exit": + return nil + // Add other actions here if needed + } + } + + // Update options if provided + if res.Options != nil { + options = res.Options + } + + // Only proceed with chat stream if no specific next action was handled + return ast.handleChatStream(c, ctx, messages, options) +} + +// handleChatStream manages the streaming chat interaction with the AI +func (ast *Assistant) handleChatStream(c *gin.Context, ctx chatctx.Context, messages []message.Message, options map[string]interface{}) error { + clientBreak := make(chan bool, 1) + done := make(chan bool, 1) + content := []byte{} + + // Chat with AI in background + go func() { + err := ast.streamChat(c, messages, options, clientBreak, done, &content) + if err != nil { + chatMessage.New().Error(err).Done().Write(c.Writer) + } + + ast.saveChatHistory(ctx, messages, content) + done <- true + }() + + // Wait for completion or client disconnect + select { + case <-done: + return nil + case <-c.Writer.CloseNotify(): + clientBreak <- true + return nil + } +} + +// streamChat handles the streaming chat interaction +func (ast *Assistant) streamChat(c *gin.Context, messages []message.Message, options map[string]interface{}, + clientBreak chan bool, done chan bool, content *[]byte) error { + + return ast.Chat(c.Request.Context(), messages, options, func(data []byte) int { + select { + case <-clientBreak: + return 0 // break + + default: + msg := chatMessage.NewOpenAI(data) + if msg == nil { + return 1 // continue + } + + // Handle error + if msg.Type == "error" { + value := msg.String() + chatMessage.New().Error(value).Done().Write(c.Writer) + return 0 // break + } + + // Append content and send message + *content = msg.Append(*content) + value := msg.String() + if value != "" { + chatMessage.New(). + Map(map[string]interface{}{ + "text": value, + "done": msg.IsDone, + }). + Write(c.Writer) + } + + // Complete the stream + if msg.IsDone { + if value == "" { + msg.Write(c.Writer) + } + done <- true + return 0 // break + } + + return 1 // continue + } + }) +} + +// saveChatHistory saves the chat history if storage is available +func (ast *Assistant) saveChatHistory(ctx chatctx.Context, messages []message.Message, content []byte) { + if len(content) > 0 && ctx.Sid != "" && len(messages) > 0 { + storage.SaveHistory( + ctx.Sid, + []map[string]interface{}{ + {"role": "user", "content": messages[len(messages)-1].Content(), "name": ctx.Sid}, + {"role": "assistant", "content": string(content), "name": ctx.Sid}, + }, + ctx.ChatID, + nil, + ) + } +} + +func (ast *Assistant) withOptions(options map[string]interface{}) map[string]interface{} { + if options == nil { + options = map[string]interface{}{} + } + + if ast.Options != nil { + for key, value := range ast.Options { + options[key] = value + } + } + return options +} + +func (ast *Assistant) withPrompts(messages []message.Message) []message.Message { + if ast.Prompts != nil { + for _, prompt := range ast.Prompts { + name := ast.Name + if prompt.Name != "" { + name = prompt.Name + } + messages = append(messages, *message.New().Map(map[string]interface{}{"role": prompt.Role, "content": prompt.Content, "name": name})) + } + } + return messages +} + +func (ast *Assistant) withHistory(ctx chatctx.Context, input string) ([]message.Message, error) { + messages := []message.Message{} + messages = ast.withPrompts(messages) + if storage != nil { + history, err := storage.GetHistory(ctx.Sid, ctx.ChatID) + if err != nil { + return nil, err + } + + // Add history messages + for _, h := range history { + messages = append(messages, *message.New().Map(h)) + } + } + + // Add user message + messages = append(messages, *message.New().Map(map[string]interface{}{"role": "user", "content": input, "name": ctx.Sid})) + return messages, nil +} + // Chat implements the chat functionality func (ast *Assistant) Chat(ctx context.Context, messages []message.Message, option map[string]interface{}, cb func(data []byte) int) error { if ast.openai == nil { diff --git a/neo/assistant/hooks.go b/neo/assistant/hooks.go index 57f1329b31..bbe8894fad 100644 --- a/neo/assistant/hooks.go +++ b/neo/assistant/hooks.go @@ -17,17 +17,26 @@ const ( // ResHookInit the response of the init hook type ResHookInit struct { - AssistantID string `json:"assistant_id,omitempty"` - ChatID string `json:"chat_id,omitempty"` + AssistantID string `json:"assistant_id,omitempty"` + ChatID string `json:"chat_id,omitempty"` + Next *NextAction `json:"next,omitempty"` + Input []message.Message `json:"input,omitempty"` + Options map[string]interface{} `json:"options,omitempty"` +} + +// NextAction the next action +type NextAction struct { + Action string `json:"action"` + Payload map[string]interface{} `json:"payload,omitempty"` } // HookInit initialize the assistant -func (ast *Assistant) HookInit(c *gin.Context, context chatctx.Context, messages []message.Message) (*ResHookInit, error) { +func (ast *Assistant) HookInit(c *gin.Context, context chatctx.Context, input []message.Message, options map[string]interface{}) (*ResHookInit, error) { // Create timeout context ctx, cancel := ast.createTimeoutContext(c) defer cancel() - v, err := ast.call(ctx, "Init", context, messages, c.Writer) + v, err := ast.call(ctx, "Init", context, input, c.Writer) if err != nil { if err.Error() == HookErrorMethodNotFound { return nil, nil diff --git a/neo/assistant/types.go b/neo/assistant/types.go index 54bc47f8e7..9877b11eb2 100644 --- a/neo/assistant/types.go +++ b/neo/assistant/types.go @@ -19,7 +19,8 @@ type API interface { Upload(ctx context.Context, file *multipart.FileHeader, reader io.Reader, option map[string]interface{}) (*File, error) Download(ctx context.Context, fileID string) (*FileResponse, error) ReadBase64(ctx context.Context, fileID string) (string, error) - HookInit(c *gin.Context, ctx chatctx.Context, messages []message.Message) (*ResHookInit, error) + Execute(c *gin.Context, ctx chatctx.Context, input string, options map[string]interface{}) error + HookInit(c *gin.Context, ctx chatctx.Context, input []message.Message, options map[string]interface{}) (*ResHookInit, error) } // RAG the RAG interface diff --git a/neo/neo.go b/neo/neo.go index 83bd91b2b4..c9a45f98ba 100644 --- a/neo/neo.go +++ b/neo/neo.go @@ -4,7 +4,6 @@ import ( "fmt" "os" "strings" - "sync" "github.com/gin-gonic/gin" "github.com/yaoapp/kun/log" @@ -13,44 +12,17 @@ import ( "github.com/yaoapp/yao/neo/message" ) -// Lock the assistant list -var lock sync.Mutex = sync.Mutex{} - // Answer reply the message func (neo *DSL) Answer(ctx chatctx.Context, question string, c *gin.Context) error { - messages, err := neo.withHistory(ctx, question) - if err != nil { - msg := message.New().Error(err).Done() - msg.Write(c.Writer) - return err - } - - var res *assistant.ResHookInit = nil + var err error var ast assistant.API = neo.Assistant - if ctx.AssistantID != "" { ast, err = neo.Select(ctx.AssistantID) if err != nil { return err } } - - // Init the assistant - res, err = ast.HookInit(c, ctx, messages) - if err != nil { - return err - } - - // Switch to the new assistant if necessary - if res.AssistantID != ctx.AssistantID { - ast, err = neo.Select(res.AssistantID) - if err != nil { - return err - } - } - - // Chat with AI - return neo.chat(ast, ctx, messages, c) + return ast.Execute(c, ctx, question, nil) } // Select select an assistant @@ -87,6 +59,7 @@ func (neo *DSL) GenerateChatTitle(ctx chatctx.Context, input string, c *gin.Cont 2. The title should be a single sentence. 3. The title should be in same language as the chat. 4. The title should be no more than 50 characters. + 5. ANSWER ONLY THE TITLE CONTENT, FOR EXAMPLE: Chat with AI is a valid title, but "Chat with AI" is not a valid title. ` isSilent := false if len(silent) > 0 { @@ -273,130 +246,3 @@ func (neo *DSL) Download(ctx chatctx.Context, c *gin.Context) (*assistant.FileRe // Download file using the assistant return ast.Download(ctx.Context, fileID) } - -// chat chat with AI -func (neo *DSL) chat(ast assistant.API, ctx chatctx.Context, messages []message.Message, c *gin.Context) error { - if ast == nil { - msg := message.New().Error("assistant is not initialized").Done() - msg.Write(c.Writer) - return fmt.Errorf("assistant is not initialized") - } - - clientBreak := make(chan bool, 1) - done := make(chan bool, 1) - content := []byte{} - - // Chat with AI in background - go func() { - err := ast.Chat(c.Request.Context(), messages, neo.Option, func(data []byte) int { - select { - case <-clientBreak: - return 0 // break - - default: - msg := message.NewOpenAI(data) - if msg == nil { - return 1 // continue - } - - // Handle error - if msg.Type == "error" { - value := msg.String() - message.New().Error(value).Done().Write(c.Writer) - return 0 // break - } - - // Append content and send message - content = msg.Append(content) - value := msg.String() - if value != "" { - message.New(). - Map(map[string]interface{}{ - "text": value, - "done": msg.IsDone, - }). - Write(c.Writer) - } - - // Complete the stream - if msg.IsDone { - if value == "" { - msg.Write(c.Writer) - } - done <- true - return 0 // break - } - - return 1 // continue - } - }) - - if err != nil { - log.Error("Chat error: %s", err.Error()) - message.New().Error(err).Done().Write(c.Writer) - } - - // Save chat history - if len(content) > 0 { - neo.saveHistory(ctx.Sid, ctx.ChatID, content, messages) - } - - done <- true - }() - - // Wait for completion or client disconnect - select { - case <-done: - return nil - case <-c.Writer.CloseNotify(): - clientBreak <- true - return nil - } -} - -func (neo *DSL) withHistory(ctx chatctx.Context, question string) ([]message.Message, error) { - history, err := neo.Store.GetHistory(ctx.Sid, ctx.ChatID) - if err != nil { - return nil, err - } - - // Add history messages - messages := []message.Message{} - for _, h := range history { - messages = append(messages, *message.New().Map(h)) - } - - // Add user message - messages = append(messages, *message.New().Map(map[string]interface{}{"role": "user", "content": question, "name": ctx.Sid})) - return messages, nil -} - -// saveHistory save the history -func (neo *DSL) saveHistory(sid string, chatID string, content []byte, messages []message.Message) { - if len(content) > 0 && sid != "" && len(messages) > 0 { - err := neo.Store.SaveHistory( - sid, - []map[string]interface{}{ - {"role": "user", "content": messages[len(messages)-1].Content(), "name": sid}, - {"role": "assistant", "content": string(content), "name": sid}, - }, - chatID, - nil, - ) - - if err != nil { - log.Error("Save history error: %s", err.Error()) - } - } -} - -// sendMessage sends a message to the client -func (neo *DSL) sendMessage(w gin.ResponseWriter, data interface{}) error { - if msg, ok := data.(map[string]interface{}); ok { - if !message.New().Map(msg).Write(w) { - return fmt.Errorf("failed to write message to stream") - } - return nil - } - return fmt.Errorf("invalid message data type") -} From ca8993f4c8360892807759912f54bfb67e345245 Mon Sep 17 00:00:00 2001 From: Max Date: Mon, 13 Jan 2025 17:11:34 +0800 Subject: [PATCH 09/11] Enhance Neo API assistant interaction with improved message handling and streaming support - Added message handling in the Execute method to support input messages from the response. - Updated the handleChatStream method to pass the context to the streamChat method, improving context management during chat streaming. - Introduced the HookStream method to handle streaming responses, allowing for custom output and next action handling based on the assistant's response. - Enhanced the ResHookStream struct to include silent output control and next action management, providing better flexibility in assistant interactions. These changes improve the robustness and maintainability of the Neo API, paving the way for enhanced assistant functionalities and message management. --- neo/assistant/api.go | 44 ++++++++++++++++++++----- neo/assistant/hooks.go | 73 ++++++++++++++++++++++++++++++------------ neo/assistant/types.go | 27 ++++++++++++++++ 3 files changed, 116 insertions(+), 28 deletions(-) diff --git a/neo/assistant/api.go b/neo/assistant/api.go index bc0840965e..5b3052c84f 100644 --- a/neo/assistant/api.go +++ b/neo/assistant/api.go @@ -81,6 +81,11 @@ func (ast *Assistant) Execute(c *gin.Context, ctx chatctx.Context, input string, options = res.Options } + // messages + if res.Input != nil { + messages = res.Input + } + // Only proceed with chat stream if no specific next action was handled return ast.handleChatStream(c, ctx, messages, options) } @@ -93,7 +98,7 @@ func (ast *Assistant) handleChatStream(c *gin.Context, ctx chatctx.Context, mess // Chat with AI in background go func() { - err := ast.streamChat(c, messages, options, clientBreak, done, &content) + err := ast.streamChat(c, ctx, messages, options, clientBreak, done, &content) if err != nil { chatMessage.New().Error(err).Done().Write(c.Writer) } @@ -113,7 +118,7 @@ func (ast *Assistant) handleChatStream(c *gin.Context, ctx chatctx.Context, mess } // streamChat handles the streaming chat interaction -func (ast *Assistant) streamChat(c *gin.Context, messages []message.Message, options map[string]interface{}, +func (ast *Assistant) streamChat(c *gin.Context, ctx chatctx.Context, messages []message.Message, options map[string]interface{}, clientBreak chan bool, done chan bool, content *[]byte) error { return ast.Chat(c.Request.Context(), messages, options, func(data []byte) int { @@ -138,12 +143,35 @@ func (ast *Assistant) streamChat(c *gin.Context, messages []message.Message, opt *content = msg.Append(*content) value := msg.String() if value != "" { - chatMessage.New(). - Map(map[string]interface{}{ - "text": value, - "done": msg.IsDone, - }). - Write(c.Writer) + + // Handle stream + res, err := ast.HookStream(c, ctx, messages, value) + if err != nil { + return 0 // break + } + + // Custom output from hook + if res.Output != "" { + value = res.Output + } + + // Custom next action from hook + if res.Next != nil { + switch res.Next.Action { + case "exit": + done <- true + return 0 // break + } + } + + if !res.Silent { + chatMessage.New(). + Map(map[string]interface{}{ + "text": value, + "done": msg.IsDone, + }). + Write(c.Writer) + } } // Complete the stream diff --git a/neo/assistant/hooks.go b/neo/assistant/hooks.go index bbe8894fad..c1425b58bb 100644 --- a/neo/assistant/hooks.go +++ b/neo/assistant/hooks.go @@ -10,26 +10,6 @@ import ( "github.com/yaoapp/yao/neo/message" ) -const ( - // HookErrorMethodNotFound is the error message for method not found - HookErrorMethodNotFound = "method not found" -) - -// ResHookInit the response of the init hook -type ResHookInit struct { - AssistantID string `json:"assistant_id,omitempty"` - ChatID string `json:"chat_id,omitempty"` - Next *NextAction `json:"next,omitempty"` - Input []message.Message `json:"input,omitempty"` - Options map[string]interface{} `json:"options,omitempty"` -} - -// NextAction the next action -type NextAction struct { - Action string `json:"action"` - Payload map[string]interface{} `json:"payload,omitempty"` -} - // HookInit initialize the assistant func (ast *Assistant) HookInit(c *gin.Context, context chatctx.Context, input []message.Message, options map[string]interface{}) (*ResHookInit, error) { // Create timeout context @@ -54,6 +34,16 @@ func (ast *Assistant) HookInit(c *gin.Context, context chatctx.Context, input [] response.ChatID = res } + if res, ok := v["next"].(map[string]interface{}); ok { + response.Next = &NextAction{} + if name, ok := res["action"].(string); ok { + response.Next.Action = name + } + if payload, ok := res["payload"].(map[string]interface{}); ok { + response.Next.Payload = payload + } + } + case string: response.AssistantID = v response.ChatID = context.ChatID @@ -66,6 +56,49 @@ func (ast *Assistant) HookInit(c *gin.Context, context chatctx.Context, input [] return response, nil } +// HookStream Handle streaming response from LLM +func (ast *Assistant) HookStream(c *gin.Context, context chatctx.Context, input []message.Message, output string) (*ResHookStream, error) { + + // Create timeout context + ctx, cancel := ast.createTimeoutContext(c) + defer cancel() + + v, err := ast.call(ctx, "Stream", context, input, output, c.Writer) + if err != nil { + if err.Error() == HookErrorMethodNotFound { + return nil, nil + } + return nil, err + } + + response := &ResHookStream{} + switch v := v.(type) { + case map[string]interface{}: + if res, ok := v["output"].(string); ok { + response.Output = res + } + if res, ok := v["next"].(map[string]interface{}); ok { + response.Next = &NextAction{} + if name, ok := res["action"].(string); ok { + response.Next.Action = name + } + if payload, ok := res["payload"].(map[string]interface{}); ok { + response.Next.Payload = payload + } + } + + // Custom silent from hook + if res, ok := v["silent"].(bool); ok { + response.Silent = res + } + + case string: + response.Output = v + } + + return response, nil +} + // createTimeoutContext creates a timeout context with 5 seconds timeout func (ast *Assistant) createTimeoutContext(c *gin.Context) (context.Context, context.CancelFunc) { ctx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Second) diff --git a/neo/assistant/types.go b/neo/assistant/types.go index 9877b11eb2..058e926b3f 100644 --- a/neo/assistant/types.go +++ b/neo/assistant/types.go @@ -13,6 +13,11 @@ import ( api "github.com/yaoapp/yao/openai" ) +const ( + // HookErrorMethodNotFound is the error message for method not found + HookErrorMethodNotFound = "method not found" +) + // API the assistant API interface type API interface { Chat(ctx context.Context, messages []message.Message, option map[string]interface{}, cb func(data []byte) int) error @@ -23,6 +28,28 @@ type API interface { HookInit(c *gin.Context, ctx chatctx.Context, input []message.Message, options map[string]interface{}) (*ResHookInit, error) } +// ResHookInit the response of the init hook +type ResHookInit struct { + AssistantID string `json:"assistant_id,omitempty"` + ChatID string `json:"chat_id,omitempty"` + Next *NextAction `json:"next,omitempty"` + Input []message.Message `json:"input,omitempty"` + Options map[string]interface{} `json:"options,omitempty"` +} + +// ResHookStream the response of the stream hook +type ResHookStream struct { + Silent bool `json:"silent,omitempty"` // Whether to suppress the output + Next *NextAction `json:"next,omitempty"` // The next action + Output string `json:"output,omitempty"` // The output +} + +// NextAction the next action +type NextAction struct { + Action string `json:"action"` + Payload map[string]interface{} `json:"payload,omitempty"` +} + // RAG the RAG interface type RAG struct { Engine driver.Engine From e440f1ff81c0c47d124c1099fffd2b204adde89b Mon Sep 17 00:00:00 2001 From: Max Date: Mon, 13 Jan 2025 17:32:51 +0800 Subject: [PATCH 10/11] Enhance Neo API assistant response handling with new hook methods - Introduced HookDone and HookFail methods to manage completion and failure scenarios in assistant responses, improving error handling and output customization. - Updated streamChat method to utilize these hooks, allowing for more flexible response management based on the assistant's output and error states. - Enhanced ResHookDone and ResHookFail structs to include next action handling, input messages, and error information, providing better control over assistant interactions. These changes improve the robustness and maintainability of the Neo API, paving the way for enhanced assistant functionalities and message management. --- neo/assistant/api.go | 66 ++++++++++++++++++++++----------- neo/assistant/hooks.go | 84 ++++++++++++++++++++++++++++++++++++++++++ neo/assistant/types.go | 15 ++++++++ 3 files changed, 144 insertions(+), 21 deletions(-) diff --git a/neo/assistant/api.go b/neo/assistant/api.go index 5b3052c84f..ae09c75141 100644 --- a/neo/assistant/api.go +++ b/neo/assistant/api.go @@ -135,6 +135,13 @@ func (ast *Assistant) streamChat(c *gin.Context, ctx chatctx.Context, messages [ // Handle error if msg.Type == "error" { value := msg.String() + res, hookErr := ast.HookFail(c, ctx, messages, string(*content), fmt.Errorf("%s", value)) + if hookErr == nil && res != nil && (res.Output != "" || res.Error != "") { + value = res.Output + if res.Error != "" { + value = res.Error + } + } chatMessage.New().Error(value).Done().Write(c.Writer) return 0 // break } @@ -143,42 +150,59 @@ func (ast *Assistant) streamChat(c *gin.Context, ctx chatctx.Context, messages [ *content = msg.Append(*content) value := msg.String() if value != "" { - // Handle stream - res, err := ast.HookStream(c, ctx, messages, value) - if err != nil { - return 0 // break + res, err := ast.HookStream(c, ctx, messages, string(*content)) + if err == nil && res != nil { + if res.Output != "" { + value = res.Output + } + if res.Next != nil && res.Next.Action == "exit" { + done <- true + return 0 // break + } + if res.Silent { + return 1 // continue + } } - // Custom output from hook - if res.Output != "" { - value = res.Output + chatMessage.New(). + Map(map[string]interface{}{ + "text": value, + "done": msg.IsDone, + }). + Write(c.Writer) + } + + // Complete the stream + if msg.IsDone { + if value == "" { + msg.Write(c.Writer) } - // Custom next action from hook - if res.Next != nil { - switch res.Next.Action { - case "exit": + // Call HookDone + res, hookErr := ast.HookDone(c, ctx, messages, string(*content)) + if hookErr == nil && res != nil { + if res.Output != "" { + chatMessage.New(). + Map(map[string]interface{}{ + "text": res.Output, + "done": true, + }). + Write(c.Writer) + } + if res.Next != nil && res.Next.Action == "exit" { done <- true return 0 // break } - } - - if !res.Silent { + } else if value != "" { chatMessage.New(). Map(map[string]interface{}{ "text": value, - "done": msg.IsDone, + "done": true, }). Write(c.Writer) } - } - // Complete the stream - if msg.IsDone { - if value == "" { - msg.Write(c.Writer) - } done <- true return 0 // break } diff --git a/neo/assistant/hooks.go b/neo/assistant/hooks.go index c1425b58bb..c6dd65b2c3 100644 --- a/neo/assistant/hooks.go +++ b/neo/assistant/hooks.go @@ -99,6 +99,90 @@ func (ast *Assistant) HookStream(c *gin.Context, context chatctx.Context, input return response, nil } +// HookDone Handle completion of assistant response +func (ast *Assistant) HookDone(c *gin.Context, context chatctx.Context, input []message.Message, output string) (*ResHookDone, error) { + // Create timeout context + ctx, cancel := ast.createTimeoutContext(c) + defer cancel() + + v, err := ast.call(ctx, "Done", context, input, output, c.Writer) + if err != nil { + if err.Error() == HookErrorMethodNotFound { + return nil, nil + } + return nil, err + } + + response := &ResHookDone{ + Input: input, + Output: output, + } + + switch v := v.(type) { + case map[string]interface{}: + if res, ok := v["output"].(string); ok { + response.Output = res + } + if res, ok := v["next"].(map[string]interface{}); ok { + response.Next = &NextAction{} + if name, ok := res["action"].(string); ok { + response.Next.Action = name + } + if payload, ok := res["payload"].(map[string]interface{}); ok { + response.Next.Payload = payload + } + } + case string: + response.Output = v + } + + return response, nil +} + +// HookFail Handle failure of assistant response +func (ast *Assistant) HookFail(c *gin.Context, context chatctx.Context, input []message.Message, output string, err error) (*ResHookFail, error) { + // Create timeout context + ctx, cancel := ast.createTimeoutContext(c) + defer cancel() + + v, callErr := ast.call(ctx, "Fail", context, input, output, err.Error(), c.Writer) + if callErr != nil { + if callErr.Error() == HookErrorMethodNotFound { + return nil, nil + } + return nil, callErr + } + + response := &ResHookFail{ + Input: input, + Output: output, + Error: err.Error(), + } + + switch v := v.(type) { + case map[string]interface{}: + if res, ok := v["output"].(string); ok { + response.Output = res + } + if res, ok := v["error"].(string); ok { + response.Error = res + } + if res, ok := v["next"].(map[string]interface{}); ok { + response.Next = &NextAction{} + if name, ok := res["action"].(string); ok { + response.Next.Action = name + } + if payload, ok := res["payload"].(map[string]interface{}); ok { + response.Next.Payload = payload + } + } + case string: + response.Output = v + } + + return response, nil +} + // createTimeoutContext creates a timeout context with 5 seconds timeout func (ast *Assistant) createTimeoutContext(c *gin.Context) (context.Context, context.CancelFunc) { ctx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Second) diff --git a/neo/assistant/types.go b/neo/assistant/types.go index 058e926b3f..80862a01ce 100644 --- a/neo/assistant/types.go +++ b/neo/assistant/types.go @@ -44,6 +44,21 @@ type ResHookStream struct { Output string `json:"output,omitempty"` // The output } +// ResHookDone the response of the done hook +type ResHookDone struct { + Next *NextAction `json:"next,omitempty"` + Input []message.Message `json:"input,omitempty"` + Output string `json:"output,omitempty"` +} + +// ResHookFail the response of the fail hook +type ResHookFail struct { + Next *NextAction `json:"next,omitempty"` + Input []message.Message `json:"input,omitempty"` + Output string `json:"output,omitempty"` + Error string `json:"error,omitempty"` +} + // NextAction the next action type NextAction struct { Action string `json:"action"` From 627328b2790eb529b44bc4a5d6e7ab69fbf1faf9 Mon Sep 17 00:00:00 2001 From: Max Date: Mon, 13 Jan 2025 18:13:40 +0800 Subject: [PATCH 11/11] Refactor streamChat method in Neo API assistant to improve message completion handling - Commented out the message writing logic in the streamChat method to prevent unintended output when the message is marked as done. - This change enhances the control over the response flow, allowing for better integration with the newly introduced hook methods for managing assistant interactions. These modifications contribute to the overall robustness and maintainability of the Neo API, aligning with recent enhancements in assistant functionalities. --- neo/assistant/api.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/neo/assistant/api.go b/neo/assistant/api.go index ae09c75141..eda682bdbf 100644 --- a/neo/assistant/api.go +++ b/neo/assistant/api.go @@ -175,9 +175,9 @@ func (ast *Assistant) streamChat(c *gin.Context, ctx chatctx.Context, messages [ // Complete the stream if msg.IsDone { - if value == "" { - msg.Write(c.Writer) - } + // if value == "" { + // msg.Write(c.Writer) + // } // Call HookDone res, hookErr := ast.HookDone(c, ctx, messages, string(*content))