Skip to content

Commit

Permalink
Add Support for Embeddings Endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
ekatiyar committed Aug 6, 2024
1 parent 9d1fbf6 commit acc7125
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 0 deletions.
68 changes: 68 additions & 0 deletions api/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ func ModelListHandler(c *gin.Context) {
Object: "model",
OwnedBy: "openai",
},
openai.Model{
CreatedAt: 1686935002,
ID: openai.GPT3Ada002,
Object: "model",
OwnedBy: "openai",
},
},
})
}
Expand Down Expand Up @@ -154,3 +160,65 @@ func setEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
}

func EmbeddingProxyHandler(c *gin.Context) {
// Retrieve the Authorization header value
authorizationHeader := c.GetHeader("Authorization")
// Declare a variable to store the OPENAI_API_KEY
var openaiAPIKey string
// Use fmt.Sscanf to extract the Bearer token
_, err := fmt.Sscanf(authorizationHeader, "Bearer %s", &openaiAPIKey)
if err != nil {
c.JSON(http.StatusBadRequest, openai.APIError{
Code: http.StatusBadRequest,
Message: err.Error(),
})
return
}

req := &adapter.EmbeddingRequest{}
// Bind the JSON data from the request to the struct
if err := c.ShouldBindJSON(req); err != nil {
c.JSON(http.StatusBadRequest, openai.APIError{
Code: http.StatusBadRequest,
Message: err.Error(),
})
return
}

messages, err := req.ToGenaiMessages()
if err != nil {
c.JSON(http.StatusBadRequest, openai.APIError{
Code: http.StatusBadRequest,
Message: err.Error(),
})
return
}

ctx := c.Request.Context()
client, err := genai.NewClient(ctx, option.WithAPIKey(openaiAPIKey))
if err != nil {
log.Printf("new genai client error %v\n", err)
c.JSON(http.StatusBadRequest, openai.APIError{
Code: http.StatusBadRequest,
Message: err.Error(),
})
return
}
defer client.Close()

model := req.ToGenaiModel()
gemini := adapter.NewGeminiAdapter(client, model)

resp, err := gemini.GenerateEmbedding(ctx, messages)
if err != nil {
log.Printf("genai generate content error %v\n", err)
c.JSON(http.StatusBadRequest, openai.APIError{
Code: http.StatusBadRequest,
Message: err.Error(),
})
return
}

c.JSON(http.StatusOK, resp)
}
3 changes: 3 additions & 0 deletions api/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,7 @@ func Register(router *gin.Engine) {

// openai chat
router.POST("/v1/chat/completions", ChatProxyHandler)

// openai embeddings
router.POST("/v1/embeddings", EmbeddingProxyHandler)
}
35 changes: 35 additions & 0 deletions pkg/adapter/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const (
Gemini1Pro = "gemini-1.0-pro-latest"
Gemini1Dot5Pro = "gemini-1.5-pro-latest"
Gemini1Dot5Flash = "gemini-1.5-flash-latest"
TextEmbedding004 = "text-embedding-004"

genaiRoleUser = "user"
genaiRoleModel = "model"
Expand Down Expand Up @@ -239,3 +240,37 @@ func setGenaiModelByOpenaiRequest(model *genai.GenerativeModel, req *ChatComplet
},
}
}

func (g *GeminiAdapter) GenerateEmbedding(
ctx context.Context,
messages []*genai.Content,
) (*openai.EmbeddingResponse, error) {
model := g.client.EmbeddingModel(g.model)

batchEmbeddings := model.NewBatch()
for _, message := range messages {
batchEmbeddings = batchEmbeddings.AddContent(message.Parts...)
}

genaiResp, err := model.BatchEmbedContents(ctx, batchEmbeddings)
if err != nil {
return nil, errors.Wrap(err, "genai generate embeddings error")
}

openaiResp := openai.EmbeddingResponse{
Object: "list",
Data: make([]openai.Embedding, 0, len(genaiResp.Embeddings)),
Model: openai.EmbeddingModel(g.model),
}

for i, genaiEmbedding := range genaiResp.Embeddings {
embedding := openai.Embedding{
Object: "embedding",
Embedding: genaiEmbedding.Values,
Index: i,
}
openaiResp.Data = append(openaiResp.Data, embedding)
}

return &openaiResp, nil
}
53 changes: 53 additions & 0 deletions pkg/adapter/struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ func (req *ChatCompletionRequest) ToGenaiModel() string {
func (req *ChatCompletionRequest) ToGenaiMessages() ([]*genai.Content, error) {
if req.Model == openai.GPT4VisionPreview {
return req.toVisionGenaiContent()
} else if req.Model == openai.GPT3Ada002 {
return nil, errors.New("Chat Completion is not supported for embedding model")
}

return req.toStringGenaiContent()
Expand Down Expand Up @@ -176,3 +178,54 @@ type CompletionResponse struct {
Model string `json:"model"`
Choices []CompletionChoice `json:"choices"`
}

type StringArray []string

// UnmarshalJSON implements the json.Unmarshaler interface for StringArray.
func (s *StringArray) UnmarshalJSON(data []byte) error {
// Check if the data is a JSON array
if data[0] == '[' {
var arr []string
if err := json.Unmarshal(data, &arr); err != nil {
return err
}
*s = arr
return nil
}

// Check if the data is a JSON string
var str string
if err := json.Unmarshal(data, &str); err != nil {
return err
}
*s = StringArray{str} // Wrap the string in a slice
return nil
}

// EmbeddingRequest represents a request structure for embeddings API.
type EmbeddingRequest struct {
Model string `json:"model" binding:"required"`
Messages StringArray `json:"input" binding:"required,min=1"`
}

func (req *EmbeddingRequest) ToGenaiMessages() ([]*genai.Content, error) {
if req.Model != openai.GPT3Ada002 {
return nil, errors.New("Embedding is not supported for chat model " + req.Model)
}

content := make([]*genai.Content, 0, len(req.Messages))
for _, message := range req.Messages {
embedString := []genai.Part{
genai.Text(message),
}
content = append(content, &genai.Content{
Parts: embedString,
})
}

return content, nil
}

func (req *EmbeddingRequest) ToGenaiModel() string {
return TextEmbedding004
}

0 comments on commit acc7125

Please sign in to comment.