Skip to content

Commit

Permalink
fix: formatting, linting, and schema code generation #1935
Browse files Browse the repository at this point in the history
ISSUE: #1935

CHANGELOG:
- [ ] Addresses error handling and parallelized formatting for toml
  files.
- [ ] Configures the ruff formatter to use requested format.
- [ ] Fixes snake_case for several Python variables.
- [ ] Configures the schemas.py code generator to generate snake_case
  field names to avoid lint reports and wraps the actual name in Field()
  metadata.
- [ ] Clean up some lint.
- [ ] Removes go-vulncheck from the pre-commit hook while retaining it
  during pre-push (easy to skip this using `-f --no-verify`).
  • Loading branch information
yesudeep committed Feb 12, 2025
1 parent 5033df0 commit 8c29807
Show file tree
Hide file tree
Showing 18 changed files with 343 additions and 198 deletions.
91 changes: 51 additions & 40 deletions py/bin/format_toml_files
Original file line number Diff line number Diff line change
Expand Up @@ -9,52 +9,63 @@ set -euo pipefail

GIT_ROOT=$(git rev-parse --show-toplevel)

if command -v rust-parallel >/dev/null 2>&1; then
if command -v taplo >/dev/null 2>&1; then
if [ ! -f "${GIT_ROOT}/py/taplo.toml" ]; then
echo "error: config file not found at ${GIT_ROOT}/py/taplo.toml"
exit 1
fi

FORMATTER_COMMAND="taplo format --config ${GIT_ROOT}/py/taplo.toml"
if command -v rust-parallel >/dev/null 2>&1; then
FORMATTER_COMMAND="rust-parallel -j4 ${FORMATTER_COMMAND}"
fi

if command -v fd >/dev/null 2>&1; then
echo "Using fd"
fd -e toml \
--exclude 'py/**/*.egg-info/**' \
--exclude 'py/**/.dist/**' \
--exclude 'py/**/.next/**' \
--exclude 'py/**/.output/**' \
--exclude 'py/**/.pytest_cache/**' \
--exclude 'py/**/.venv/**' \
--exclude 'py/**/__pycache__/**' \
--exclude 'py/**/build/**' \
--exclude 'py/**/develop-eggs/**' \
--exclude 'py/**/dist/**' \
--exclude 'py/**/eggs/**' \
--exclude 'py/**/node_modules/**' \
--exclude 'py/**/sdist/**' \
--exclude 'py/**/site/**' \
--exclude 'py/**/target/**' \
--exclude 'py/**/venv/**' \
--exclude 'py/**/wheels/**' |
rust-parallel -j4 \
taplo format --config "${GIT_ROOT}/py/taplo.toml"
--exclude '**/*.egg-info/**' \
--exclude '**/.dist/**' \
--exclude '**/.next/**' \
--exclude '**/.output/**' \
--exclude '**/.pytest_cache/**' \
--exclude '**/.venv/**' \
--exclude '**/__pycache__/**' \
--exclude '**/bazel-*/**' \
--exclude '**/build/**' \
--exclude '**/develop-eggs/**' \
--exclude '**/dist/**' \
--exclude '**/eggs/**' \
--exclude '**/node_modules/**' \
--exclude '**/sdist/**' \
--exclude '**/site/**' \
--exclude '**/target/**' \
--exclude '**/venv/**' \
--exclude '**/wheels/**' |
${FORMATTER_COMMAND}
else
echo "Using find"
find "${GIT_ROOT}" -name "*.toml" \
! -path 'py/**/*.egg-info/**' \
! -path 'py/**/.dist/**' \
! -path 'py/**/.next/**' \
! -path 'py/**/.output/**' \
! -path 'py/**/.pytest_cache/**' \
! -path 'py/**/.venv/**' \
! -path 'py/**/__pycache__/**' \
! -path 'py/**/build/**' \
! -path 'py/**/develop-eggs/**' \
! -path 'py/**/dist/**' \
! -path 'py/**/eggs/**' \
! -path 'py/**/node_modules/**' \
! -path 'py/**/sdist/**' \
! -path 'py/**/site/**' \
! -path 'py/**/target/**' \
! -path 'py/**/venv/**' \
! -path 'py/**/wheels/**' \
! -path '**/*.egg-info/**' \
! -path '**/.dist/**' \
! -path '**/.next/**' \
! -path '**/.output/**' \
! -path '**/.pytest_cache/**' \
! -path '**/.venv/**' \
! -path '**/__pycache__/**' \
! -path '**/bazel-*/**' \
! -path '**/build/**' \
! -path '**/develop-eggs/**' \
! -path '**/dist/**' \
! -path '**/eggs/**' \
! -path '**/node_modules/**' \
! -path '**/sdist/**' \
! -path '**/site/**' \
! -path '**/target/**' \
! -path '**/venv/**' \
! -path '**/wheels/**' \
-print0 |
rust-parallel -j4 \
taplo format --config "${GIT_ROOT}/py/taplo.toml"
${FORMATTER_COMMAND}
fi
else
echo "Please install GNU parallel to use this script"
echo "Please install taplo to use this script"
fi
69 changes: 69 additions & 0 deletions py/bin/format_toml_files_filtered_serial
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/usr/bin/env bash
#
# Format all TOML files in the project.
#
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

set -euo pipefail

GIT_ROOT=$(git rev-parse --show-toplevel)

if command -v taplo >/dev/null 2>&1; then
if [ ! -f "${GIT_ROOT}/py/taplo.toml" ]; then
echo "error: config file not found at ${GIT_ROOT}/py/taplo.toml"
exit 1
fi

FORMATTER_COMMAND="taplo format --config ${GIT_ROOT}/py/taplo.toml"
#if command -v rust-parallel >/dev/null 2>&1; then
# FORMATTER_COMMAND="rust-parallel -j4 ${FORMATTER_COMMAND}"
#fi

if command -v fd >/dev/null 2>&1; then
echo "Using fd"
fd -e toml \
--exclude 'py/**/*.egg-info/**' \
--exclude 'py/**/.dist/**' \
--exclude 'py/**/.next/**' \
--exclude 'py/**/.output/**' \
--exclude 'py/**/.pytest_cache/**' \
--exclude 'py/**/.venv/**' \
--exclude 'py/**/__pycache__/**' \
--exclude 'py/**/build/**' \
--exclude 'py/**/develop-eggs/**' \
--exclude 'py/**/dist/**' \
--exclude 'py/**/eggs/**' \
--exclude 'py/**/node_modules/**' \
--exclude 'py/**/sdist/**' \
--exclude 'py/**/site/**' \
--exclude 'py/**/target/**' \
--exclude 'py/**/venv/**' \
--exclude 'py/**/wheels/**' |
${FORMATTER_COMMAND}
else
echo "Using find"
find "${GIT_ROOT}" -name "*.toml" \
! -path 'py/**/*.egg-info/**' \
! -path 'py/**/.dist/**' \
! -path 'py/**/.next/**' \
! -path 'py/**/.output/**' \
! -path 'py/**/.pytest_cache/**' \
! -path 'py/**/.venv/**' \
! -path 'py/**/__pycache__/**' \
! -path 'py/**/build/**' \
! -path 'py/**/develop-eggs/**' \
! -path 'py/**/dist/**' \
! -path 'py/**/eggs/**' \
! -path 'py/**/node_modules/**' \
! -path 'py/**/sdist/**' \
! -path 'py/**/site/**' \
! -path 'py/**/target/**' \
! -path 'py/**/venv/**' \
! -path 'py/**/wheels/**' \
-print0 |
${FORMATTER_COMMAND}
fi
else
echo "Please install taplo to use this script"
fi
14 changes: 14 additions & 0 deletions py/bin/format_toml_files_serial
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/usr/bin/env bash
#
# Format all TOML files in the project.
#
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

set -euo pipefail

GIT_ROOT=$(git rev-parse --show-toplevel)

taplo format --config "${GIT_ROOT}/py/taplo.toml"

exit $?
20 changes: 10 additions & 10 deletions py/bin/generate_schema_types
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,30 @@
set -euo pipefail

TOP_DIR=$(git rev-parse --show-toplevel)
SCHEMA_FILE="$TOP_DIR/py/packages/genkit/src/genkit/core/schemas.py"
SCHEMA_FILE="${TOP_DIR}/py/packages/genkit/src/genkit/core/schemas.py"

# Generate types using configuration from pyproject.toml
uv run --directory "$TOP_DIR/py" datamodel-codegen
uv run --directory "${TOP_DIR}/py" datamodel-codegen

# This isn't causing runtime errors at the moment so letting it be.
#sed -i '' '/^class Model(RootModel\[Any\]):$/,/^ root: Any$/d' "$SCHEMA_FILE"
#sed -i '' '/^class Model(RootModel\[Any\]):$/,/^ root: Any$/d' "${SCHEMA_FILE}"

# Sanitize the generated schema.
python3 "${TOP_DIR}/py/bin/sanitize_schemas.py" "$SCHEMA_FILE"
python3 "${TOP_DIR}/py/bin/sanitize_schemas.py" "${SCHEMA_FILE}"

# Add a generated by `generate_schema_types` comment.
sed -i '' '1i\
# DO NOT EDIT: Generated by `generate_schema_types` from `genkit-schemas.json`.
' "$SCHEMA_FILE"
' "${SCHEMA_FILE}"

# Add license header.
addlicense \
-c "Google LLC" \
-s=only \
"$SCHEMA_FILE"
"${SCHEMA_FILE}"

# Checks and formatting.
uv run --directory "$TOP_DIR/py" \
ruff check --fix "$SCHEMA_FILE"
uv run --directory "$TOP_DIR/py" \
ruff format "$SCHEMA_FILE"
uv run --directory "${TOP_DIR}/py" \
ruff format "${TOP_DIR}"
uv run --directory "${TOP_DIR}/py" \
ruff check --fix "${SCHEMA_FILE}"
7 changes: 4 additions & 3 deletions py/bin/sanitize_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ def is_rootmodel_class(self, node: ast.ClassDef) -> bool:
return True
return False

def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
"""Visit class definitions and remove model_config if class inherits from RootModel."""
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: # noqa: N802
"""Visit class definitions and remove model_config if class
inherits from RootModel."""
if self.is_rootmodel_class(node):
# Filter out model_config assignments
new_body = []
Expand All @@ -48,7 +49,7 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:

def process_file(filename: str) -> None:
"""Process a Python file to remove model_config from RootModel classes."""
with open(filename, 'r') as f:
with open(filename) as f:
source = f.read()

tree = ast.parse(source)
Expand Down
2 changes: 1 addition & 1 deletion py/bin/setup
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ function genkit::install_google_cloud_sdk() {
# This depends on Python 3.11 and installs it for the user on some systems.
if command -v gcloud &>/dev/null; then
gcloud config set disable_usage_reporting true
gcloud components update
yes | gcloud components update
return 0
fi

Expand Down
13 changes: 1 addition & 12 deletions py/captainhook.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,6 @@
},
{
"run": "go test go/..."
},
{
"run": "govulncheck -C go ./...",
"conditions": [
{
"run": "CaptainHook::FileChanged.Any",
"options": {
"files": ["go/go.mod", "go.sum", "*.go"]
}
}
]
}
]
},
Expand Down Expand Up @@ -111,7 +100,7 @@
{
"run": "CaptainHook::FileChanged.Any",
"options": {
"files": ["go/go.mod", "go.sum", "*.go"]
"files": ["go/**/*.go", "go/**/go.mod", "go/**/go.sum"]
}
}
]
Expand Down
3 changes: 2 additions & 1 deletion py/packages/genkit/src/genkit/ai/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

from typing import Callable
from collections.abc import Callable

from genkit.core.schemas import GenerateRequest, GenerateResponse

ModelFn = Callable[[GenerateRequest], GenerateResponse]
7 changes: 4 additions & 3 deletions py/packages/genkit/src/genkit/ai/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
# SPDX-License-Identifier: Apache-2.0


from typing import Callable, Optional, Any
from genkit.core.schemas import GenerateRequest
from collections.abc import Callable
from typing import Any

from genkit.core.schemas import GenerateRequest

PromptFn = Callable[[Optional[Any]], GenerateRequest]
PromptFn = Callable[[Any | None], GenerateRequest]
15 changes: 8 additions & 7 deletions py/packages/genkit/src/genkit/core/action.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.

import inspect
import json

from typing import Dict, Optional, Callable, Any
from pydantic import ConfigDict, BaseModel, TypeAdapter
from collections.abc import Callable
from typing import Any

from genkit.core.tracing import tracer
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter


class ActionResponse(BaseModel):
model_config = ConfigDict(extra='forbid')

response: Any
traceId: str
trace_id: str = Field(alias='traceId')


class Action:
Expand All @@ -26,9 +27,9 @@ def __init__(
action_type: str,
name: str,
fn: Callable,
description: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
span_metadata: Optional[Dict[str, str]] = None,
description: str | None = None,
metadata: dict[str, Any] | None = None,
span_metadata: dict[str, str] | None = None,
):
# TODO(Tatsiana Havina): separate a long constructor into methods.
self.type = action_type
Expand Down
7 changes: 3 additions & 4 deletions py/packages/genkit/src/genkit/core/reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
"""Exposes an API for inspecting and interacting with Genkit in development."""

import json

from http.server import BaseHTTPRequestHandler
from pydantic import BaseModel

from genkit.core.registry import Registry
from pydantic import BaseModel


def make_reflection_server(registry: Registry):
Expand All @@ -20,7 +19,7 @@ class ReflectionServer(BaseHTTPRequestHandler):

ENCODING = 'utf-8'

def do_GET(self):
def do_GET(self) -> None: # noqa: N802
"""Handles GET requests."""
if self.path == '/api/__health':
self.send_response(200)
Expand Down Expand Up @@ -49,7 +48,7 @@ def do_GET(self):
self.send_response(404)
self.end_headers()

def do_POST(self):
def do_POST(self) -> None: # noqa: N802
"""Handles POST requests."""
if self.path == '/api/notify':
self.send_response(200)
Expand Down
3 changes: 1 addition & 2 deletions py/packages/genkit/src/genkit/core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@

"""The registry is used to store and lookup resources."""

from typing import Dict
from genkit.core.action import Action


class Registry:
"""Stores actions, trace stores, flow state stores, plugins, and schemas."""

actions: Dict[str, Dict[str, Action]] = {}
actions: dict[str, dict[str, Action]] = {}

def register_action(self, action_type: str, name: str, action: Action):
if action_type not in self.actions:
Expand Down
Loading

0 comments on commit 8c29807

Please sign in to comment.