From 8c29807b5a35db7b986231ff31c3ea92776a2392 Mon Sep 17 00:00:00 2001 From: Yesudeep Mangalapilly Date: Wed, 12 Feb 2025 00:19:06 -0800 Subject: [PATCH] fix: formatting, linting, and schema code generation #1935 ISSUE: https://github.com/firebase/genkit/issues/1935 CHANGELOG: - [ ] Addresses error handling and parallelized formatting for toml files. - [ ] Configures the ruff formatter to use requested format. - [ ] Fixes snake_case for several Python variables. - [ ] Configures the schemas.py code generator to generate snake_case field names to avoid lint reports and wraps the actual name in Field() metadata. - [ ] Clean up some lint. - [ ] Removes go-vulncheck from the pre-commit hook while retaining it during pre-push (easy to skip this using `-f --no-verify`). --- py/bin/format_toml_files | 91 ++++++----- py/bin/format_toml_files_filtered_serial | 69 ++++++++ py/bin/format_toml_files_serial | 14 ++ py/bin/generate_schema_types | 20 +-- py/bin/sanitize_schemas.py | 7 +- py/bin/setup | 2 +- py/captainhook.json | 13 +- py/packages/genkit/src/genkit/ai/model.py | 3 +- py/packages/genkit/src/genkit/ai/prompt.py | 7 +- py/packages/genkit/src/genkit/core/action.py | 15 +- .../genkit/src/genkit/core/reflection.py | 7 +- .../genkit/src/genkit/core/registry.py | 3 +- py/packages/genkit/src/genkit/core/schemas.py | 154 ++++++++++-------- py/packages/genkit/src/genkit/core/tracing.py | 17 +- .../genkit/src/genkit/veneer/veneer.py | 34 ++-- .../src/genkit/plugins/vertex_ai/__init__.py | 23 +-- py/pyproject.toml | 56 ++++++- py/samples/hello/hello.py | 6 +- 18 files changed, 343 insertions(+), 198 deletions(-) create mode 100755 py/bin/format_toml_files_filtered_serial create mode 100755 py/bin/format_toml_files_serial diff --git a/py/bin/format_toml_files b/py/bin/format_toml_files index eef13cf5f..5d131e7e6 100755 --- a/py/bin/format_toml_files +++ b/py/bin/format_toml_files @@ -9,52 +9,63 @@ set -euo pipefail GIT_ROOT=$(git rev-parse --show-toplevel) -if command -v rust-parallel >/dev/null 2>&1; then +if command -v taplo >/dev/null 2>&1; then + if [ ! -f "${GIT_ROOT}/py/taplo.toml" ]; then + echo "error: config file not found at ${GIT_ROOT}/py/taplo.toml" + exit 1 + fi + + FORMATTER_COMMAND="taplo format --config ${GIT_ROOT}/py/taplo.toml" + if command -v rust-parallel >/dev/null 2>&1; then + FORMATTER_COMMAND="rust-parallel -j4 ${FORMATTER_COMMAND}" + fi + if command -v fd >/dev/null 2>&1; then + echo "Using fd" fd -e toml \ - --exclude 'py/**/*.egg-info/**' \ - --exclude 'py/**/.dist/**' \ - --exclude 'py/**/.next/**' \ - --exclude 'py/**/.output/**' \ - --exclude 'py/**/.pytest_cache/**' \ - --exclude 'py/**/.venv/**' \ - --exclude 'py/**/__pycache__/**' \ - --exclude 'py/**/build/**' \ - --exclude 'py/**/develop-eggs/**' \ - --exclude 'py/**/dist/**' \ - --exclude 'py/**/eggs/**' \ - --exclude 'py/**/node_modules/**' \ - --exclude 'py/**/sdist/**' \ - --exclude 'py/**/site/**' \ - --exclude 'py/**/target/**' \ - --exclude 'py/**/venv/**' \ - --exclude 'py/**/wheels/**' | - rust-parallel -j4 \ - taplo format --config "${GIT_ROOT}/py/taplo.toml" + --exclude '**/*.egg-info/**' \ + --exclude '**/.dist/**' \ + --exclude '**/.next/**' \ + --exclude '**/.output/**' \ + --exclude '**/.pytest_cache/**' \ + --exclude '**/.venv/**' \ + --exclude '**/__pycache__/**' \ + --exclude '**/bazel-*/**' \ + --exclude '**/build/**' \ + --exclude '**/develop-eggs/**' \ + --exclude '**/dist/**' \ + --exclude '**/eggs/**' \ + --exclude '**/node_modules/**' \ + --exclude '**/sdist/**' \ + --exclude '**/site/**' \ + --exclude '**/target/**' \ + --exclude '**/venv/**' \ + --exclude '**/wheels/**' | + ${FORMATTER_COMMAND} else echo "Using find" find "${GIT_ROOT}" -name "*.toml" \ - ! -path 'py/**/*.egg-info/**' \ - ! -path 'py/**/.dist/**' \ - ! -path 'py/**/.next/**' \ - ! -path 'py/**/.output/**' \ - ! -path 'py/**/.pytest_cache/**' \ - ! -path 'py/**/.venv/**' \ - ! -path 'py/**/__pycache__/**' \ - ! -path 'py/**/build/**' \ - ! -path 'py/**/develop-eggs/**' \ - ! -path 'py/**/dist/**' \ - ! -path 'py/**/eggs/**' \ - ! -path 'py/**/node_modules/**' \ - ! -path 'py/**/sdist/**' \ - ! -path 'py/**/site/**' \ - ! -path 'py/**/target/**' \ - ! -path 'py/**/venv/**' \ - ! -path 'py/**/wheels/**' \ + ! -path '**/*.egg-info/**' \ + ! -path '**/.dist/**' \ + ! -path '**/.next/**' \ + ! -path '**/.output/**' \ + ! -path '**/.pytest_cache/**' \ + ! -path '**/.venv/**' \ + ! -path '**/__pycache__/**' \ + ! -path '**/bazel-*/**' \ + ! -path '**/build/**' \ + ! -path '**/develop-eggs/**' \ + ! -path '**/dist/**' \ + ! -path '**/eggs/**' \ + ! -path '**/node_modules/**' \ + ! -path '**/sdist/**' \ + ! -path '**/site/**' \ + ! -path '**/target/**' \ + ! -path '**/venv/**' \ + ! -path '**/wheels/**' \ -print0 | - rust-parallel -j4 \ - taplo format --config "${GIT_ROOT}/py/taplo.toml" + ${FORMATTER_COMMAND} fi else - echo "Please install GNU parallel to use this script" + echo "Please install taplo to use this script" fi diff --git a/py/bin/format_toml_files_filtered_serial b/py/bin/format_toml_files_filtered_serial new file mode 100755 index 000000000..e0e7131fb --- /dev/null +++ b/py/bin/format_toml_files_filtered_serial @@ -0,0 +1,69 @@ +#!/usr/bin/env bash +# +# Format all TOML files in the project. +# +# Copyright 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 + +set -euo pipefail + +GIT_ROOT=$(git rev-parse --show-toplevel) + +if command -v taplo >/dev/null 2>&1; then + if [ ! -f "${GIT_ROOT}/py/taplo.toml" ]; then + echo "error: config file not found at ${GIT_ROOT}/py/taplo.toml" + exit 1 + fi + + FORMATTER_COMMAND="taplo format --config ${GIT_ROOT}/py/taplo.toml" + #if command -v rust-parallel >/dev/null 2>&1; then + # FORMATTER_COMMAND="rust-parallel -j4 ${FORMATTER_COMMAND}" + #fi + + if command -v fd >/dev/null 2>&1; then + echo "Using fd" + fd -e toml \ + --exclude 'py/**/*.egg-info/**' \ + --exclude 'py/**/.dist/**' \ + --exclude 'py/**/.next/**' \ + --exclude 'py/**/.output/**' \ + --exclude 'py/**/.pytest_cache/**' \ + --exclude 'py/**/.venv/**' \ + --exclude 'py/**/__pycache__/**' \ + --exclude 'py/**/build/**' \ + --exclude 'py/**/develop-eggs/**' \ + --exclude 'py/**/dist/**' \ + --exclude 'py/**/eggs/**' \ + --exclude 'py/**/node_modules/**' \ + --exclude 'py/**/sdist/**' \ + --exclude 'py/**/site/**' \ + --exclude 'py/**/target/**' \ + --exclude 'py/**/venv/**' \ + --exclude 'py/**/wheels/**' | + ${FORMATTER_COMMAND} + else + echo "Using find" + find "${GIT_ROOT}" -name "*.toml" \ + ! -path 'py/**/*.egg-info/**' \ + ! -path 'py/**/.dist/**' \ + ! -path 'py/**/.next/**' \ + ! -path 'py/**/.output/**' \ + ! -path 'py/**/.pytest_cache/**' \ + ! -path 'py/**/.venv/**' \ + ! -path 'py/**/__pycache__/**' \ + ! -path 'py/**/build/**' \ + ! -path 'py/**/develop-eggs/**' \ + ! -path 'py/**/dist/**' \ + ! -path 'py/**/eggs/**' \ + ! -path 'py/**/node_modules/**' \ + ! -path 'py/**/sdist/**' \ + ! -path 'py/**/site/**' \ + ! -path 'py/**/target/**' \ + ! -path 'py/**/venv/**' \ + ! -path 'py/**/wheels/**' \ + -print0 | + ${FORMATTER_COMMAND} + fi +else + echo "Please install taplo to use this script" +fi diff --git a/py/bin/format_toml_files_serial b/py/bin/format_toml_files_serial new file mode 100755 index 000000000..af4343cd4 --- /dev/null +++ b/py/bin/format_toml_files_serial @@ -0,0 +1,14 @@ +#!/usr/bin/env bash +# +# Format all TOML files in the project. +# +# Copyright 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 + +set -euo pipefail + +GIT_ROOT=$(git rev-parse --show-toplevel) + +taplo format --config "${GIT_ROOT}/py/taplo.toml" + +exit $? diff --git a/py/bin/generate_schema_types b/py/bin/generate_schema_types index 03ff2be78..1e770e8d3 100755 --- a/py/bin/generate_schema_types +++ b/py/bin/generate_schema_types @@ -6,30 +6,30 @@ set -euo pipefail TOP_DIR=$(git rev-parse --show-toplevel) -SCHEMA_FILE="$TOP_DIR/py/packages/genkit/src/genkit/core/schemas.py" +SCHEMA_FILE="${TOP_DIR}/py/packages/genkit/src/genkit/core/schemas.py" # Generate types using configuration from pyproject.toml -uv run --directory "$TOP_DIR/py" datamodel-codegen +uv run --directory "${TOP_DIR}/py" datamodel-codegen # This isn't causing runtime errors at the moment so letting it be. -#sed -i '' '/^class Model(RootModel\[Any\]):$/,/^ root: Any$/d' "$SCHEMA_FILE" +#sed -i '' '/^class Model(RootModel\[Any\]):$/,/^ root: Any$/d' "${SCHEMA_FILE}" # Sanitize the generated schema. -python3 "${TOP_DIR}/py/bin/sanitize_schemas.py" "$SCHEMA_FILE" +python3 "${TOP_DIR}/py/bin/sanitize_schemas.py" "${SCHEMA_FILE}" # Add a generated by `generate_schema_types` comment. sed -i '' '1i\ # DO NOT EDIT: Generated by `generate_schema_types` from `genkit-schemas.json`. -' "$SCHEMA_FILE" +' "${SCHEMA_FILE}" # Add license header. addlicense \ -c "Google LLC" \ -s=only \ - "$SCHEMA_FILE" + "${SCHEMA_FILE}" # Checks and formatting. -uv run --directory "$TOP_DIR/py" \ - ruff check --fix "$SCHEMA_FILE" -uv run --directory "$TOP_DIR/py" \ - ruff format "$SCHEMA_FILE" +uv run --directory "${TOP_DIR}/py" \ + ruff format "${TOP_DIR}" +uv run --directory "${TOP_DIR}/py" \ + ruff check --fix "${SCHEMA_FILE}" diff --git a/py/bin/sanitize_schemas.py b/py/bin/sanitize_schemas.py index 198a96948..94c7aaa00 100644 --- a/py/bin/sanitize_schemas.py +++ b/py/bin/sanitize_schemas.py @@ -24,8 +24,9 @@ def is_rootmodel_class(self, node: ast.ClassDef) -> bool: return True return False - def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: - """Visit class definitions and remove model_config if class inherits from RootModel.""" + def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: # noqa: N802 + """Visit class definitions and remove model_config if class + inherits from RootModel.""" if self.is_rootmodel_class(node): # Filter out model_config assignments new_body = [] @@ -48,7 +49,7 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: def process_file(filename: str) -> None: """Process a Python file to remove model_config from RootModel classes.""" - with open(filename, 'r') as f: + with open(filename) as f: source = f.read() tree = ast.parse(source) diff --git a/py/bin/setup b/py/bin/setup index 77180dc10..e3bcd8942 100755 --- a/py/bin/setup +++ b/py/bin/setup @@ -171,7 +171,7 @@ function genkit::install_google_cloud_sdk() { # This depends on Python 3.11 and installs it for the user on some systems. if command -v gcloud &>/dev/null; then gcloud config set disable_usage_reporting true - gcloud components update + yes | gcloud components update return 0 fi diff --git a/py/captainhook.json b/py/captainhook.json index 56bbad814..236c8ba84 100644 --- a/py/captainhook.json +++ b/py/captainhook.json @@ -53,17 +53,6 @@ }, { "run": "go test go/..." - }, - { - "run": "govulncheck -C go ./...", - "conditions": [ - { - "run": "CaptainHook::FileChanged.Any", - "options": { - "files": ["go/go.mod", "go.sum", "*.go"] - } - } - ] } ] }, @@ -111,7 +100,7 @@ { "run": "CaptainHook::FileChanged.Any", "options": { - "files": ["go/go.mod", "go.sum", "*.go"] + "files": ["go/**/*.go", "go/**/go.mod", "go/**/go.sum"] } } ] diff --git a/py/packages/genkit/src/genkit/ai/model.py b/py/packages/genkit/src/genkit/ai/model.py index b052cfbfc..c30d1825d 100644 --- a/py/packages/genkit/src/genkit/ai/model.py +++ b/py/packages/genkit/src/genkit/ai/model.py @@ -1,7 +1,8 @@ # Copyright 2025 Google LLC # SPDX-License-Identifier: Apache-2.0 -from typing import Callable +from collections.abc import Callable + from genkit.core.schemas import GenerateRequest, GenerateResponse ModelFn = Callable[[GenerateRequest], GenerateResponse] diff --git a/py/packages/genkit/src/genkit/ai/prompt.py b/py/packages/genkit/src/genkit/ai/prompt.py index 438492c88..135dca765 100644 --- a/py/packages/genkit/src/genkit/ai/prompt.py +++ b/py/packages/genkit/src/genkit/ai/prompt.py @@ -2,8 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Optional, Any -from genkit.core.schemas import GenerateRequest +from collections.abc import Callable +from typing import Any +from genkit.core.schemas import GenerateRequest -PromptFn = Callable[[Optional[Any]], GenerateRequest] +PromptFn = Callable[[Any | None], GenerateRequest] diff --git a/py/packages/genkit/src/genkit/core/action.py b/py/packages/genkit/src/genkit/core/action.py index 4c4f6a19d..633a1af8c 100644 --- a/py/packages/genkit/src/genkit/core/action.py +++ b/py/packages/genkit/src/genkit/core/action.py @@ -1,19 +1,20 @@ # Copyright 2025 Google LLC # SPDX-License-Identifier: Apache-2. + import inspect import json - -from typing import Dict, Optional, Callable, Any -from pydantic import ConfigDict, BaseModel, TypeAdapter +from collections.abc import Callable +from typing import Any from genkit.core.tracing import tracer +from pydantic import BaseModel, ConfigDict, Field, TypeAdapter class ActionResponse(BaseModel): model_config = ConfigDict(extra='forbid') response: Any - traceId: str + trace_id: str = Field(alias='traceId') class Action: @@ -26,9 +27,9 @@ def __init__( action_type: str, name: str, fn: Callable, - description: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, - span_metadata: Optional[Dict[str, str]] = None, + description: str | None = None, + metadata: dict[str, Any] | None = None, + span_metadata: dict[str, str] | None = None, ): # TODO(Tatsiana Havina): separate a long constructor into methods. self.type = action_type diff --git a/py/packages/genkit/src/genkit/core/reflection.py b/py/packages/genkit/src/genkit/core/reflection.py index 2aff3a239..9949002d1 100644 --- a/py/packages/genkit/src/genkit/core/reflection.py +++ b/py/packages/genkit/src/genkit/core/reflection.py @@ -5,11 +5,10 @@ """Exposes an API for inspecting and interacting with Genkit in development.""" import json - from http.server import BaseHTTPRequestHandler -from pydantic import BaseModel from genkit.core.registry import Registry +from pydantic import BaseModel def make_reflection_server(registry: Registry): @@ -20,7 +19,7 @@ class ReflectionServer(BaseHTTPRequestHandler): ENCODING = 'utf-8' - def do_GET(self): + def do_GET(self) -> None: # noqa: N802 """Handles GET requests.""" if self.path == '/api/__health': self.send_response(200) @@ -49,7 +48,7 @@ def do_GET(self): self.send_response(404) self.end_headers() - def do_POST(self): + def do_POST(self) -> None: # noqa: N802 """Handles POST requests.""" if self.path == '/api/notify': self.send_response(200) diff --git a/py/packages/genkit/src/genkit/core/registry.py b/py/packages/genkit/src/genkit/core/registry.py index 823622adf..4f09d36b1 100644 --- a/py/packages/genkit/src/genkit/core/registry.py +++ b/py/packages/genkit/src/genkit/core/registry.py @@ -4,14 +4,13 @@ """The registry is used to store and lookup resources.""" -from typing import Dict from genkit.core.action import Action class Registry: """Stores actions, trace stores, flow state stores, plugins, and schemas.""" - actions: Dict[str, Dict[str, Action]] = {} + actions: dict[str, dict[str, Action]] = {} def register_action(self, action_type: str, name: str, action: Action): if action_type not in self.actions: diff --git a/py/packages/genkit/src/genkit/core/schemas.py b/py/packages/genkit/src/genkit/core/schemas.py index c65f00d72..7c6f78d89 100644 --- a/py/packages/genkit/src/genkit/core/schemas.py +++ b/py/packages/genkit/src/genkit/core/schemas.py @@ -3,8 +3,10 @@ # DO NOT EDIT: Generated by `generate_schema_types` from `genkit-schemas.json`. from __future__ import annotations + from enum import Enum from typing import Any + from pydantic import BaseModel, ConfigDict, Field, RootModel @@ -16,15 +18,15 @@ class InstrumentationLibrary(BaseModel): model_config = ConfigDict(extra='forbid') name: str version: str | None = None - schemaUrl: str | None = None + schema_url: str | None = Field(None, alias='schemaUrl') class SpanContext(BaseModel): model_config = ConfigDict(extra='forbid') - traceId: str - spanId: str - isRemote: bool | None = None - traceFlags: float + trace_id: str = Field(..., alias='traceId') + span_id: str = Field(..., alias='spanId') + is_remote: bool | None = Field(None, alias='isRemote') + trace_flags: float = Field(..., alias='traceFlags') class SameProcessAsParentSpan(BaseModel): @@ -43,7 +45,7 @@ class SpanMetadata(BaseModel): state: State | None = None input: Any | None = None output: Any | None = None - isRoot: bool | None = None + is_root: bool | None = Field(None, alias='isRoot') metadata: dict[str, str] | None = None @@ -82,8 +84,8 @@ class DataPart(BaseModel): model_config = ConfigDict(extra='forbid') text: Any | None = None media: Any | None = None - toolRequest: Any | None = None - toolResponse: Any | None = None + tool_request: Any | None = Field(None, alias='toolRequest') + tool_response: Any | None = Field(None, alias='toolResponse') data: Any | None = None metadata: dict[str, Any] | None = None @@ -105,7 +107,7 @@ class Content(BaseModel): class Media(BaseModel): model_config = ConfigDict(extra='forbid') - contentType: str | None = None + content_type: str | None = Field(None, alias='contentType') url: str @@ -130,9 +132,9 @@ class ToolChoice(Enum): class Output(BaseModel): model_config = ConfigDict(extra='forbid') format: str | None = None - contentType: str | None = None + content_type: str | None = Field(None, alias='contentType') instructions: bool | str | None = None - jsonSchema: Any | None = None + json_schema: Any | None = Field(None, alias='jsonSchema') constrained: bool | None = None @@ -152,25 +154,25 @@ class GenerationCommonConfig(BaseModel): model_config = ConfigDict(extra='forbid') version: str | None = None temperature: float | None = None - maxOutputTokens: float | None = None - topK: float | None = None - topP: float | None = None - stopSequences: list[str] | None = None + max_output_tokens: float | None = Field(None, alias='maxOutputTokens') + top_k: float | None = Field(None, alias='topK') + top_p: float | None = Field(None, alias='topP') + stop_sequences: list[str] | None = Field(None, alias='stopSequences') class GenerationUsage(BaseModel): model_config = ConfigDict(extra='forbid') - inputTokens: float | None = None - outputTokens: float | None = None - totalTokens: float | None = None - inputCharacters: float | None = None - outputCharacters: float | None = None - inputImages: float | None = None - outputImages: float | None = None - inputVideos: float | None = None - outputVideos: float | None = None - inputAudioFiles: float | None = None - outputAudioFiles: float | None = None + input_tokens: float | None = Field(None, alias='inputTokens') + output_tokens: float | None = Field(None, alias='outputTokens') + total_tokens: float | None = Field(None, alias='totalTokens') + input_characters: float | None = Field(None, alias='inputCharacters') + output_characters: float | None = Field(None, alias='outputCharacters') + input_images: float | None = Field(None, alias='inputImages') + output_images: float | None = Field(None, alias='outputImages') + input_videos: float | None = Field(None, alias='inputVideos') + output_videos: float | None = Field(None, alias='outputVideos') + input_audio_files: float | None = Field(None, alias='inputAudioFiles') + output_audio_files: float | None = Field(None, alias='outputAudioFiles') custom: dict[str, float] | None = None @@ -185,12 +187,12 @@ class Supports(BaseModel): multiturn: bool | None = None media: bool | None = None tools: bool | None = None - systemRole: bool | None = None + system_role: bool | None = Field(None, alias='systemRole') output: list[str] | None = None - contentType: list[str] | None = None + content_type: list[str] | None = Field(None, alias='contentType') context: bool | None = None constrained: Constrained | None = None - toolChoice: bool | None = None + tool_choice: bool | None = Field(None, alias='toolChoice') class ModelInfo(BaseModel): @@ -211,11 +213,15 @@ class ToolDefinition(BaseModel): model_config = ConfigDict(extra='forbid') name: str description: str - inputSchema: dict[str, Any] = Field( - ..., description='Valid JSON Schema representing the input of the tool.' + input_schema: dict[str, Any] = Field( + ..., + alias='inputSchema', + description='Valid JSON Schema representing the input of the tool.', ) - outputSchema: dict[str, Any] | None = Field( - None, description='Valid JSON Schema describing the output of the tool.' + output_schema: dict[str, Any] | None = Field( + None, + alias='outputSchema', + description='Valid JSON Schema describing the output of the tool.', ) metadata: dict[str, Any] | None = Field( None, description='additional metadata for this tool definition' @@ -264,7 +270,7 @@ class Content2(BaseModel): class Media2(BaseModel): model_config = ConfigDict(extra='forbid') - contentType: str | None = None + content_type: str | None = Field(None, alias='contentType') url: str @@ -326,38 +332,44 @@ class Link(BaseModel): model_config = ConfigDict(extra='forbid') context: SpanContext | None = None attributes: dict[str, Any] | None = None - droppedAttributesCount: float | None = None + dropped_attributes_count: float | None = Field( + None, alias='droppedAttributesCount' + ) class TimeEvents(BaseModel): model_config = ConfigDict(extra='forbid') - timeEvent: list[TimeEvent] | None = None + time_event: list[TimeEvent] | None = Field(None, alias='timeEvent') class SpanData(BaseModel): model_config = ConfigDict(extra='forbid') - spanId: str - traceId: str - parentSpanId: str | None = None - startTime: float - endTime: float + span_id: str = Field(..., alias='spanId') + trace_id: str = Field(..., alias='traceId') + parent_span_id: str | None = Field(None, alias='parentSpanId') + start_time: float = Field(..., alias='startTime') + end_time: float = Field(..., alias='endTime') attributes: dict[str, Any] - displayName: str + display_name: str = Field(..., alias='displayName') links: list[Link] | None = None - instrumentationLibrary: InstrumentationLibrary - spanKind: str - sameProcessAsParentSpan: SameProcessAsParentSpan | None = None + instrumentation_library: InstrumentationLibrary = Field( + ..., alias='instrumentationLibrary' + ) + span_kind: str = Field(..., alias='spanKind') + same_process_as_parent_span: SameProcessAsParentSpan | None = Field( + None, alias='sameProcessAsParentSpan' + ) status: SpanStatus | None = None - timeEvents: TimeEvents | None = None + time_events: TimeEvents | None = Field(None, alias='timeEvents') truncated: bool | None = None class TraceData(BaseModel): model_config = ConfigDict(extra='forbid') - traceId: str - displayName: str | None = None - startTime: float | None = None - endTime: float | None = None + trace_id: str = Field(..., alias='traceId') + display_name: str | None = Field(None, alias='displayName') + start_time: float | None = Field(None, alias='startTime') + end_time: float | None = Field(None, alias='endTime') spans: dict[str, SpanData] @@ -365,8 +377,8 @@ class MediaPart(BaseModel): model_config = ConfigDict(extra='forbid') text: Text | None = None media: Media - toolRequest: ToolRequest | None = None - toolResponse: ToolResponse | None = None + tool_request: ToolRequest | None = Field(None, alias='toolRequest') + tool_response: ToolResponse | None = Field(None, alias='toolResponse') data: Any | None = None metadata: Metadata | None = None @@ -375,8 +387,8 @@ class TextPart(BaseModel): model_config = ConfigDict(extra='forbid') text: str media: MediaModel | None = None - toolRequest: ToolRequest | None = None - toolResponse: ToolResponse | None = None + tool_request: ToolRequest | None = Field(None, alias='toolRequest') + tool_response: ToolResponse | None = Field(None, alias='toolResponse') data: Data | None = None metadata: Metadata | None = None @@ -385,8 +397,8 @@ class ToolRequestPart(BaseModel): model_config = ConfigDict(extra='forbid') text: Text | None = None media: MediaModel | None = None - toolRequest: ToolRequest1 - toolResponse: ToolResponse | None = None + tool_request: ToolRequest1 = Field(..., alias='toolRequest') + tool_response: ToolResponse | None = Field(None, alias='toolResponse') data: Data | None = None metadata: Metadata | None = None @@ -395,8 +407,8 @@ class ToolResponsePart(BaseModel): model_config = ConfigDict(extra='forbid') text: Text | None = None media: MediaModel | None = None - toolRequest: ToolRequest | None = None - toolResponse: ToolResponse1 + tool_request: ToolRequest | None = Field(None, alias='toolRequest') + tool_response: ToolResponse1 = Field(..., alias='toolResponse') data: Data | None = None metadata: Metadata | None = None @@ -453,8 +465,8 @@ class Candidate(BaseModel): index: float message: Message usage: GenerationUsage | None = None - finishReason: FinishReason - finishMessage: str | None = None + finish_reason: FinishReason = Field(..., alias='finishReason') + finish_message: str | None = Field(None, alias='finishMessage') custom: Any | None = None @@ -464,11 +476,11 @@ class GenerateActionOptions(BaseModel): docs: list[Doc] | None = None messages: list[Message] tools: list[str] | None = None - toolChoice: ToolChoice | None = None + tool_choice: ToolChoice | None = Field(None, alias='toolChoice') config: Any | None = None output: Output | None = None - returnToolRequests: bool | None = None - maxTurns: float | None = None + return_tool_requests: bool | None = Field(None, alias='returnToolRequests') + max_turns: float | None = Field(None, alias='maxTurns') class GenerateRequest(BaseModel): @@ -476,7 +488,7 @@ class GenerateRequest(BaseModel): messages: list[Message] config: Any | None = None tools: list[ToolDefinition] | None = None - toolChoice: ToolChoice | None = None + tool_choice: ToolChoice | None = Field(None, alias='toolChoice') output: Output1 | None = None context: list[Items] | None = None candidates: float | None = None @@ -485,9 +497,9 @@ class GenerateRequest(BaseModel): class GenerateResponse(BaseModel): model_config = ConfigDict(extra='forbid') message: Message | None = None - finishReason: FinishReason | None = None - finishMessage: str | None = None - latencyMs: float | None = None + finish_reason: FinishReason | None = Field(None, alias='finishReason') + finish_message: str | None = Field(None, alias='finishMessage') + latency_ms: float | None = Field(None, alias='latencyMs') usage: GenerationUsage | None = None custom: Any | None = None request: GenerateRequest | None = None @@ -499,7 +511,7 @@ class ModelRequest(BaseModel): messages: Messages config: Config | None = None tools: Tools | None = None - toolChoice: ToolChoice | None = None + tool_choice: ToolChoice | None = Field(None, alias='toolChoice') output: OutputModel | None = None context: list[Items] | None = None @@ -511,9 +523,9 @@ class Request(RootModel[GenerateRequest]): class ModelResponse(BaseModel): model_config = ConfigDict(extra='forbid') message: Message | None = None - finishReason: FinishReason - finishMessage: FinishMessage | None = None - latencyMs: LatencyMs | None = None + finish_reason: FinishReason = Field(..., alias='finishReason') + finish_message: FinishMessage | None = Field(None, alias='finishMessage') + latency_ms: LatencyMs | None = Field(None, alias='latencyMs') usage: Usage | None = None custom: Custom | None = None request: Request | None = None diff --git a/py/packages/genkit/src/genkit/core/tracing.py b/py/packages/genkit/src/genkit/core/tracing.py index 2895a75f9..29c547be8 100644 --- a/py/packages/genkit/src/genkit/core/tracing.py +++ b/py/packages/genkit/src/genkit/core/tracing.py @@ -7,17 +7,17 @@ import json import os import sys -from typing import Any, Dict, Sequence +from collections.abc import Sequence +from typing import Any import requests -from opentelemetry.sdk.trace import TracerProvider +from opentelemetry import trace as trace_api +from opentelemetry.sdk.trace import ReadableSpan, TracerProvider from opentelemetry.sdk.trace.export import ( + SimpleSpanProcessor, SpanExporter, SpanExportResult, - SimpleSpanProcessor, ) -from opentelemetry import trace as trace_api -from opentelemetry.sdk.trace import ReadableSpan class TelemetryServerSpanExporter(SpanExporter): @@ -71,7 +71,8 @@ def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult: span_data['startTime'] = span.start_time span_data['endTime'] = span.end_time - # TODO: telemetry server URL must be dynamic, whatever tools notification says + # TODO: telemetry server URL must be dynamic, whatever tools + # notification says requests.post( 'http://localhost:4033/api/traces', data=json.dumps(span_data), @@ -88,8 +89,8 @@ def force_flush(self, timeout_millis: int = 30000) -> bool: return True -def convert_attributes(attributes: Dict[str, Any]) -> Dict[str, Any]: - attrs: Dict[str, Any] = {} +def convert_attributes(attributes: dict[str, Any]) -> dict[str, Any]: + attrs: dict[str, Any] = {} for key in attributes: attrs[key] = attributes[key] return attrs diff --git a/py/packages/genkit/src/genkit/veneer/veneer.py b/py/packages/genkit/src/genkit/veneer/veneer.py index b5c26d311..a2d2d4217 100644 --- a/py/packages/genkit/src/genkit/veneer/veneer.py +++ b/py/packages/genkit/src/genkit/veneer/veneer.py @@ -9,15 +9,15 @@ import json import os import threading - +from collections.abc import Callable from http.server import HTTPServer -from typing import Union, List, Dict, Optional, Callable, Any +from typing import Any from genkit.ai.model import ModelFn from genkit.ai.prompt import PromptFn +from genkit.core.action import Action from genkit.core.reflection import make_reflection_server from genkit.core.registry import Registry -from genkit.core.action import Action from genkit.core.schemas import GenerateRequest, GenerateResponse, Message Plugin = Callable[['Genkit'], None] @@ -33,8 +33,8 @@ class Genkit: def __init__( self, - plugins: Optional[List[Plugin]] = None, - model: Optional[str] = None, + plugins: list[Plugin] | None = None, + model: str | None = None, ) -> None: self.model = model if os.getenv('GENKIT_ENV') == 'dev': @@ -77,11 +77,11 @@ def start_server(self) -> None: def generate( self, - model: Optional[str] = None, - prompt: Optional[Union[str]] = None, - messages: Optional[List[Message]] = None, - system: Optional[Union[str]] = None, - tools: Optional[List[str]] = None, + model: str | None = None, + prompt: str | None = None, + messages: list[Message] | None = None, + system: str | None = None, + tools: list[str] | None = None, ) -> GenerateResponse: model = model if model is not None else self.model if model is None: @@ -91,9 +91,7 @@ def generate( return model_action.fn(GenerateRequest(messages=messages)).response - def flow( - self, name: Optional[str] = None - ) -> Callable[[Callable], Callable]: + def flow(self, name: str | None = None) -> Callable[[Callable], Callable]: def wrapper(func: Callable) -> Callable: flow_name = name if name is not None else func.__name__ action = Action( @@ -117,7 +115,7 @@ def define_model( self, name: str, fn: ModelFn, - metadata: Optional[Dict[str, Any]] = None, + metadata: dict[str, Any] | None = None, ) -> None: action = Action( name=name, action_type=self.MODEL, fn=fn, metadata=metadata @@ -128,16 +126,16 @@ def define_prompt( self, name: str, fn: PromptFn, - model: Optional[str] = None, - ) -> Callable[[Optional[Any]], GenerateResponse]: - def prompt(input_prompt: Optional[Any] = None) -> GenerateResponse: + model: str | None = None, + ) -> Callable[[Any | None], GenerateResponse]: + def prompt(input_prompt: Any | None = None) -> GenerateResponse: req = fn(input_prompt) return self.generate(messages=req.messages, model=model) action = Action(self.MODEL, name, prompt) self.registry.register_action(self.MODEL, name, action) - def wrapper(input_prompt: Optional[Any] = None) -> GenerateResponse: + def wrapper(input_prompt: Any | None = None) -> GenerateResponse: return action.fn(input_prompt) return wrapper diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/__init__.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/__init__.py index 5f3ba44a8..a6fc0e476 100644 --- a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/__init__.py +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/__init__.py @@ -6,11 +6,9 @@ Google Cloud Vertex AI Plugin for Genkit. """ -import vertexai - -from typing import Callable, Optional -from vertexai.generative_models import GenerativeModel, Content, Part +from collections.abc import Callable +import vertexai from genkit.core.schemas import ( GenerateRequest, GenerateResponse, @@ -18,28 +16,31 @@ TextPart, ) from genkit.veneer.veneer import Genkit +from vertexai.generative_models import Content, GenerativeModel, Part def package_name() -> str: return 'genkit.plugins.vertex_ai' -def vertexAI(project_id: Optional[str] = None) -> Callable[[Genkit], None]: +def vertex_ai(project_id: str | None = None) -> Callable[[Genkit], None]: def plugin(ai: Genkit) -> None: vertexai.init(location='us-central1', project=project_id) def gemini(request: GenerateRequest) -> GenerateResponse: - geminiMsgs: list[Content] = [] + gemini_msgs: list[Content] = [] for m in request.messages: - geminiParts: list[Part] = [] + gemini_parts: list[Part] = [] for p in m.content: if p.root.text is not None: - geminiParts.append(Part.from_text(p.root.text)) + gemini_parts.append(Part.from_text(p.root.text)) else: raise Exception('unsupported part type') - geminiMsgs.append(Content(role=m.role.value, parts=geminiParts)) + gemini_msgs.append( + Content(role=m.role.value, parts=gemini_parts) + ) model = GenerativeModel('gemini-1.5-flash-002') - response = model.generate_content(contents=geminiMsgs) + response = model.generate_content(contents=gemini_msgs) return GenerateResponse( message=Message( role='model', content=[TextPart(text=response.text)] @@ -64,4 +65,4 @@ def gemini(name: str) -> str: return f'vertexai/{name}' -__all__ = ['package_name', 'vertexAI', 'gemini'] +__all__ = ['package_name', 'vertex_ai', 'gemini'] diff --git a/py/pyproject.toml b/py/pyproject.toml index 1a1f15b45..59e594a72 100644 --- a/py/pyproject.toml +++ b/py/pyproject.toml @@ -75,12 +75,56 @@ members = ["packages/*", "plugins/*", "samples/*"] # Ruff checks and formatting. [tool.ruff] +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".ipynb_checkpoints", + ".mypy_cache", + ".nox", + ".pants.d", + ".pyenv", + ".pytest_cache", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "bazel-*", + "buck-out", + "build", + "dist", + "node_modules", + "site-packages", + "venv", +] indent-width = 4 -line-length = 80 +line-length = 80 +target-version = "py312" + +[tool.ruff.lint] +fixable = ["ALL"] +select = [ + "E", # pycodestyle (errors) + "W", # pycodestyle (warnings) + "F", # pyflakes + "I", # isort (import sorting) + "UP", # pyupgrade (Python version upgrades) + "B", # flake8-bugbear (common bugs) + "N", # pep8-naming (naming conventions) +] [tool.ruff.format] -line-ending = "lf" -quote-style = "single" +indent-style = "space" +line-ending = "lf" +quote-style = "single" +skip-magic-trailing-comma = false # Static type checking. [tool.mypy] @@ -89,8 +133,10 @@ disallow_untyped_defs = true warn_unused_configs = true [tool.datamodel-codegen] -#strict-types = ["str", "int", "float", "bool", "bytes"] # Don't use; produces StrictStr, StrictInt, etc. #collapse-root-models = true # Don't use; produces Any as types. +#strict-types = ["str", "int", "float", "bool", "bytes"] # Don't use; produces StrictStr, StrictInt, etc. +#use-subclass-enum = true +capitalize-enum-members = true disable-timestamp = true enable-version-header = true field-constraints = true @@ -98,8 +144,10 @@ input = "../genkit-tools/genkit-schema.json" input-file-type = "jsonschema" output = "packages/genkit/src/genkit/core/schemas.py" output-model-type = "pydantic_v2.BaseModel" +snake-case-field = true strict-nullable = true target-python-version = "3.12" use-schema-description = true use-standard-collections = true use-union-operator = true +use-unique-items-as-set = true diff --git a/py/samples/hello/hello.py b/py/samples/hello/hello.py index 9a74fa4c4..0ac640592 100644 --- a/py/samples/hello/hello.py +++ b/py/samples/hello/hello.py @@ -2,12 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 -from genkit.core.schemas import Message, TextPart, GenerateRequest -from genkit.plugins.vertex_ai import vertexAI, gemini +from genkit.core.schemas import GenerateRequest, Message, TextPart +from genkit.plugins.vertex_ai import gemini, vertex_ai from genkit.veneer.veneer import Genkit from pydantic import BaseModel, Field -ai = Genkit(plugins=[vertexAI()], model=gemini('gemini-1.5-flash')) +ai = Genkit(plugins=[vertex_ai()], model=gemini('gemini-1.5-flash')) class MyInput(BaseModel):