Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(go): add support for multiple model versions #1575

Merged
merged 11 commits into from
Feb 5, 2025
74 changes: 59 additions & 15 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0


package ai

import (
Expand Down Expand Up @@ -35,18 +34,11 @@ type modelAction = core.Action[*ModelRequest, *ModelResponse, *ModelResponseChun
// ModelStreamingCallback is the type for the streaming callback of a model.
type ModelStreamingCallback = func(context.Context, *ModelResponseChunk) error

// ModelCapabilities describes various capabilities of the model.
type ModelCapabilities struct {
Multiturn bool // the model can handle multiple request-response interactions
Media bool // the model supports media as well as text input
Tools bool // the model supports tools
SystemRole bool // the model supports a system prompt or role
}

// ModelMetadata is the metadata of the model, specifying things like nice user-visible label, capabilities, etc.
type ModelMetadata struct {
Label string
Supports ModelCapabilities
Versions []string
Info ModelInfo
}

// DefineModel registers the given generate function as an action, and returns a
Expand All @@ -62,18 +54,24 @@ func DefineModel(
// Always make sure there's at least minimal metadata.
metadata = &ModelMetadata{
Label: name,
Info: ModelInfo{
Label: name,
Supports: &ModelInfoSupports{},
},
Versions: []string{},
}
}
if metadata.Label != "" {
metadataMap["label"] = metadata.Label
}
supports := map[string]bool{
"media": metadata.Supports.Media,
"multiturn": metadata.Supports.Multiturn,
"systemRole": metadata.Supports.SystemRole,
"tools": metadata.Supports.Tools,
"media": metadata.Info.Supports.Media,
"multiturn": metadata.Info.Supports.Multiturn,
"systemRole": metadata.Info.Supports.SystemRole,
"tools": metadata.Info.Supports.Tools,
}
metadataMap["supports"] = supports
metadataMap["versions"] = metadata.Versions

return (*modelActionDef)(core.DefineStreamingAction(r, provider, name, atype.Model, map[string]any{
"model": metadataMap,
Expand All @@ -100,8 +98,8 @@ type generateParams struct {
Request *ModelRequest
Model Model
Stream ModelStreamingCallback
History []*Message
SystemPrompt *Message
History []*Message
}

// GenerateOption configures params of the Generate call.
Expand Down Expand Up @@ -242,6 +240,19 @@ func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption)
if req.Model == nil {
return nil, errors.New("model is required")
}

var modelVersion string
if config, ok := req.Request.Config.(*GenerationCommonConfig); ok {
modelVersion = config.Version
}

if modelVersion != "" {
ok, err := validateModelVersion(r, modelVersion, req)
if !ok {
return nil, err
}
}

if req.History != nil {
prev := req.Request.Messages
req.Request.Messages = req.History
Expand All @@ -256,6 +267,39 @@ func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption)
return req.Model.Generate(ctx, r, req.Request, req.Stream)
}

// validateModelVersion checks in the registry the action of the
// given model version and determines whether its supported or not.
func validateModelVersion(r *registry.Registry, v string, req *generateParams) (bool, error) {
parts := strings.Split(req.Model.Name(), "/")
if len(parts) != 2 {
return false, errors.New("wrong model name")
}

a := core.LookupActionFor[*ModelRequest, *ModelResponse, *ModelResponseChunk](r, atype.Model, parts[0], parts[1])
if a == nil {
return false, errors.New("model action not defined")
}

if !modelVersionSupported(v, a.Desc().Metadata) {
return false, fmt.Errorf("version not supported: %s", v)
}

return true, nil
}

// modelVersionSupported iterates over model's metadata to find the requested
// supported model version
func modelVersionSupported(modelVersion string, modelMetadata map[string]any) bool {
if md, ok := modelMetadata["model"].(map[string]any); ok {
for _, v := range md["versions"].([]string) {
if modelVersion == v {
return true
}
}
}
return false
}

// GenerateText run generate request for this model. Returns generated text only.
func GenerateText(ctx context.Context, r *registry.Registry, opts ...GenerateOption) (string, error) {
res, err := Generate(ctx, r, opts...)
Expand Down
20 changes: 12 additions & 8 deletions go/internal/doc-snippets/modelplugin/modelplugin.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0


package modelplugin

import (
Expand Down Expand Up @@ -30,15 +29,20 @@ func Init() error {
}

// [START definemodel]
name := "my-model"
genkit.DefineModel(g,
providerID, "my-model",
providerID, name,
&ai.ModelMetadata{
Label: "my-model",
Supports: ai.ModelCapabilities{
Multiturn: true, // Does the model support multi-turn chats?
SystemRole: true, // Does the model support syatem messages?
Media: false, // Can the model accept media input?
Tools: false, // Does the model support function calling (tools)?
Label: name,
Info: ai.ModelInfo{
Label: name,
Supports: &ai.ModelInfoSupports{
Multiturn: true, // Does the model support multi-turn chats?
SystemRole: true, // Does the model support syatem messages?
Media: false, // Can the model accept media input?
Tools: false, // Does the model support function calling (tools)?
},
Versions: []string{},
},
},
func(ctx context.Context,
Expand Down
16 changes: 10 additions & 6 deletions go/internal/doc-snippets/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,21 @@ func ollamaEx(ctx context.Context) error {
// [END init]

// [START definemodel]
name := "gemma2"
model := ollama.DefineModel(
g,
ollama.ModelDefinition{
Name: "gemma2",
Name: name,
Type: "chat", // "chat" or "generate"
},
&ai.ModelCapabilities{
Multiturn: true,
SystemRole: true,
Tools: false,
Media: false,
&ai.ModelInfo{
Label: name,
Supports: &ai.ModelInfoSupports{
Multiturn: true,
SystemRole: true,
Tools: false,
Media: false,
},
},
)
// [END definemodel]
Expand Down
57 changes: 37 additions & 20 deletions go/plugins/googleai/googleai.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0


// Parts of this file are copied into vertexai, because the code is identical
// except for the import path of the Gemini SDK.
//go:generate go run ../../internal/cmd/copy -dest ../vertexai googleai.go
Expand Down Expand Up @@ -31,17 +30,33 @@ const (
)

var state struct {
mu sync.Mutex
initted bool
// These happen to be the same.
gclient, pclient *genai.Client
mu sync.Mutex
initted bool
}

var (
knownCaps = map[string]ai.ModelCapabilities{
"gemini-1.0-pro": gemini.BasicText,
"gemini-1.5-pro": gemini.Multimodal,
"gemini-1.5-flash": gemini.Multimodal,
supportedModels = map[string]ai.ModelInfo{
"gemini-1.0-pro": {
Versions: []string{"gemini-pro", "gemini-1.0-pro-latest", "gemini-1.0-pro-001"},
Supports: &gemini.BasicText,
},

"gemini-1.5-flash": {
Versions: []string{"gemini-1.5-flash-latest", "gemini-1.5-flash-001", "gemini-1.5-flash-002"},
Supports: &gemini.Multimodal,
},

"gemini-1.5-pro": {
Versions: []string{"gemini-1.5-pro-latest", "gemini-1.5-pro-001", "gemini-1.5-pro-002"},
Supports: &gemini.Multimodal,
},

"gemini-1.5-flash-8b": {
Versions: []string{"gemini-1.5-flash-8b-latest", "gemini-1.5-flash-8b-001"},
Supports: &gemini.Multimodal,
},
}

knownEmbedders = []string{"text-embedding-004", "embedding-001"}
Expand Down Expand Up @@ -88,7 +103,8 @@ func Init(ctx context.Context, g *genkit.Genkit, cfg *Config) (err error) {

opts := append([]option.ClientOption{
option.WithAPIKey(apiKey),
genai.WithClientInfo("genkit-go", internal.Version)},
genai.WithClientInfo("genkit-go", internal.Version),
},
cfg.ClientOptions...,
)
client, err := genai.NewClient(ctx, opts...)
Expand All @@ -98,8 +114,8 @@ func Init(ctx context.Context, g *genkit.Genkit, cfg *Config) (err error) {
state.gclient = client
state.pclient = client
state.initted = true
for model, caps := range knownCaps {
defineModel(g, model, caps)
for model, details := range supportedModels {
defineModel(g, model, details)
}
for _, e := range knownEmbedders {
defineEmbedder(g, e)
Expand All @@ -113,30 +129,32 @@ func Init(ctx context.Context, g *genkit.Genkit, cfg *Config) (err error) {
// The second argument describes the capability of the model.
// Use [IsDefinedModel] to determine if a model is already defined.
// After [Init] is called, only the known models are defined.
func DefineModel(g *genkit.Genkit, name string, caps *ai.ModelCapabilities) (ai.Model, error) {
func DefineModel(g *genkit.Genkit, name string, info *ai.ModelInfo) (ai.Model, error) {
state.mu.Lock()
defer state.mu.Unlock()
if !state.initted {
panic(provider + ".Init not called")
}
var mc ai.ModelCapabilities
if caps == nil {
var mi ai.ModelInfo
if info == nil {
var ok bool
mc, ok = knownCaps[name]
mi, ok = supportedModels[name]
if !ok {
return nil, fmt.Errorf("%s.DefineModel: called with unknown model %q and nil ModelCapabilities", provider, name)
return nil, fmt.Errorf("%s.DefineModel: called with unknown model %q and nil ModelInfo", provider, name)
}
} else {
mc = *caps
// TODO: unknown models could also specify versions?
mi = *info
}
return defineModel(g, name, mc), nil
return defineModel(g, name, mi), nil
}

// requires state.mu
func defineModel(g *genkit.Genkit, name string, caps ai.ModelCapabilities) ai.Model {
func defineModel(g *genkit.Genkit, name string, info ai.ModelInfo) ai.Model {
meta := &ai.ModelMetadata{
Label: labelPrefix + " - " + name,
Supports: caps,
Info: info,
Versions: info.Versions,
}
return genkit.DefineModel(g, provider, name, meta, func(
ctx context.Context,
Expand Down Expand Up @@ -317,7 +335,6 @@ func newModel(client *genai.Client, model string, input *ai.ModelRequest) (*gena
systemParts, err := convertParts(m.Content)
if err != nil {
return nil, err

}
// system prompts go into GenerativeModel.SystemInstruction field.
if m.Role == ai.RoleSystem {
Expand Down
4 changes: 2 additions & 2 deletions go/plugins/internal/gemini/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ import "github.com/firebase/genkit/go/ai"

var (
// BasicText describes model capabilities for text-only Gemini models.
BasicText = ai.ModelCapabilities{
BasicText = ai.ModelInfoSupports{
Multiturn: true,
Tools: true,
SystemRole: true,
Media: false,
}

// Multimodal describes model capabilities for multimodal Gemini models.
Multimodal = ai.ModelCapabilities{
Multimodal = ai.ModelInfoSupports{
Multiturn: true,
Tools: true,
SystemRole: true,
Expand Down
Loading
Loading