Skip to content

Commit

Permalink
feat(py): added support for action execution context
Browse files Browse the repository at this point in the history
  • Loading branch information
kirgrim committed Feb 18, 2025
1 parent b6e1c2a commit ecb256a
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 11 deletions.
17 changes: 15 additions & 2 deletions py/packages/genkit/src/genkit/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ class ActionMetadataKey(str, Enum):
RETURN = 'return'


class ActionExecutionContext(BaseModel):
"""The context to the action callback."""

model_config = ConfigDict(extra='forbid', populate_by_name=True)


class Action:
"""An action is a Typed JSON-based RPC-over-HTTP remote-callable function
that supports metadata, streaming, reflection and discovery.
Expand All @@ -66,6 +72,7 @@ def __init__(
description: str | None = None,
metadata: dict[ActionMetadataKey, Any] | None = None,
span_metadata: dict[str, str] | None = None,
fn_context: ActionExecutionContext | None = None,
):
"""Initialize an action.
Expand All @@ -80,6 +87,7 @@ def __init__(
# TODO(Tatsiana Havina): separate a long constructor into methods.
self.kind: ActionKind = kind
self.name = name
self.fn_context = fn_context

def tracing_wrapper(*args, **kwargs):
"""Wraps the callable function in a tracing span and adds metadata
Expand All @@ -101,6 +109,13 @@ def tracing_wrapper(*args, **kwargs):
else:
span.set_attribute('genkit:input', json.dumps(args[0]))

if self.fn_context is not None:
if not isinstance(fn_context, ActionExecutionContext):
raise TypeError(
"Action Execution context must be of type 'ActionExecutionContext'"
)
kwargs['context'] = self.fn_context

output = fn(*args, **kwargs)

span.set_attribute('genkit:state', 'success')
Expand All @@ -122,8 +137,6 @@ def tracing_wrapper(*args, **kwargs):
k for k in input_spec.annotations if k != ActionMetadataKey.RETURN
]

if len(action_args) > 1:
raise Exception('can only have one arg')
if len(action_args) > 0:
type_adapter = TypeAdapter(input_spec.annotations[action_args[0]])
self.input_schema = type_adapter.json_schema()
Expand Down
17 changes: 14 additions & 3 deletions py/packages/genkit/src/genkit/core/plugin_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import abc
import typing

from genkit.core.action import ActionExecutionContext
from genkit.core.schema_types import GenerateRequest, GenerateResponse

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -38,7 +39,11 @@ def attach_to_veneer(self, veneer: Genkit) -> None:
pass

def _add_model_to_veneer(
self, veneer: Genkit, name: str, metadata: dict | None = None
self,
veneer: Genkit,
name: str,
metadata: dict | None = None,
fn_context: ActionExecutionContext | None = None,
) -> None:
"""
Defines plugin's model in the Genkit Registry
Expand All @@ -57,11 +62,16 @@ def _add_model_to_veneer(
if not metadata:
metadata = {}
veneer.define_model(
name=name, fn=self._model_callback, metadata=metadata
name=name,
fn=self._model_callback,
metadata=metadata,
fn_context=fn_context,
)

@abc.abstractmethod
def _model_callback(self, request: GenerateRequest) -> GenerateResponse:
def _model_callback(
self, request: GenerateRequest, context: ActionExecutionContext | None
) -> GenerateResponse:
"""
Wrapper around any plugin's model callback.
Expand All @@ -71,6 +81,7 @@ def _model_callback(self, request: GenerateRequest) -> GenerateResponse:
Args:
request: incoming request as generic
`genkit.core.schemas.GenerateRequest` instance
context: Action's execution context data if any
Returns:
Model response represented as generic
Expand Down
9 changes: 7 additions & 2 deletions py/packages/genkit/src/genkit/veneer/veneer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from genkit.ai.model import ModelFn
from genkit.ai.prompt import PromptFn
from genkit.core.action import Action, ActionKind
from genkit.core.action import Action, ActionExecutionContext, ActionKind
from genkit.core.plugin_abc import Plugin
from genkit.core.reflection import make_reflection_server
from genkit.core.registry import Registry
Expand Down Expand Up @@ -110,9 +110,14 @@ def define_model(
name: str,
fn: ModelFn,
metadata: dict[str, Any] | None = None,
fn_context: ActionExecutionContext | None = None,
) -> None:
action = Action(
name=name, kind=ActionKind.MODEL, fn=fn, metadata=metadata
name=name,
kind=ActionKind.MODEL,
fn=fn,
metadata=metadata,
fn_context=fn_context,
)
self.registry.register_action(action)

Expand Down
17 changes: 14 additions & 3 deletions py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Role,
TextPart,
)
from genkit.plugins.vertex_ai.models import VertexAIActionExecutionContext
from genkit.veneer.veneer import Genkit
from vertexai.generative_models import Content, GenerativeModel, Part

Expand Down Expand Up @@ -46,6 +47,7 @@ def _add_model_to_veneer(self, veneer: Genkit, **kwargs) -> None:
veneer=veneer,
name=self.VERTEX_AI_MODEL_NAME,
metadata=self.vertex_ai_model_metadata,
fn_context=VertexAIActionExecutionContext(),
)

@property
Expand All @@ -57,11 +59,20 @@ def vertex_ai_model_metadata(self) -> dict[str, dict[str, Any]]:
}
}

def _model_callback(self, request: GenerateRequest) -> GenerateResponse:
return self._handle_gemini_request(request=request)
def _model_callback(
self,
request: GenerateRequest,
context: VertexAIActionExecutionContext,
) -> GenerateResponse:
return self._handle_gemini_request(
request=request,
context=context,
)

def _handle_gemini_request(
self, request: GenerateRequest
self,
request: GenerateRequest,
context: VertexAIActionExecutionContext,
) -> GenerateResponse:
gemini_msgs: list[Content] = []
for m in request.messages:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,18 @@
Google Cloud Vertex AI Models for Genkit.
"""

from genkit.core.action import ActionExecutionContext


def package_name() -> str:
return 'genkit.plugins.vertex_ai.models'


__all__ = ['package_name']
class VertexAIActionExecutionContext(ActionExecutionContext):
pass


__all__ = [
'package_name',
'VertexAIActionExecutionContext',
]

0 comments on commit ecb256a

Please sign in to comment.