Skip to content

Commit

Permalink
feat: support gemini-vision-pro
Browse files Browse the repository at this point in the history
  • Loading branch information
songquanpeng committed Dec 24, 2023
1 parent f3c07e1 commit 1c89221
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 10 deletions.
35 changes: 35 additions & 0 deletions common/image/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,22 @@ import (
_ "golang.org/x/image/webp"
)

func IsImageUrl(url string) (bool, error) {
resp, err := http.Head(url)
if err != nil {
return false, err
}
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
return false, nil
}
return true, nil
}

func GetImageSizeFromUrl(url string) (width int, height int, err error) {
isImage, err := IsImageUrl(url)
if !isImage {
return
}
resp, err := http.Get(url)
if err != nil {
return
Expand All @@ -28,6 +43,26 @@ func GetImageSizeFromUrl(url string) (width int, height int, err error) {
return img.Width, img.Height, nil
}

func GetImageFromUrl(url string) (mimeType string, data string, err error) {
isImage, err := IsImageUrl(url)
if !isImage {
return
}
resp, err := http.Get(url)
if err != nil {
return
}
defer resp.Body.Close()
buffer := bytes.NewBuffer(nil)
_, err = buffer.ReadFrom(resp.Body)
if err != nil {
return
}
mimeType = resp.Header.Get("Content-Type")
data = base64.StdEncoding.EncodeToString(buffer.Bytes())
return
}

var (
reg = regexp.MustCompile(`data:image/([^;]+);base64,`)
)
Expand Down
1 change: 1 addition & 0 deletions common/model-ratio.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ var ModelRatio = map[string]float64{
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
Expand Down
9 changes: 9 additions & 0 deletions controller/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,15 @@ func init() {
Root: "gemini-pro",
Parent: nil,
},
{
Id: "gemini-pro-vision",
Object: "model",
Created: 1677649963,
OwnedBy: "google",
Permission: permission,
Root: "gemini-pro-vision",
Parent: nil,
},
{
Id: "chatglm_turbo",
Object: "model",
Expand Down
31 changes: 31 additions & 0 deletions controller/relay-gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,18 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/common/image"
"strings"

"github.com/gin-gonic/gin"
)

// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn

const (
GeminiVisionMaxImageNum = 16
)

type GeminiChatRequest struct {
Contents []GeminiChatContent `json:"contents"`
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
Expand Down Expand Up @@ -97,6 +104,30 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
},
},
}
openaiContent := message.ParseContent()
var parts []GeminiPart
imageNum := 0
for _, part := range openaiContent {
if part.Type == ContentTypeText {
parts = append(parts, GeminiPart{
Text: part.Text,
})
} else if part.Type == ContentTypeImageURL {
imageNum += 1
if imageNum > GeminiVisionMaxImageNum {
continue
}
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: data,
},
})
}
}
content.Parts = parts

// there's no assistant role in gemini and API shall vomit if Role is not user or model
if content.Role == "assistant" {
content.Role = "model"
Expand Down
10 changes: 2 additions & 8 deletions controller/relay-text.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if baseURL != "" {
fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL)
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey
case APITypeGemini:
requestBaseURL := "https://generativelanguage.googleapis.com"
if baseURL != "" {
Expand All @@ -197,9 +194,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
action = "streamGenerateContent"
}
fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action)
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey
case APITypeZhipu:
method := "invoke"
if textRequest.Stream {
Expand Down Expand Up @@ -396,9 +390,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
case APITypeTencent:
req.Header.Set("Authorization", apiKey)
case APITypePaLM:
// do not set Authorization header
req.Header.Set("x-goog-api-key", apiKey)
case APITypeGemini:
// do not set Authorization header
req.Header.Set("x-goog-api-key", apiKey)
default:
req.Header.Set("Authorization", "Bearer "+apiKey)
}
Expand Down
59 changes: 58 additions & 1 deletion controller/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,22 @@ type ImageContent struct {
ImageURL *ImageURL `json:"image_url,omitempty"`
}

const (
ContentTypeText = "text"
ContentTypeImageURL = "image_url"
)

type OpenAIMessageContent struct {
Type string `json:"type,omitempty"`
Text string `json:"text"`
ImageURL *ImageURL `json:"image_url,omitempty"`
}

func (m Message) IsStringContent() bool {
_, ok := m.Content.(string)
return ok
}

func (m Message) StringContent() string {
content, ok := m.Content.(string)
if ok {
Expand All @@ -44,7 +60,7 @@ func (m Message) StringContent() string {
if !ok {
continue
}
if contentMap["type"] == "text" {
if contentMap["type"] == ContentTypeText {
if subStr, ok := contentMap["text"].(string); ok {
contentStr += subStr
}
Expand All @@ -55,6 +71,47 @@ func (m Message) StringContent() string {
return ""
}

func (m Message) ParseContent() []OpenAIMessageContent {
var contentList []OpenAIMessageContent
content, ok := m.Content.(string)
if ok {
contentList = append(contentList, OpenAIMessageContent{
Type: ContentTypeText,
Text: content,
})
return contentList
}
anyList, ok := m.Content.([]any)
if ok {
for _, contentItem := range anyList {
contentMap, ok := contentItem.(map[string]any)
if !ok {
continue
}
switch contentMap["type"] {
case ContentTypeText:
if subStr, ok := contentMap["text"].(string); ok {
contentList = append(contentList, OpenAIMessageContent{
Type: ContentTypeText,
Text: subStr,
})
}
case ContentTypeImageURL:
if subObj, ok := contentMap["image_url"].(map[string]any); ok {
contentList = append(contentList, OpenAIMessageContent{
Type: ContentTypeImageURL,
ImageURL: &ImageURL{
Url: subObj["url"].(string),
},
})
}
}
}
return contentList
}
return nil
}

const (
RelayModeUnknown = iota
RelayModeChatCompletions
Expand Down
2 changes: 2 additions & 0 deletions middleware/recover.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"runtime/debug"
)

func RelayPanicRecover() gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
common.SysError(fmt.Sprintf("panic detected: %v", err))
common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err),
Expand Down
2 changes: 1 addition & 1 deletion web/src/pages/Channel/EditChannel.js
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ const EditChannel = () => {
localModels = ['hunyuan'];
break;
case 24:
localModels = ['gemini-pro'];
localModels = ['gemini-pro', 'gemini-pro-vision'];
break;
}
setInputs((inputs) => ({ ...inputs, models: localModels }));
Expand Down

0 comments on commit 1c89221

Please sign in to comment.