Skip to content

Commit

Permalink
Merge pull request #857 from trheyi/main
Browse files Browse the repository at this point in the history
Refactor OpenAI message parsing for improved streaming support
  • Loading branch information
trheyi authored Feb 9, 2025
2 parents c575bd9 + 38ee806 commit dafb181
Show file tree
Hide file tree
Showing 2 changed files with 260 additions and 73 deletions.
266 changes: 193 additions & 73 deletions neo/message/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,6 @@ func NewAny(content interface{}) (*Message, error) {
}

// NewOpenAI create a new message from OpenAI response
// @todo:
//
// this function need to be refactored
func NewOpenAI(data []byte, isThinking bool) *Message {

// For Debug
Expand All @@ -203,121 +200,244 @@ func NewOpenAI(data []byte, isThinking bool) *Message {
msg := New()
text := string(data)
data = []byte(strings.TrimPrefix(text, "data: "))
switch {

case strings.Contains(text, `"delta":{`) && strings.Contains(text, `"tool_calls"`) && !strings.Contains(text, `"tool_calls":null`):
var toolCalls openai.ToolCalls
if err := jsoniter.Unmarshal(data, &toolCalls); err != nil {
switch {
case strings.Contains(text, `"object":"chat.completion.chunk"`): // Delta content
var chunk openai.ChatCompletionChunk
err := jsoniter.Unmarshal(data, &chunk)
if err != nil {
color.Red("JSON parse error: %s", err.Error())
color.White(string(data))
msg.Text = "JSON parse error\n" + string(data)
msg.Type = "error"
msg.IsDone = true
}

// Empty content, then it is a pending message
if len(chunk.Choices) == 0 {
msg.Pending = true
return msg
}

msg.Type = "tool_calls_native"
if len(toolCalls.Choices) > 0 && len(toolCalls.Choices[0].Delta.ToolCalls) > 0 {
id := toolCalls.Choices[0].Delta.ToolCalls[0].ID
function := toolCalls.Choices[0].Delta.ToolCalls[0].Function.Name
arguments := toolCalls.Choices[0].Delta.ToolCalls[0].Function.Arguments
// Tool calls
if len(chunk.Choices[0].Delta.ToolCalls) > 0 {
msg.Type = "tool_calls_native"
id := chunk.Choices[0].Delta.ToolCalls[0].ID
function := chunk.Choices[0].Delta.ToolCalls[0].Function.Name
arguments := chunk.Choices[0].Delta.ToolCalls[0].Function.Arguments
text := arguments
if id != "" {
text = fmt.Sprintf(`{"id": "%s", "function": "%s", "arguments": %s`, id, function, arguments)
}
msg.Text = text
}

case strings.Contains(text, `"delta":{`) && strings.Contains(text, `"content":`):
var message openai.MessageWithReasoningContent
if err := jsoniter.Unmarshal(data, &message); err != nil {
color.Red("JSON parse error: %s", err.Error())
color.White(string(data))
msg.Text = "JSON parse error\n" + string(data)
msg.Type = "error"
msg.IsDone = true
msg.Text = text
msg.IsDone = chunk.Choices[0].FinishReason == "tool_calls" // is done when tool calls are finished
fmt.Printf("arguments: %s, finish_reason: %#v\n", arguments, chunk.Choices[0].FinishReason)
return msg
}

msg.Type = "text"
if len(message.Choices) > 0 {
if reasoningContent, ok := message.Choices[0].Delta["reasoning_content"].(string); ok {
msg.Text = reasoningContent
msg.Type = "think"
return msg
}

if content, ok := message.Choices[0].Delta["content"].(string); ok && content != "" {
msg.Text = content
msg.Type = "text"
return msg
}

if isThinking {
msg.Type = "think"
msg.Text = ""
return msg
}

msg.Text = ""
// Text content
if chunk.Choices[0].Delta.Content != "" {
msg.Type = "text"
msg.Text = chunk.Choices[0].Delta.Content
msg.IsDone = chunk.Choices[0].FinishReason == "stop" // is done when the content is finished
return msg
}

case strings.Index(text, `{"code":`) == 0 || strings.Index(text, `"statusCode":`) > 0:
var errorMessage openai.Error
if err := jsoniter.UnmarshalFromString(text, &errorMessage); err != nil {
color.Red("JSON parse error: %s", err.Error())
color.White(string(data))
msg.Text = "JSON parse error\n" + string(data)
msg.Type = "error"
// Done messages
if chunk.Choices[0].FinishReason == "stop" || chunk.Choices[0].FinishReason == "tool_calls" {
msg.IsDone = true
return msg
}
msg.Type = "error"
msg.Text = errorMessage.Message
msg.IsDone = true
break

case strings.Contains(text, `{"error":{`):
var errorMessage openai.ErrorMessage
if err := jsoniter.Unmarshal(data, &errorMessage); err != nil {
color.Red("JSON parse error: %s", err.Error())
color.White(string(data))
msg.Text = "JSON parse error\n" + string(data)
msg.Type = "error"
msg.IsDone = true
// Reasoning content
if chunk.Choices[0].Delta.ReasoningContent != "" {
msg.Type = "think"
msg.Text = chunk.Choices[0].Delta.ReasoningContent
return msg
}
msg.Type = "error"
msg.Text = errorMessage.Error.Message
msg.IsDone = true
break
// Content is empty and is thinking, then it is a thinking message pending
if isThinking {
msg.Type = "think"
msg.Text = ""
return msg
}

msg.Text = ""
return msg

case strings.Contains(text, `"usage":`) && !strings.Contains(text, `"chat.completion.chunk`):
case strings.Contains(text, `"usage":`): // usage content
msg.IsDone = true
break

case strings.Contains(text, `[DONE]`):
msg.IsDone = true
return msg

case strings.Contains(text, `"finish_reason":"stop"`):
msg.IsDone = true
case len(data) > 2 && data[0] == '{' && data[len(data)-1] == '}': // JSON content (error)

case strings.Contains(text, `"finish_reason":"tool_calls"`):
var error openai.Error
var errorMessage openai.ErrorMessage
if strings.Contains(string(data), `"error":`) {
if err := jsoniter.Unmarshal(data, &errorMessage); err != nil {
color.Red("JSON parse error: %s", err.Error())
color.White(string(data))
msg.Text = "JSON parse error\n" + string(data)
msg.Type = "error"
msg.IsDone = true
return msg
}
error = errorMessage.Error
} else {
err := jsoniter.Unmarshal(data, &error)
if err != nil {
color.Red("JSON parse error: %s", err.Error())
color.White(string(data))
msg.Text = "JSON parse error\n" + string(data)
msg.Type = "error"
msg.IsDone = true
return msg
}
}

message := error.Message
if message == "" {
message = "Unknown error occurred\n" + string(data)
}

msg.Type = "error"
msg.Text = message
msg.IsDone = true
return msg

// Not a data message
case !strings.Contains(text, `data: `):
case !strings.Contains(text, `data: `): // unknown message or uncompleted message
msg.Pending = true
msg.Text = text
return msg

default:
default: // unknown message
str := strings.TrimPrefix(strings.Trim(string(data), "\""), "data: ")
msg.Type = "error"
msg.Text = str
return msg
}

return msg

// switch {

// case strings.Contains(text, `"delta":{`) && strings.Contains(text, `"tool_calls"`) && !strings.Contains(text, `"tool_calls":null`):
// var toolCalls openai.ToolCalls
// if err := jsoniter.Unmarshal(data, &toolCalls); err != nil {
// color.Red("JSON parse error: %s", err.Error())
// color.White(string(data))
// msg.Text = "JSON parse error\n" + string(data)
// msg.Type = "error"
// msg.IsDone = true
// return msg
// }

// msg.Type = "tool_calls_native"
// if len(toolCalls.Choices) > 0 && len(toolCalls.Choices[0].Delta.ToolCalls) > 0 {
// id := toolCalls.Choices[0].Delta.ToolCalls[0].ID
// function := toolCalls.Choices[0].Delta.ToolCalls[0].Function.Name
// arguments := toolCalls.Choices[0].Delta.ToolCalls[0].Function.Arguments
// text := arguments
// if id != "" {
// text = fmt.Sprintf(`{"id": "%s", "function": "%s", "arguments": %s`, id, function, arguments)
// }
// msg.Text = text
// }

// case strings.Contains(text, `"delta":{`) && strings.Contains(text, `"content":`):
// var message openai.MessageWithReasoningContent
// if err := jsoniter.Unmarshal(data, &message); err != nil {
// color.Red("JSON parse error: %s", err.Error())
// color.White(string(data))
// msg.Text = "JSON parse error\n" + string(data)
// msg.Type = "error"
// msg.IsDone = true
// return msg
// }

// msg.Type = "text"
// if len(message.Choices) > 0 {
// if reasoningContent, ok := message.Choices[0].Delta["reasoning_content"].(string); ok {
// msg.Text = reasoningContent
// msg.Type = "think"
// return msg
// }

// if content, ok := message.Choices[0].Delta["content"].(string); ok && content != "" {
// msg.Text = content
// msg.Type = "text"
// return msg
// }

// if isThinking {
// msg.Type = "think"
// msg.Text = ""
// return msg
// }

// msg.Text = ""
// return msg
// }

// case strings.Index(text, `{"code":`) == 0 || strings.Index(text, `"statusCode":`) > 0:
// var errorMessage openai.Error
// if err := jsoniter.UnmarshalFromString(text, &errorMessage); err != nil {
// color.Red("JSON parse error: %s", err.Error())
// color.White(string(data))
// msg.Text = "JSON parse error\n" + string(data)
// msg.Type = "error"
// msg.IsDone = true
// return msg
// }
// msg.Type = "error"
// msg.Text = errorMessage.Message
// msg.IsDone = true
// break

// case strings.Contains(text, `{"error":{`):
// var errorMessage openai.ErrorMessage
// if err := jsoniter.Unmarshal(data, &errorMessage); err != nil {
// color.Red("JSON parse error: %s", err.Error())
// color.White(string(data))
// msg.Text = "JSON parse error\n" + string(data)
// msg.Type = "error"
// msg.IsDone = true
// return msg
// }
// msg.Type = "error"
// msg.Text = errorMessage.Error.Message
// msg.IsDone = true
// break

// case strings.Contains(text, `"usage":`) && !strings.Contains(text, `"chat.completion.chunk`):
// msg.IsDone = true
// break

// case strings.Contains(text, `[DONE]`):
// msg.IsDone = true

// case strings.Contains(text, `"finish_reason":"stop"`):
// msg.IsDone = true

// case strings.Contains(text, `"finish_reason":"tool_calls"`):
// msg.IsDone = true

// // Not a data message
// case !strings.Contains(text, `data: `):
// msg.Pending = true
// msg.Text = text

// default:
// str := strings.TrimPrefix(strings.Trim(string(data), "\""), "data: ")
// msg.Type = "error"
// msg.Text = str
// }

}

// String returns the string representation
Expand Down
67 changes: 67 additions & 0 deletions openai/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,73 @@ type MessageWithReasoningContent struct {
} `json:"choices,omitempty"`
}

// ChatCompletionChunk is the response from OpenAI
type ChatCompletionChunk struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint,omitempty"`
Choices []ChatCompletionChunkChoice `json:"choices"`
}

// ChatCompletionChunkChoice represents a chunk choice in the response
type ChatCompletionChunkChoice struct {
Index int `json:"index"`
Delta ChatCompletionChunkDelta `json:"delta"`
LogProbs *LogProbs `json:"logprobs,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
}

// ChatCompletionChunkDelta represents the delta content in a chunk
type ChatCompletionChunkDelta struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"`
}

// LogProbs represents the log probabilities in a response
type LogProbs struct {
Content []ContentLogProb `json:"content,omitempty"`
}

// ContentLogProb represents a single token's log probability information
type ContentLogProb struct {
Token string `json:"token"`
LogProb float64 `json:"logprob"`
Bytes []int `json:"bytes,omitempty"`
TopLogProbs []LogProb `json:"top_logprobs,omitempty"`
}

// LogProb represents a token and its log probability
type LogProb struct {
Token string `json:"token"`
LogProb float64 `json:"logprob"`
Bytes []int `json:"bytes,omitempty"`
}

// ToolCall represents a tool call in the response
type ToolCall struct {
Index int `json:"index"`
ID string `json:"id"`
Type string `json:"type"`
Function Function `json:"function"`
}

// FunctionCall represents a function call in the response
type FunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}

// Function represents a function in a tool call
type Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}

// ToolCalls is the response from OpenAI
type ToolCalls struct {
ID string `json:"id,omitempty"`
Expand Down

0 comments on commit dafb181

Please sign in to comment.