Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(py): Core VertexAI plugin #2014

Merged
merged 1 commit into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 2 additions & 82 deletions py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,10 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0


"""
Google Cloud Vertex AI Plugin for Genkit.
"""

from typing import Any

import vertexai
from genkit.core.plugin_abc import Plugin
from genkit.core.schema_types import (
GenerateRequest,
GenerateResponse,
Message,
Role,
TextPart,
)
from genkit.veneer.veneer import Genkit
from vertexai.generative_models import Content, GenerativeModel, Part
from genkit.plugins.vertex_ai.plugin_api import VertexAI, vertexai_name


def package_name() -> str:
return 'genkit.plugins.vertex_ai'


def gemini(name: str) -> str:
return f'vertexai/{name}'


class VertexAI(Plugin):
LOCATION = 'us-central1'
VERTEX_AI_MODEL_NAME = gemini('gemini-1.5-flash')
VERTEX_AI_GENERATIVE_MODEL_NAME = 'gemini-1.5-flash-002'

def __init__(self, project_id: str | None = None):
self.project_id = project_id
vertexai.init(location=self.LOCATION, project=self.project_id)

def attach_to_veneer(self, veneer: Genkit) -> None:
self._add_model_to_veneer(veneer=veneer)

def _add_model_to_veneer(self, veneer: Genkit, **kwargs) -> None:
return super()._add_model_to_veneer(
veneer=veneer,
name=self.VERTEX_AI_MODEL_NAME,
metadata=self.vertex_ai_model_metadata,
)

@property
def vertex_ai_model_metadata(self) -> dict[str, dict[str, Any]]:
return {
'model': {
'label': 'banana',
'supports': {'multiturn': True},
}
}

def _model_callback(self, request: GenerateRequest) -> GenerateResponse:
return self._handle_gemini_request(request=request)

def _handle_gemini_request(
self, request: GenerateRequest
) -> GenerateResponse:
gemini_msgs: list[Content] = []
for m in request.messages:
gemini_parts: list[Part] = []
for p in m.content:
if p.root.text is not None:
gemini_parts.append(Part.from_text(p.root.text))
else:
raise Exception('unsupported part type')
gemini_msgs.append(Content(role=m.role.value, parts=gemini_parts))
response = self.vertex_ai_generative_model.generate_content(
contents=gemini_msgs
)
return GenerateResponse(
message=Message(
role=Role.model,
content=[TextPart(text=response.text)],
)
)

@property
def vertex_ai_generative_model(self) -> GenerativeModel:
return GenerativeModel(self.VERTEX_AI_GENERATIVE_MODEL_NAME)


__all__ = ['package_name', 'VertexAI', 'gemini']
__all__ = ['package_name', 'VertexAI', 'vertexai_name']
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

GCLOUD_PROJECT = 'GCLOUD_PROJECT'
DEFAULT_REGION = 'us-central1'
89 changes: 89 additions & 0 deletions py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

from enum import StrEnum

from genkit.core.schema_types import (
GenerateRequest,
GenerateResponse,
Message,
ModelInfo,
Role,
Supports,
TextPart,
)
from vertexai.generative_models import Content, GenerativeModel, Part


class GeminiVersion(StrEnum):
GEMINI_1_5_PRO = 'gemini-1.5-pro'
GEMINI_1_5_FLASH = 'gemini-1.5-flash'
GEMINI_2_0_FLASH_001 = 'gemini-2.0-flash-001'
GEMINI_2_0_FLASH_LITE_PREVIEW = 'gemini-2.0-flash-lite-preview-02-05'
GEMINI_2_0_PRO_EXP = 'gemini-2.0-pro-exp-02-05'


SUPPORTED_MODELS = {
GeminiVersion.GEMINI_1_5_PRO: ModelInfo(
versions=[],
label='Vertex AI - Gemini 1.5 Pro',
supports=Supports(
multiturn=True, media=True, tools=True, systemRole=True
),
),
GeminiVersion.GEMINI_1_5_FLASH: ModelInfo(
versions=[],
label='Vertex AI - Gemini 1.5 Flash',
supports=Supports(
multiturn=True, media=True, tools=True, systemRole=True
),
),
GeminiVersion.GEMINI_2_0_FLASH_001: ModelInfo(
versions=[],
label='Vertex AI - Gemini 2.0 Flash 001',
supports=Supports(
multiturn=True, media=True, tools=True, systemRole=True
),
),
GeminiVersion.GEMINI_2_0_FLASH_LITE_PREVIEW: ModelInfo(
versions=[],
label='Vertex AI - Gemini 2.0 Flash Lite Preview 02-05',
supports=Supports(
multiturn=True, media=True, tools=True, systemRole=True
),
),
GeminiVersion.GEMINI_2_0_PRO_EXP: ModelInfo(
versions=[],
label='Vertex AI - Gemini 2.0 Flash Pro Experimental 02-05',
supports=Supports(
multiturn=True, media=True, tools=True, systemRole=True
),
),
}


class Gemini:
def __init__(self, version):
self.version = version

@property
def gemini_model(self) -> GenerativeModel:
return GenerativeModel(self.version)

def handle_request(self, request: GenerateRequest) -> GenerateResponse:
messages: list[Content] = []
for m in request.messages:
parts: list[Part] = []
for p in m.content:
if p.root.text is not None:
parts.append(Part.from_text(p.root.text))
else:
raise Exception('unsupported part type')
messages.append(Content(role=m.role.value, parts=parts))
response = self.gemini_model.generate_content(contents=messages)
return GenerateResponse(
message=Message(
role=Role.model,
content=[TextPart(text=response.text)],
)
)
59 changes: 59 additions & 0 deletions py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/plugin_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

"""Google Cloud Vertex AI Plugin for Genkit."""

import logging
import os
from typing import Any

import vertexai
from genkit.core.plugin_abc import Plugin
from genkit.core.schema_types import GenerateRequest, GenerateResponse
from genkit.plugins.vertex_ai import constants as const
from genkit.plugins.vertex_ai.gemini import Gemini, GeminiVersion
from genkit.veneer.veneer import Genkit

LOG = logging.getLogger(__name__)


def vertexai_name(name: str) -> str:
return f'vertexai/{name}'


class VertexAI(Plugin):
# This is 'gemini-1.5-pro' - the latest stable model
VERTEX_AI_GENERATIVE_MODEL_NAME: str = GeminiVersion.GEMINI_1_5_FLASH.value

def __init__(
self, project_id: str | None = None, location: str | None = None
):
# If not set, projectId will be read by plugin
project_id = (
project_id if project_id else os.getenv(const.GCLOUD_PROJECT)
)
location = location if location else const.DEFAULT_REGION

self._gemini = Gemini(self.VERTEX_AI_GENERATIVE_MODEL_NAME)
vertexai.init(project=project_id, location=location)

def attach_to_veneer(self, veneer: Genkit) -> None:
self._add_model_to_veneer(veneer=veneer)

def _add_model_to_veneer(self, veneer: Genkit, **kwargs) -> None:
return super()._add_model_to_veneer(
veneer=veneer,
name=vertexai_name(self.VERTEX_AI_GENERATIVE_MODEL_NAME),
metadata=self.vertex_ai_model_metadata,
)

@property
def vertex_ai_model_metadata(self) -> dict[str, dict[str, Any]]:
return {
'model': {
'supports': {'multiturn': True},
}
}

def _model_callback(self, request: GenerateRequest) -> GenerateResponse:
return self._gemini.handle_request(request=request)
1 change: 1 addition & 0 deletions py/samples/hello/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Hello world

## Setup environment
Use `gcloud auth application-default login` to connect to the VertexAI.

```bash
uv venv
Expand Down
4 changes: 3 additions & 1 deletion py/samples/hello/src/hello.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from genkit.veneer.veneer import Genkit
from pydantic import BaseModel, Field

ai = Genkit(plugins=[VertexAI()], model=VertexAI.VERTEX_AI_MODEL_NAME)
ai = Genkit(
plugins=[VertexAI()], model=VertexAI.VERTEX_AI_GENERATIVE_MODEL_NAME
)


class MyInput(BaseModel):
Expand Down
Loading