diff --git a/applogic/ernie_function.go b/applogic/ernie_function.go new file mode 100644 index 0000000..6d3d967 --- /dev/null +++ b/applogic/ernie_function.go @@ -0,0 +1,317 @@ +package applogic + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "strings" + + "github.com/hoshinonyaruko/gensokyo-llm/config" + "github.com/hoshinonyaruko/gensokyo-llm/fmtf" + "github.com/hoshinonyaruko/gensokyo-llm/structs" + "github.com/hoshinonyaruko/gensokyo-llm/utils" +) + +//var mutexErnie sync.Mutex + +func (app *App) ChatHandlerErnieFunction(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + http.Error(w, "Only POST method is allowed", http.StatusMethodNotAllowed) + return + } + + var msg structs.WXRequestMessageF + err := json.NewDecoder(r.Body).Decode(&msg) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + msg.Role = "user" + //颠倒用户输入 + if config.GetReverseUserPrompt() { + msg.Text = utils.ReverseString(msg.Text) + } + + if msg.ConversationID == "" { + msg.ConversationID = utils.GenerateUUID() + app.createConversation(msg.ConversationID) + } + + //转换一下 + tempmsg := structs.Message{ + ConversationID: msg.ConversationID, + ParentMessageID: msg.ParentMessageID, + Text: msg.Text, + Role: msg.Role, + CreatedAt: msg.CreatedAt, + } + + userMessageID, err := app.addMessage(tempmsg) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // 构建请求负载 + var payload structs.WXRequestPayloadF + + // 添加当前用户消息 + payload.Messages = append(payload.Messages, structs.WXMessage{ + Content: msg.Text, + Role: "user", + }) + + TopP := config.GetWenxinTopp() + PenaltyScore := config.GetWnxinPenaltyScore() + MaxOutputTokens := config.GetWenxinMaxOutputTokens() + + // 设置其他可选参数 + payload.TopP = TopP + payload.PenaltyScore = PenaltyScore + payload.MaxOutputTokens = MaxOutputTokens + // 增加function + payload.Functions = append(payload.Functions, msg.WXFunction) + //payload.ResponseFormat = "json_object" + payload.ToolChoice.Type = "function" + payload.ToolChoice.Function.Name = "predict_followup_questions" + + // 是否sse + if config.GetuseSse() { + payload.Stream = true + } + + // 获取系统提示词,并设置system字段,如果它不为空 + systemPromptContent := config.SystemPrompt() // 确保函数名正确 + if systemPromptContent != "0" { + payload.System = systemPromptContent // 直接在请求负载中设置system字段 + } + + // 获取访问凭证和API路径 + accessToken := config.GetWenxinAccessToken() + apiPath := config.GetWenxinApiPath() + + // 构建请求URL + url := fmtf.Sprintf("%s?access_token=%s", apiPath, accessToken) + fmtf.Printf("%v\n", url) + + // 序列化请求负载 + jsonData, err := json.Marshal(payload) + if err != nil { + log.Fatalf("Error occurred during marshaling. Error: %s", err.Error()) + } + + fmtf.Printf("%v\n", string(jsonData)) + + // 创建并发送POST请求 + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + log.Fatalf("Error occurred during request creation. Error: %s", err.Error()) + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + log.Fatalf("Error occurred during sending the request. Error: %s", err.Error()) + } + defer resp.Body.Close() + + // 读取响应头中的速率限制信息 + rateLimitRequests := resp.Header.Get("X-Ratelimit-Limit-Requests") + rateLimitTokens := resp.Header.Get("X-Ratelimit-Limit-Tokens") + remainingRequests := resp.Header.Get("X-Ratelimit-Remaining-Requests") + remainingTokens := resp.Header.Get("X-Ratelimit-Remaining-Tokens") + + fmtf.Printf("RateLimit: Requests %s, Tokens %s, Remaining Requests %s, Remaining Tokens %s\n", + rateLimitRequests, rateLimitTokens, remainingRequests, remainingTokens) + + // 检查是否不使用SSE + if !config.GetuseSse() { + // 读取整个响应体到内存中 + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + log.Fatalf("Error occurred during response body reading. Error: %s", err) + } + + // 首先尝试解析为简单的map来查看响应概览 + var response map[string]interface{} + if err := json.Unmarshal(bodyBytes, &response); err != nil { + log.Fatalf("Error occurred during response decoding to map. Error: %s", err) + } + fmtf.Printf("%v\n", response) + + // 然后尝试解析为具体的结构体以获取详细信息 + var responseStruct struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + SentenceID int `json:"sentence_id,omitempty"` + IsEnd bool `json:"is_end,omitempty"` + IsTruncated bool `json:"is_truncated"` + Result string `json:"result"` + NeedClearHistory bool `json:"need_clear_history"` + BanRound int `json:"ban_round"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` + } + + if err := json.Unmarshal(bodyBytes, &responseStruct); err != nil { + http.Error(w, fmtf.Sprintf("解析响应体出错: %v", err), http.StatusInternalServerError) + return + } + // 根据API响应构造消息和响应给客户端 + assistantMessageID, err := app.addMessage(structs.Message{ + ConversationID: msg.ConversationID, + ParentMessageID: userMessageID, + Text: responseStruct.Result, + Role: "assistant", + }) + + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // 构造响应 + responseMap := map[string]interface{}{ + "response": responseStruct.Result, + "conversationId": msg.ConversationID, + "messageId": assistantMessageID, + "details": map[string]interface{}{ + "usage": map[string]int{ + "prompt_tokens": responseStruct.Usage.PromptTokens, + "completion_tokens": responseStruct.Usage.CompletionTokens, + "total_tokens": responseStruct.Usage.TotalTokens, + }, + }, + } + + // 设置响应头信息以反映速率限制状态 + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Ratelimit-Limit-Requests", rateLimitRequests) + w.Header().Set("X-Ratelimit-Limit-Tokens", rateLimitTokens) + w.Header().Set("X-Ratelimit-Remaining-Requests", remainingRequests) + w.Header().Set("X-Ratelimit-Remaining-Tokens", remainingTokens) + + // 发送JSON响应 + json.NewEncoder(w).Encode(responseMap) + } else { + // SSE响应模式 + // 设置SSE相关的响应头部 + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported!", http.StatusInternalServerError) + return + } + + var responseTextBuilder strings.Builder + var totalUsage structs.UsageInfo + + // 假设我们已经建立了与API的连接并且开始接收流式响应 + // reader代表从API接收数据的流 + reader := bufio.NewReader(resp.Body) + for { + // 读取流中的一行,即一个事件数据块 + line, err := reader.ReadString('\n') + if err != nil { + if err == io.EOF { + // 流结束 + break + } + // 处理错误 + fmtf.Fprintf(w, "data: %s\n\n", fmtf.Sprintf("读取流数据时发生错误: %v", err)) + flusher.Flush() + continue + } + + // 处理流式数据行 + if strings.HasPrefix(line, "data: ") { + eventDataJSON := line[6:] // 去掉"data: "前缀 + + var eventData struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + SentenceID int `json:"sentence_id,omitempty"` + IsEnd bool `json:"is_end,omitempty"` + IsTruncated bool `json:"is_truncated"` + Result string `json:"result"` + NeedClearHistory bool `json:"need_clear_history"` + BanRound int `json:"ban_round"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` + } + // 解析JSON数据 + if err := json.Unmarshal([]byte(eventDataJSON), &eventData); err != nil { + fmtf.Fprintf(w, "data: %s\n\n", fmtf.Sprintf("解析事件数据出错: %v", err)) + flusher.Flush() + continue + } + + // 这里处理解析后的事件数据 + responseTextBuilder.WriteString(eventData.Result) + totalUsage.PromptTokens += eventData.Usage.PromptTokens + totalUsage.CompletionTokens += eventData.Usage.CompletionTokens + + // 发送当前事件的响应数据,但不包含assistantMessageID + tempResponseMap := map[string]interface{}{ + "response": eventData.Result, + "conversationId": msg.ConversationID, + "details": map[string]interface{}{ + "usage": eventData.Usage, + }, + } + tempResponseJSON, _ := json.Marshal(tempResponseMap) + fmtf.Fprintf(w, "data: %s\n\n", string(tempResponseJSON)) + flusher.Flush() + + // 如果这是最后一个消息 + if eventData.IsEnd { + break + } + } + } + + // 处理完所有事件后,生成并发送包含assistantMessageID的最终响应 + //fmt.Printf("处理完所有事件后,生成并发送包含assistantMessageID的最终响应\n") + responseText := responseTextBuilder.String() + assistantMessageID, err := app.addMessage(structs.Message{ + ConversationID: msg.ConversationID, + ParentMessageID: userMessageID, + Text: responseText, + Role: "assistant", + }) + + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + finalResponseMap := map[string]interface{}{ + "response": responseText, + "conversationId": msg.ConversationID, + "messageId": assistantMessageID, + "details": map[string]interface{}{ + "usage": totalUsage, + }, + } + finalResponseJSON, _ := json.Marshal(finalResponseMap) + fmt.Fprintf(w, "data: %s\n\n", string(finalResponseJSON)) + flusher.Flush() + } + +} diff --git a/applogic/gensokyo.go b/applogic/gensokyo.go index a901c3d..970b9d8 100644 --- a/applogic/gensokyo.go +++ b/applogic/gensokyo.go @@ -487,8 +487,16 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { } if message.RealMessageType == "group_private" || message.MessageType == "private" { if config.GetUsePrivateSSE() { + //发气泡和按钮 - promptkeyboard := config.GetPromptkeyboard() + var promptkeyboard []string + if !config.GetUseAIPromptkeyboard() { + promptkeyboard = config.GetPromptkeyboard() + } else { + fmtf.Printf("ai生成气泡:%v", "Q"+newmsg+"A"+response) + promptkeyboard = GetPromptKeyboardAI("Q" + newmsg + "A" + response) + } + //最后一条了 messageSSE := structs.InterfaceBody{ Content: " " + "\n", diff --git a/applogic/promptkeyboard.go b/applogic/promptkeyboard.go new file mode 100644 index 0000000..c8b6180 --- /dev/null +++ b/applogic/promptkeyboard.go @@ -0,0 +1,67 @@ +package applogic + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/hoshinonyaruko/gensokyo-llm/config" +) + +// ResponseDataPromptKeyboard 用于解析外层响应 +type ResponseDataPromptKeyboard struct { + ConversationID string `json:"conversationId"` + MessageID string `json:"messageId"` + Response string `json:"response"` // 这里是嵌套的JSON字符串 +} + +// 你要扮演一个json生成器,根据我下一句提交的QA内容,推断我可能会继续问的问题,生成json数组格式的结果,如:输入Q我好累啊A要休息一下吗,返回["嗯,我想要休息","我想喝杯咖啡","你平时怎么休息呢"],返回需要是["","",""]需要2-3个结果 +func GetPromptKeyboardAI(msg string) []string { + url := config.GetAIPromptkeyboardPath() + requestBody, err := json.Marshal(map[string]interface{}{ + "message": msg, + "conversationId": "", + "parentMessageId": "", + "user_id": "", + }) + if err != nil { + fmt.Printf("Error marshalling request: %v\n", err) + return config.GetPromptkeyboard() + } + + resp, err := http.Post(url, "application/json", bytes.NewBuffer(requestBody)) + if err != nil { + fmt.Printf("Error sending request: %v\n", err) + return config.GetPromptkeyboard() + } + defer resp.Body.Close() + + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Printf("Error reading response body: %v\n", err) + return config.GetPromptkeyboard() + } + fmt.Printf("Response: %s\n", string(responseBody)) + + var responseData ResponseDataPromptKeyboard + if err := json.Unmarshal(responseBody, &responseData); err != nil { + fmt.Printf("Error unmarshalling response data: %v\n", err) + return config.GetPromptkeyboard() + } + + var keyboardPrompts []string + // 预处理响应数据,移除可能的换行符 + preprocessedResponse := strings.TrimSpace(responseData.Response) + + // 尝试直接解析JSON + err = json.Unmarshal([]byte(preprocessedResponse), &keyboardPrompts) + if err != nil { + fmt.Printf("Error unmarshalling nested response: %v\n", err) + return config.GetPromptkeyboard() + } + + return keyboardPrompts +} diff --git a/config/config.go b/config/config.go index 1cf47ed..9938065 100644 --- a/config/config.go +++ b/config/config.go @@ -88,6 +88,11 @@ type Settings struct { BlacklistResponseMessages []string `yaml:"blacklistResponseMessages"` NoContext bool `yaml:"noContext"` WithdrawCommand []string `yaml:"withdrawCommand"` + FunctionMode bool `yaml:"functionMode"` + FunctionPath string `yaml:"functionPath"` + UseFunctionPromptkeyboard bool `yaml:"useFunctionPromptkeyboard"` + AIPromptkeyboardPath string `yaml:"AIPromptkeyboardPath"` + UseAIPromptkeyboard bool `yaml:"useAIPromptkeyboard"` } // LoadConfig 从文件中加载配置并初始化单例配置 @@ -585,7 +590,6 @@ func GetUsePrivateSSE() bool { } // GetPromptkeyboard 获取Promptkeyboard,如果超过3个成员则随机选择3个 - func GetPromptkeyboard() []string { mu.Lock() defer mu.Unlock() @@ -889,3 +893,53 @@ func GetWithdrawCommand() []string { } return nil } + +// 获取FunctionMode +func GetFunctionMode() bool { + mu.Lock() + defer mu.Unlock() + if instance != nil { + return instance.Settings.FunctionMode + } + return false +} + +// 获取FunctionPath +func GetFunctionPath() string { + mu.Lock() + defer mu.Unlock() + if instance != nil { + return instance.Settings.FunctionPath + } + return "" +} + +// 获取UseFunctionPromptkeyboard +func GetUseFunctionPromptkeyboard() bool { + mu.Lock() + defer mu.Unlock() + if instance != nil { + return instance.Settings.UseFunctionPromptkeyboard + } + return false +} + +// 获取UseAIPromptkeyboard +func GetUseAIPromptkeyboard() bool { + mu.Lock() + defer mu.Unlock() + if instance != nil { + return instance.Settings.UseAIPromptkeyboard + } + return false +} + +// 获取AIPromptkeyboardPath +func GetAIPromptkeyboardPath() string { + mu.Lock() + defer mu.Unlock() + if instance != nil { + return instance.Settings.AIPromptkeyboardPath + } + return "" +} diff --git a/function/function.go b/function/function.go new file mode 100644 index 0000000..780ab05 --- /dev/null +++ b/function/function.go @@ -0,0 +1,70 @@ +package function + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/hoshinonyaruko/gensokyo-llm/config" + "github.com/hoshinonyaruko/gensokyo-llm/structs" +) + +// GetPromptkeyboard 请求并打印3个预测的问题 +func GetPromptkeyboard(msg string) bool { + url := config.GetFunctionPath() + wxFunction := structs.WXFunction{ + Name: "predict_followup_questions", + Description: "根据用户输入,预测用户可能接下来提出的三个相关问题", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "question": map[string]interface{}{ + "type": "string", + "description": "用户提出的初始问题", + }, + }, + "required": []string{"question"}, + }, + Responses: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "followup_questions": map[string]interface{}{ + "type": "array", + "items": map[string]string{"type": "string"}, + "description": "预测的后续问题列表", + }, + }, + }, + } + + request := structs.WXRequestMessageF{ + Text: msg, + WXFunction: wxFunction, + } + + requestBody, err := json.Marshal(request) + if err != nil { + fmt.Printf("Error marshalling request: %v\n", err) + return false + } + + resp, err := http.Post(url, "application/json", bytes.NewBuffer(requestBody)) + if err != nil { + fmt.Printf("Error sending request: %v\n", err) + return false + } + defer resp.Body.Close() + + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Printf("Error reading response body: %v\n", err) + return false + } + fmt.Printf("Response: %s\n", string(responseBody)) + + // 这里可以添加逻辑以解析和处理响应数据 + + return true // 根据实际情况可能需要调整返回值 +} diff --git a/main.go b/main.go index 9e32a4d..be9a705 100644 --- a/main.go +++ b/main.go @@ -134,7 +134,13 @@ func main() { http.HandleFunc("/conversation", app.ChatHandlerHunyuan) case 1: // 如果API类型是1,使用app.chatHandlerErnie - http.HandleFunc("/conversation", app.ChatHandlerErnie) + // 如果开启function模式 切换function端点 + if !config.GetFunctionMode() { + http.HandleFunc("/conversation", app.ChatHandlerErnie) + } else { + http.HandleFunc("/conversation", app.ChatHandlerErnieFunction) + } + case 2: // 如果API类型是2,使用app.chatHandlerChatGpt http.HandleFunc("/conversation", app.ChatHandlerChatgpt) diff --git a/readme.md b/readme.md index cc15a7a..b8dec4c 100644 --- a/readme.md +++ b/readme.md @@ -133,4 +133,26 @@ AhoCorasick算法实现的超高效文本IN-Out替换规则,可大量替换n ## 场景支持 API方式调用 -QQ频道直接接入 \ No newline at end of file +QQ频道直接接入 + +## 约定参数 + +审核员请求参数 + +当需要将请求发给另一个 GSK LLM 作为审核员时,应该返回的 JSON 格式如下: + +```json +{"result":%s} +``` + +这里的 `%s` 代表一个将被替换为具体浮点数值的占位符。 + +气泡生成请求结果 + +当请求另一个 GSK LLM 生成气泡时,应该返回的 JSON 格式如下: + +```json +["","",""] +``` + +这表示气泡生成的结果是一个包含三个字符串的数组。这个格式用于在返回结果时指明三个不同的气泡,也可以少于或等于3个. \ No newline at end of file diff --git a/structs/struct.go b/structs/struct.go index 6879dde..8423a8d 100644 --- a/structs/struct.go +++ b/structs/struct.go @@ -8,6 +8,23 @@ type Message struct { CreatedAt string `json:"created_at"` } +type WXRequestMessage struct { + ConversationID string `json:"conversationId"` + ParentMessageID string `json:"parentMessageId"` + Text string `json:"message"` + Role string `json:"role"` + CreatedAt string `json:"created_at"` +} + +type WXRequestMessageF struct { + ConversationID string `json:"conversationId"` + ParentMessageID string `json:"parentMessageId"` + Text string `json:"message"` + Role string `json:"role"` + CreatedAt string `json:"created_at"` + WXFunction WXFunction `json:"functions,omitempty"` +} + type UsageInfo struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` @@ -54,6 +71,13 @@ type WXMessage struct { Role string `json:"role"` } +// 定义请求消息的结构体 +type WXMessageF struct { + Content string `json:"content"` + Role string `json:"role"` + FunctionCall WXFunctionCall `json:"function_call,omitempty"` +} + // 定义请求负载的结构体 type WXRequestPayload struct { Messages []WXMessage `json:"messages"` @@ -67,6 +91,33 @@ type WXRequestPayload struct { UserID string `json:"user_id,omitempty"` } +// 定义请求负载的结构体 +type WXRequestPayloadF struct { + Messages []WXMessage `json:"messages"` + Functions []WXFunction `json:"functions,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + PenaltyScore float64 `json:"penalty_score,omitempty"` + System string `json:"system,omitempty"` + Stop []string `json:"stop,omitempty"` + MaxOutputTokens int `json:"max_output_tokens,omitempty"` + UserID string `json:"user_id,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + ToolChoice ToolChoice `json:"tool_choice,omitempty"` +} + +// Function 描述了一个可调用的函数的细节 +type Function struct { + Name string `json:"name"` // 函数名 +} + +// ToolChoice 描述了要使用的工具和具体的函数选择 +type ToolChoice struct { + Type string `json:"type"` // 工具类型,这里固定为"function" + Function Function `json:"function"` // 指定要使用的函数 +} + type ChatGPTMessage struct { Role string `json:"role"` Content string `json:"content"` @@ -131,3 +182,27 @@ type EmbeddingResponseErnie struct { Object string `json:"object"` Data []EmbeddingDataErnie `json:"data"` } + +// Function 描述了一个可调用的函数的结构 +type WXFunction struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters map[string]interface{} `json:"parameters"` + Responses map[string]interface{} `json:"responses,omitempty"` + Examples [][]WXExample `json:"examples,omitempty"` +} + +// Example 描述了函数调用的一个示例 +type WXExample struct { + Role string `json:"role"` + Content string `json:"content"` + Name string `json:"name,omitempty"` + FunctionCall *WXFunctionCall `json:"function_call,omitempty"` +} + +// FunctionCall 描述了一个函数调用 +type WXFunctionCall struct { + Name string `json:"name,omitempty"` + Arguments map[string]interface{} `json:"arguments,omitempty"` + Thought string `json:"thought,omitempty"` +} diff --git a/template/config_template.go b/template/config_template.go index 4f0e842..6f7ff8b 100644 --- a/template/config_template.go +++ b/template/config_template.go @@ -35,6 +35,17 @@ settings: savelogs : false #本地落地日志. noContext : false #不开启上下文 withdrawCommand : ["撤回"] #撤回指令 + + functionMode : false #是否指定本agent使用func模式(目前仅支持千帆平台),效果不好,暂时不用. + functionPath : "" #调用另一个启用了func模式的gsk-llm联合工作的/conversation地址,效果不好,暂时不用. + useFunctionPromptkeyboard : false #使用func生成气泡,效果不好,暂时不用. + + AIPromptkeyboardPath : "" #调用另一个设置系统提示词的gsk-llm联合工作的/conversation地址,约定系统提示词需返回文本json数组(3个). + useAIPromptkeyboard : false #使用ai生成气泡. + #systemPrompt: [ + # "你要扮演一个json生成器,根据我下一句提交的QA内容,推断我可能会继续问的问题,生成json数组格式的结果,如:输入Q我好累啊A要休息一下吗,返回[\"嗯,我想要休息\",\"我想喝杯咖啡\",\"你平时怎么休息呢\"],返回需要是[\"\",\"\",\"\"]需要2-3个结果" + #] + #语言过滤 allowedLanguages : ["cmn"] #根据自身安全实力,酌情过滤,cmn代表中文,小写字母,[]空数组代表不限制. langResponseMessages : ["抱歉,我不会**这个语言呢","我不会**这门语言,请使用中文和我对话吧"] #定型文,**会自动替换为检测到的语言