-
Notifications
You must be signed in to change notification settings - Fork 158
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(py): Core VertexAI plugin (#2014)
- Loading branch information
Showing
6 changed files
with
159 additions
and
83 deletions.
There are no files selected for viewing
84 changes: 2 additions & 82 deletions
84
py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,90 +1,10 @@ | ||
# Copyright 2025 Google LLC | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
""" | ||
Google Cloud Vertex AI Plugin for Genkit. | ||
""" | ||
|
||
from typing import Any | ||
|
||
import vertexai | ||
from genkit.core.plugin_abc import Plugin | ||
from genkit.core.schema_types import ( | ||
GenerateRequest, | ||
GenerateResponse, | ||
Message, | ||
Role, | ||
TextPart, | ||
) | ||
from genkit.veneer.veneer import Genkit | ||
from vertexai.generative_models import Content, GenerativeModel, Part | ||
from genkit.plugins.vertex_ai.plugin_api import VertexAI, vertexai_name | ||
|
||
|
||
def package_name() -> str: | ||
return 'genkit.plugins.vertex_ai' | ||
|
||
|
||
def gemini(name: str) -> str: | ||
return f'vertexai/{name}' | ||
|
||
|
||
class VertexAI(Plugin): | ||
LOCATION = 'us-central1' | ||
VERTEX_AI_MODEL_NAME = gemini('gemini-1.5-flash') | ||
VERTEX_AI_GENERATIVE_MODEL_NAME = 'gemini-1.5-flash-002' | ||
|
||
def __init__(self, project_id: str | None = None): | ||
self.project_id = project_id | ||
vertexai.init(location=self.LOCATION, project=self.project_id) | ||
|
||
def attach_to_veneer(self, veneer: Genkit) -> None: | ||
self._add_model_to_veneer(veneer=veneer) | ||
|
||
def _add_model_to_veneer(self, veneer: Genkit, **kwargs) -> None: | ||
return super()._add_model_to_veneer( | ||
veneer=veneer, | ||
name=self.VERTEX_AI_MODEL_NAME, | ||
metadata=self.vertex_ai_model_metadata, | ||
) | ||
|
||
@property | ||
def vertex_ai_model_metadata(self) -> dict[str, dict[str, Any]]: | ||
return { | ||
'model': { | ||
'label': 'banana', | ||
'supports': {'multiturn': True}, | ||
} | ||
} | ||
|
||
def _model_callback(self, request: GenerateRequest) -> GenerateResponse: | ||
return self._handle_gemini_request(request=request) | ||
|
||
def _handle_gemini_request( | ||
self, request: GenerateRequest | ||
) -> GenerateResponse: | ||
gemini_msgs: list[Content] = [] | ||
for m in request.messages: | ||
gemini_parts: list[Part] = [] | ||
for p in m.content: | ||
if p.root.text is not None: | ||
gemini_parts.append(Part.from_text(p.root.text)) | ||
else: | ||
raise Exception('unsupported part type') | ||
gemini_msgs.append(Content(role=m.role.value, parts=gemini_parts)) | ||
response = self.vertex_ai_generative_model.generate_content( | ||
contents=gemini_msgs | ||
) | ||
return GenerateResponse( | ||
message=Message( | ||
role=Role.model, | ||
content=[TextPart(text=response.text)], | ||
) | ||
) | ||
|
||
@property | ||
def vertex_ai_generative_model(self) -> GenerativeModel: | ||
return GenerativeModel(self.VERTEX_AI_GENERATIVE_MODEL_NAME) | ||
|
||
|
||
__all__ = ['package_name', 'VertexAI', 'gemini'] | ||
__all__ = ['package_name', 'VertexAI', 'vertexai_name'] |
5 changes: 5 additions & 0 deletions
5
py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/constants.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright 2025 Google LLC | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
GCLOUD_PROJECT = 'GCLOUD_PROJECT' | ||
DEFAULT_REGION = 'us-central1' |
89 changes: 89 additions & 0 deletions
89
py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/gemini.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Copyright 2025 Google LLC | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from enum import StrEnum | ||
|
||
from genkit.core.schema_types import ( | ||
GenerateRequest, | ||
GenerateResponse, | ||
Message, | ||
ModelInfo, | ||
Role, | ||
Supports, | ||
TextPart, | ||
) | ||
from vertexai.generative_models import Content, GenerativeModel, Part | ||
|
||
|
||
class GeminiVersion(StrEnum): | ||
GEMINI_1_5_PRO = 'gemini-1.5-pro' | ||
GEMINI_1_5_FLASH = 'gemini-1.5-flash' | ||
GEMINI_2_0_FLASH_001 = 'gemini-2.0-flash-001' | ||
GEMINI_2_0_FLASH_LITE_PREVIEW = 'gemini-2.0-flash-lite-preview-02-05' | ||
GEMINI_2_0_PRO_EXP = 'gemini-2.0-pro-exp-02-05' | ||
|
||
|
||
SUPPORTED_MODELS = { | ||
GeminiVersion.GEMINI_1_5_PRO: ModelInfo( | ||
versions=[], | ||
label='Vertex AI - Gemini 1.5 Pro', | ||
supports=Supports( | ||
multiturn=True, media=True, tools=True, systemRole=True | ||
), | ||
), | ||
GeminiVersion.GEMINI_1_5_FLASH: ModelInfo( | ||
versions=[], | ||
label='Vertex AI - Gemini 1.5 Flash', | ||
supports=Supports( | ||
multiturn=True, media=True, tools=True, systemRole=True | ||
), | ||
), | ||
GeminiVersion.GEMINI_2_0_FLASH_001: ModelInfo( | ||
versions=[], | ||
label='Vertex AI - Gemini 2.0 Flash 001', | ||
supports=Supports( | ||
multiturn=True, media=True, tools=True, systemRole=True | ||
), | ||
), | ||
GeminiVersion.GEMINI_2_0_FLASH_LITE_PREVIEW: ModelInfo( | ||
versions=[], | ||
label='Vertex AI - Gemini 2.0 Flash Lite Preview 02-05', | ||
supports=Supports( | ||
multiturn=True, media=True, tools=True, systemRole=True | ||
), | ||
), | ||
GeminiVersion.GEMINI_2_0_PRO_EXP: ModelInfo( | ||
versions=[], | ||
label='Vertex AI - Gemini 2.0 Flash Pro Experimental 02-05', | ||
supports=Supports( | ||
multiturn=True, media=True, tools=True, systemRole=True | ||
), | ||
), | ||
} | ||
|
||
|
||
class Gemini: | ||
def __init__(self, version): | ||
self.version = version | ||
|
||
@property | ||
def gemini_model(self) -> GenerativeModel: | ||
return GenerativeModel(self.version) | ||
|
||
def handle_request(self, request: GenerateRequest) -> GenerateResponse: | ||
messages: list[Content] = [] | ||
for m in request.messages: | ||
parts: list[Part] = [] | ||
for p in m.content: | ||
if p.root.text is not None: | ||
parts.append(Part.from_text(p.root.text)) | ||
else: | ||
raise Exception('unsupported part type') | ||
messages.append(Content(role=m.role.value, parts=parts)) | ||
response = self.gemini_model.generate_content(contents=messages) | ||
return GenerateResponse( | ||
message=Message( | ||
role=Role.model, | ||
content=[TextPart(text=response.text)], | ||
) | ||
) |
59 changes: 59 additions & 0 deletions
59
py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/plugin_api.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Copyright 2025 Google LLC | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Google Cloud Vertex AI Plugin for Genkit.""" | ||
|
||
import logging | ||
import os | ||
from typing import Any | ||
|
||
import vertexai | ||
from genkit.core.plugin_abc import Plugin | ||
from genkit.core.schema_types import GenerateRequest, GenerateResponse | ||
from genkit.plugins.vertex_ai import constants as const | ||
from genkit.plugins.vertex_ai.gemini import Gemini, GeminiVersion | ||
from genkit.veneer.veneer import Genkit | ||
|
||
LOG = logging.getLogger(__name__) | ||
|
||
|
||
def vertexai_name(name: str) -> str: | ||
return f'vertexai/{name}' | ||
|
||
|
||
class VertexAI(Plugin): | ||
# This is 'gemini-1.5-pro' - the latest stable model | ||
VERTEX_AI_GENERATIVE_MODEL_NAME: str = GeminiVersion.GEMINI_1_5_FLASH.value | ||
|
||
def __init__( | ||
self, project_id: str | None = None, location: str | None = None | ||
): | ||
# If not set, projectId will be read by plugin | ||
project_id = ( | ||
project_id if project_id else os.getenv(const.GCLOUD_PROJECT) | ||
) | ||
location = location if location else const.DEFAULT_REGION | ||
|
||
self._gemini = Gemini(self.VERTEX_AI_GENERATIVE_MODEL_NAME) | ||
vertexai.init(project=project_id, location=location) | ||
|
||
def attach_to_veneer(self, veneer: Genkit) -> None: | ||
self._add_model_to_veneer(veneer=veneer) | ||
|
||
def _add_model_to_veneer(self, veneer: Genkit, **kwargs) -> None: | ||
return super()._add_model_to_veneer( | ||
veneer=veneer, | ||
name=vertexai_name(self.VERTEX_AI_GENERATIVE_MODEL_NAME), | ||
metadata=self.vertex_ai_model_metadata, | ||
) | ||
|
||
@property | ||
def vertex_ai_model_metadata(self) -> dict[str, dict[str, Any]]: | ||
return { | ||
'model': { | ||
'supports': {'multiturn': True}, | ||
} | ||
} | ||
|
||
def _model_callback(self, request: GenerateRequest) -> GenerateResponse: | ||
return self._gemini.handle_request(request=request) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters