Skip to content

Commit

Permalink
feat(py): Core VertexAI plugin (#2014)
Browse files Browse the repository at this point in the history
  • Loading branch information
Irillit authored Feb 20, 2025
1 parent 87a524b commit bffc50d
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 83 deletions.
84 changes: 2 additions & 82 deletions py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,10 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0


"""
Google Cloud Vertex AI Plugin for Genkit.
"""

from typing import Any

import vertexai
from genkit.core.plugin_abc import Plugin
from genkit.core.schema_types import (
GenerateRequest,
GenerateResponse,
Message,
Role,
TextPart,
)
from genkit.veneer.veneer import Genkit
from vertexai.generative_models import Content, GenerativeModel, Part
from genkit.plugins.vertex_ai.plugin_api import VertexAI, vertexai_name


def package_name() -> str:
return 'genkit.plugins.vertex_ai'


def gemini(name: str) -> str:
return f'vertexai/{name}'


class VertexAI(Plugin):
LOCATION = 'us-central1'
VERTEX_AI_MODEL_NAME = gemini('gemini-1.5-flash')
VERTEX_AI_GENERATIVE_MODEL_NAME = 'gemini-1.5-flash-002'

def __init__(self, project_id: str | None = None):
self.project_id = project_id
vertexai.init(location=self.LOCATION, project=self.project_id)

def attach_to_veneer(self, veneer: Genkit) -> None:
self._add_model_to_veneer(veneer=veneer)

def _add_model_to_veneer(self, veneer: Genkit, **kwargs) -> None:
return super()._add_model_to_veneer(
veneer=veneer,
name=self.VERTEX_AI_MODEL_NAME,
metadata=self.vertex_ai_model_metadata,
)

@property
def vertex_ai_model_metadata(self) -> dict[str, dict[str, Any]]:
return {
'model': {
'label': 'banana',
'supports': {'multiturn': True},
}
}

def _model_callback(self, request: GenerateRequest) -> GenerateResponse:
return self._handle_gemini_request(request=request)

def _handle_gemini_request(
self, request: GenerateRequest
) -> GenerateResponse:
gemini_msgs: list[Content] = []
for m in request.messages:
gemini_parts: list[Part] = []
for p in m.content:
if p.root.text is not None:
gemini_parts.append(Part.from_text(p.root.text))
else:
raise Exception('unsupported part type')
gemini_msgs.append(Content(role=m.role.value, parts=gemini_parts))
response = self.vertex_ai_generative_model.generate_content(
contents=gemini_msgs
)
return GenerateResponse(
message=Message(
role=Role.model,
content=[TextPart(text=response.text)],
)
)

@property
def vertex_ai_generative_model(self) -> GenerativeModel:
return GenerativeModel(self.VERTEX_AI_GENERATIVE_MODEL_NAME)


__all__ = ['package_name', 'VertexAI', 'gemini']
__all__ = ['package_name', 'VertexAI', 'vertexai_name']
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

GCLOUD_PROJECT = 'GCLOUD_PROJECT'
DEFAULT_REGION = 'us-central1'
89 changes: 89 additions & 0 deletions py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

from enum import StrEnum

from genkit.core.schema_types import (
GenerateRequest,
GenerateResponse,
Message,
ModelInfo,
Role,
Supports,
TextPart,
)
from vertexai.generative_models import Content, GenerativeModel, Part


class GeminiVersion(StrEnum):
GEMINI_1_5_PRO = 'gemini-1.5-pro'
GEMINI_1_5_FLASH = 'gemini-1.5-flash'
GEMINI_2_0_FLASH_001 = 'gemini-2.0-flash-001'
GEMINI_2_0_FLASH_LITE_PREVIEW = 'gemini-2.0-flash-lite-preview-02-05'
GEMINI_2_0_PRO_EXP = 'gemini-2.0-pro-exp-02-05'


SUPPORTED_MODELS = {
GeminiVersion.GEMINI_1_5_PRO: ModelInfo(
versions=[],
label='Vertex AI - Gemini 1.5 Pro',
supports=Supports(
multiturn=True, media=True, tools=True, systemRole=True
),
),
GeminiVersion.GEMINI_1_5_FLASH: ModelInfo(
versions=[],
label='Vertex AI - Gemini 1.5 Flash',
supports=Supports(
multiturn=True, media=True, tools=True, systemRole=True
),
),
GeminiVersion.GEMINI_2_0_FLASH_001: ModelInfo(
versions=[],
label='Vertex AI - Gemini 2.0 Flash 001',
supports=Supports(
multiturn=True, media=True, tools=True, systemRole=True
),
),
GeminiVersion.GEMINI_2_0_FLASH_LITE_PREVIEW: ModelInfo(
versions=[],
label='Vertex AI - Gemini 2.0 Flash Lite Preview 02-05',
supports=Supports(
multiturn=True, media=True, tools=True, systemRole=True
),
),
GeminiVersion.GEMINI_2_0_PRO_EXP: ModelInfo(
versions=[],
label='Vertex AI - Gemini 2.0 Flash Pro Experimental 02-05',
supports=Supports(
multiturn=True, media=True, tools=True, systemRole=True
),
),
}


class Gemini:
def __init__(self, version):
self.version = version

@property
def gemini_model(self) -> GenerativeModel:
return GenerativeModel(self.version)

def handle_request(self, request: GenerateRequest) -> GenerateResponse:
messages: list[Content] = []
for m in request.messages:
parts: list[Part] = []
for p in m.content:
if p.root.text is not None:
parts.append(Part.from_text(p.root.text))
else:
raise Exception('unsupported part type')
messages.append(Content(role=m.role.value, parts=parts))
response = self.gemini_model.generate_content(contents=messages)
return GenerateResponse(
message=Message(
role=Role.model,
content=[TextPart(text=response.text)],
)
)
59 changes: 59 additions & 0 deletions py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/plugin_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

"""Google Cloud Vertex AI Plugin for Genkit."""

import logging
import os
from typing import Any

import vertexai
from genkit.core.plugin_abc import Plugin
from genkit.core.schema_types import GenerateRequest, GenerateResponse
from genkit.plugins.vertex_ai import constants as const
from genkit.plugins.vertex_ai.gemini import Gemini, GeminiVersion
from genkit.veneer.veneer import Genkit

LOG = logging.getLogger(__name__)


def vertexai_name(name: str) -> str:
return f'vertexai/{name}'


class VertexAI(Plugin):
# This is 'gemini-1.5-pro' - the latest stable model
VERTEX_AI_GENERATIVE_MODEL_NAME: str = GeminiVersion.GEMINI_1_5_FLASH.value

def __init__(
self, project_id: str | None = None, location: str | None = None
):
# If not set, projectId will be read by plugin
project_id = (
project_id if project_id else os.getenv(const.GCLOUD_PROJECT)
)
location = location if location else const.DEFAULT_REGION

self._gemini = Gemini(self.VERTEX_AI_GENERATIVE_MODEL_NAME)
vertexai.init(project=project_id, location=location)

def attach_to_veneer(self, veneer: Genkit) -> None:
self._add_model_to_veneer(veneer=veneer)

def _add_model_to_veneer(self, veneer: Genkit, **kwargs) -> None:
return super()._add_model_to_veneer(
veneer=veneer,
name=vertexai_name(self.VERTEX_AI_GENERATIVE_MODEL_NAME),
metadata=self.vertex_ai_model_metadata,
)

@property
def vertex_ai_model_metadata(self) -> dict[str, dict[str, Any]]:
return {
'model': {
'supports': {'multiturn': True},
}
}

def _model_callback(self, request: GenerateRequest) -> GenerateResponse:
return self._gemini.handle_request(request=request)
1 change: 1 addition & 0 deletions py/samples/hello/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Hello world

## Setup environment
Use `gcloud auth application-default login` to connect to the VertexAI.

```bash
uv venv
Expand Down
4 changes: 3 additions & 1 deletion py/samples/hello/src/hello.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from genkit.veneer.veneer import Genkit
from pydantic import BaseModel, Field

ai = Genkit(plugins=[VertexAI()], model=VertexAI.VERTEX_AI_MODEL_NAME)
ai = Genkit(
plugins=[VertexAI()], model=VertexAI.VERTEX_AI_GENERATIVE_MODEL_NAME
)


class MyInput(BaseModel):
Expand Down

0 comments on commit bffc50d

Please sign in to comment.