-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* beta1 * beta2 * beta3 * beta4 * beta5 * beta6 * beta7 * beta8 * beta9 * beta10 * beta11 * beta12 * beta13 * beta14 * beta15 * beta16 * beta16 * beta19 * beta20 * beta21 * beta22 * beta23 * beta24 * beta25 * beta27 * beta28 * beta29 * beta30 * beta31 * beta33 * beta34 * beta35 * beta36 * beta37 * beta38 * beta39 * beta40 * beta41 * beta42 * beta43 * beta44 * beta45 * beta45 * beta46 * beat48 * beta49 * beta50 * beta51 * beta52 * beta53 * beta54 * beta55 * beta57
- Loading branch information
1 parent
6ab103f
commit e33000b
Showing
9 changed files
with
634 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,317 @@ | ||
package applogic | ||
|
||
import ( | ||
"bufio" | ||
"bytes" | ||
"encoding/json" | ||
"fmt" | ||
"io" | ||
"log" | ||
"net/http" | ||
"strings" | ||
|
||
"github.com/hoshinonyaruko/gensokyo-llm/config" | ||
"github.com/hoshinonyaruko/gensokyo-llm/fmtf" | ||
"github.com/hoshinonyaruko/gensokyo-llm/structs" | ||
"github.com/hoshinonyaruko/gensokyo-llm/utils" | ||
) | ||
|
||
//var mutexErnie sync.Mutex | ||
|
||
func (app *App) ChatHandlerErnieFunction(w http.ResponseWriter, r *http.Request) { | ||
if r.Method != "POST" { | ||
http.Error(w, "Only POST method is allowed", http.StatusMethodNotAllowed) | ||
return | ||
} | ||
|
||
var msg structs.WXRequestMessageF | ||
err := json.NewDecoder(r.Body).Decode(&msg) | ||
if err != nil { | ||
http.Error(w, err.Error(), http.StatusBadRequest) | ||
return | ||
} | ||
msg.Role = "user" | ||
//颠倒用户输入 | ||
if config.GetReverseUserPrompt() { | ||
msg.Text = utils.ReverseString(msg.Text) | ||
} | ||
|
||
if msg.ConversationID == "" { | ||
msg.ConversationID = utils.GenerateUUID() | ||
app.createConversation(msg.ConversationID) | ||
} | ||
|
||
//转换一下 | ||
tempmsg := structs.Message{ | ||
ConversationID: msg.ConversationID, | ||
ParentMessageID: msg.ParentMessageID, | ||
Text: msg.Text, | ||
Role: msg.Role, | ||
CreatedAt: msg.CreatedAt, | ||
} | ||
|
||
userMessageID, err := app.addMessage(tempmsg) | ||
if err != nil { | ||
http.Error(w, err.Error(), http.StatusInternalServerError) | ||
return | ||
} | ||
|
||
// 构建请求负载 | ||
var payload structs.WXRequestPayloadF | ||
|
||
// 添加当前用户消息 | ||
payload.Messages = append(payload.Messages, structs.WXMessage{ | ||
Content: msg.Text, | ||
Role: "user", | ||
}) | ||
|
||
TopP := config.GetWenxinTopp() | ||
PenaltyScore := config.GetWnxinPenaltyScore() | ||
MaxOutputTokens := config.GetWenxinMaxOutputTokens() | ||
|
||
// 设置其他可选参数 | ||
payload.TopP = TopP | ||
payload.PenaltyScore = PenaltyScore | ||
payload.MaxOutputTokens = MaxOutputTokens | ||
// 增加function | ||
payload.Functions = append(payload.Functions, msg.WXFunction) | ||
//payload.ResponseFormat = "json_object" | ||
payload.ToolChoice.Type = "function" | ||
payload.ToolChoice.Function.Name = "predict_followup_questions" | ||
|
||
// 是否sse | ||
if config.GetuseSse() { | ||
payload.Stream = true | ||
} | ||
|
||
// 获取系统提示词,并设置system字段,如果它不为空 | ||
systemPromptContent := config.SystemPrompt() // 确保函数名正确 | ||
if systemPromptContent != "0" { | ||
payload.System = systemPromptContent // 直接在请求负载中设置system字段 | ||
} | ||
|
||
// 获取访问凭证和API路径 | ||
accessToken := config.GetWenxinAccessToken() | ||
apiPath := config.GetWenxinApiPath() | ||
|
||
// 构建请求URL | ||
url := fmtf.Sprintf("%s?access_token=%s", apiPath, accessToken) | ||
fmtf.Printf("%v\n", url) | ||
|
||
// 序列化请求负载 | ||
jsonData, err := json.Marshal(payload) | ||
if err != nil { | ||
log.Fatalf("Error occurred during marshaling. Error: %s", err.Error()) | ||
} | ||
|
||
fmtf.Printf("%v\n", string(jsonData)) | ||
|
||
// 创建并发送POST请求 | ||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) | ||
if err != nil { | ||
log.Fatalf("Error occurred during request creation. Error: %s", err.Error()) | ||
} | ||
req.Header.Set("Content-Type", "application/json") | ||
|
||
client := &http.Client{} | ||
resp, err := client.Do(req) | ||
if err != nil { | ||
log.Fatalf("Error occurred during sending the request. Error: %s", err.Error()) | ||
} | ||
defer resp.Body.Close() | ||
|
||
// 读取响应头中的速率限制信息 | ||
rateLimitRequests := resp.Header.Get("X-Ratelimit-Limit-Requests") | ||
rateLimitTokens := resp.Header.Get("X-Ratelimit-Limit-Tokens") | ||
remainingRequests := resp.Header.Get("X-Ratelimit-Remaining-Requests") | ||
remainingTokens := resp.Header.Get("X-Ratelimit-Remaining-Tokens") | ||
|
||
fmtf.Printf("RateLimit: Requests %s, Tokens %s, Remaining Requests %s, Remaining Tokens %s\n", | ||
rateLimitRequests, rateLimitTokens, remainingRequests, remainingTokens) | ||
|
||
// 检查是否不使用SSE | ||
if !config.GetuseSse() { | ||
// 读取整个响应体到内存中 | ||
bodyBytes, err := io.ReadAll(resp.Body) | ||
if err != nil { | ||
log.Fatalf("Error occurred during response body reading. Error: %s", err) | ||
} | ||
|
||
// 首先尝试解析为简单的map来查看响应概览 | ||
var response map[string]interface{} | ||
if err := json.Unmarshal(bodyBytes, &response); err != nil { | ||
log.Fatalf("Error occurred during response decoding to map. Error: %s", err) | ||
} | ||
fmtf.Printf("%v\n", response) | ||
|
||
// 然后尝试解析为具体的结构体以获取详细信息 | ||
var responseStruct struct { | ||
ID string `json:"id"` | ||
Object string `json:"object"` | ||
Created int `json:"created"` | ||
SentenceID int `json:"sentence_id,omitempty"` | ||
IsEnd bool `json:"is_end,omitempty"` | ||
IsTruncated bool `json:"is_truncated"` | ||
Result string `json:"result"` | ||
NeedClearHistory bool `json:"need_clear_history"` | ||
BanRound int `json:"ban_round"` | ||
Usage struct { | ||
PromptTokens int `json:"prompt_tokens"` | ||
CompletionTokens int `json:"completion_tokens"` | ||
TotalTokens int `json:"total_tokens"` | ||
} `json:"usage"` | ||
} | ||
|
||
if err := json.Unmarshal(bodyBytes, &responseStruct); err != nil { | ||
http.Error(w, fmtf.Sprintf("解析响应体出错: %v", err), http.StatusInternalServerError) | ||
return | ||
} | ||
// 根据API响应构造消息和响应给客户端 | ||
assistantMessageID, err := app.addMessage(structs.Message{ | ||
ConversationID: msg.ConversationID, | ||
ParentMessageID: userMessageID, | ||
Text: responseStruct.Result, | ||
Role: "assistant", | ||
}) | ||
|
||
if err != nil { | ||
http.Error(w, err.Error(), http.StatusInternalServerError) | ||
return | ||
} | ||
|
||
// 构造响应 | ||
responseMap := map[string]interface{}{ | ||
"response": responseStruct.Result, | ||
"conversationId": msg.ConversationID, | ||
"messageId": assistantMessageID, | ||
"details": map[string]interface{}{ | ||
"usage": map[string]int{ | ||
"prompt_tokens": responseStruct.Usage.PromptTokens, | ||
"completion_tokens": responseStruct.Usage.CompletionTokens, | ||
"total_tokens": responseStruct.Usage.TotalTokens, | ||
}, | ||
}, | ||
} | ||
|
||
// 设置响应头信息以反映速率限制状态 | ||
w.Header().Set("Content-Type", "application/json") | ||
w.Header().Set("X-Ratelimit-Limit-Requests", rateLimitRequests) | ||
w.Header().Set("X-Ratelimit-Limit-Tokens", rateLimitTokens) | ||
w.Header().Set("X-Ratelimit-Remaining-Requests", remainingRequests) | ||
w.Header().Set("X-Ratelimit-Remaining-Tokens", remainingTokens) | ||
|
||
// 发送JSON响应 | ||
json.NewEncoder(w).Encode(responseMap) | ||
} else { | ||
// SSE响应模式 | ||
// 设置SSE相关的响应头部 | ||
w.Header().Set("Content-Type", "text/event-stream") | ||
w.Header().Set("Cache-Control", "no-cache") | ||
w.Header().Set("Connection", "keep-alive") | ||
|
||
flusher, ok := w.(http.Flusher) | ||
if !ok { | ||
http.Error(w, "Streaming unsupported!", http.StatusInternalServerError) | ||
return | ||
} | ||
|
||
var responseTextBuilder strings.Builder | ||
var totalUsage structs.UsageInfo | ||
|
||
// 假设我们已经建立了与API的连接并且开始接收流式响应 | ||
// reader代表从API接收数据的流 | ||
reader := bufio.NewReader(resp.Body) | ||
for { | ||
// 读取流中的一行,即一个事件数据块 | ||
line, err := reader.ReadString('\n') | ||
if err != nil { | ||
if err == io.EOF { | ||
// 流结束 | ||
break | ||
} | ||
// 处理错误 | ||
fmtf.Fprintf(w, "data: %s\n\n", fmtf.Sprintf("读取流数据时发生错误: %v", err)) | ||
flusher.Flush() | ||
continue | ||
} | ||
|
||
// 处理流式数据行 | ||
if strings.HasPrefix(line, "data: ") { | ||
eventDataJSON := line[6:] // 去掉"data: "前缀 | ||
|
||
var eventData struct { | ||
ID string `json:"id"` | ||
Object string `json:"object"` | ||
Created int `json:"created"` | ||
SentenceID int `json:"sentence_id,omitempty"` | ||
IsEnd bool `json:"is_end,omitempty"` | ||
IsTruncated bool `json:"is_truncated"` | ||
Result string `json:"result"` | ||
NeedClearHistory bool `json:"need_clear_history"` | ||
BanRound int `json:"ban_round"` | ||
Usage struct { | ||
PromptTokens int `json:"prompt_tokens"` | ||
CompletionTokens int `json:"completion_tokens"` | ||
TotalTokens int `json:"total_tokens"` | ||
} `json:"usage"` | ||
} | ||
// 解析JSON数据 | ||
if err := json.Unmarshal([]byte(eventDataJSON), &eventData); err != nil { | ||
fmtf.Fprintf(w, "data: %s\n\n", fmtf.Sprintf("解析事件数据出错: %v", err)) | ||
flusher.Flush() | ||
continue | ||
} | ||
|
||
// 这里处理解析后的事件数据 | ||
responseTextBuilder.WriteString(eventData.Result) | ||
totalUsage.PromptTokens += eventData.Usage.PromptTokens | ||
totalUsage.CompletionTokens += eventData.Usage.CompletionTokens | ||
|
||
// 发送当前事件的响应数据,但不包含assistantMessageID | ||
tempResponseMap := map[string]interface{}{ | ||
"response": eventData.Result, | ||
"conversationId": msg.ConversationID, | ||
"details": map[string]interface{}{ | ||
"usage": eventData.Usage, | ||
}, | ||
} | ||
tempResponseJSON, _ := json.Marshal(tempResponseMap) | ||
fmtf.Fprintf(w, "data: %s\n\n", string(tempResponseJSON)) | ||
flusher.Flush() | ||
|
||
// 如果这是最后一个消息 | ||
if eventData.IsEnd { | ||
break | ||
} | ||
} | ||
} | ||
|
||
// 处理完所有事件后,生成并发送包含assistantMessageID的最终响应 | ||
//fmt.Printf("处理完所有事件后,生成并发送包含assistantMessageID的最终响应\n") | ||
responseText := responseTextBuilder.String() | ||
assistantMessageID, err := app.addMessage(structs.Message{ | ||
ConversationID: msg.ConversationID, | ||
ParentMessageID: userMessageID, | ||
Text: responseText, | ||
Role: "assistant", | ||
}) | ||
|
||
if err != nil { | ||
http.Error(w, err.Error(), http.StatusInternalServerError) | ||
return | ||
} | ||
|
||
finalResponseMap := map[string]interface{}{ | ||
"response": responseText, | ||
"conversationId": msg.ConversationID, | ||
"messageId": assistantMessageID, | ||
"details": map[string]interface{}{ | ||
"usage": totalUsage, | ||
}, | ||
} | ||
finalResponseJSON, _ := json.Marshal(finalResponseMap) | ||
fmt.Fprintf(w, "data: %s\n\n", string(finalResponseJSON)) | ||
flusher.Flush() | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.