From f2d9cb06ee4dafaa07e9966c8da9db95bfdb54ed Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Tue, 18 Feb 2025 17:08:10 -0500 Subject: [PATCH] fix(js/plugins): fixed gemini tool schema converter to correctly handle descriptions and nullable fields --- js/plugins/googleai/src/gemini.ts | 44 +++++++++++++++-- js/plugins/googleai/tests/gemini_test.ts | 62 +++++++++++++++++++++++- js/plugins/vertexai/src/gemini.ts | 58 +++++++++++++++++++--- js/plugins/vertexai/tests/gemini_test.ts | 62 +++++++++++++++++++++++- js/testapps/flow-simple-ai/src/index.ts | 60 +++++++++++++++++++++-- 5 files changed, 267 insertions(+), 19 deletions(-) diff --git a/js/plugins/googleai/src/gemini.ts b/js/plugins/googleai/src/gemini.ts index b758ed46e..68179a765 100644 --- a/js/plugins/googleai/src/gemini.ts +++ b/js/plugins/googleai/src/gemini.ts @@ -29,6 +29,7 @@ import { GoogleGenerativeAI, InlineDataPart, RequestOptions, + Schema, SchemaType, StartChatParams, Tool, @@ -330,31 +331,64 @@ function toGeminiRole( function convertSchemaProperty(property) { if (!property || !property.type) { - return null; + return undefined; + } + const baseSchema = {} as Schema; + if (property.description) { + baseSchema.description = property.description; + } + if (property.enum) { + baseSchema.enum = property.enum; + } + if (property.nullable) { + baseSchema.nullable = property.nullable; } - if (property.type === 'object') { + let propertyType; + // nullable schema can ALSO be defined as, for example, type=['string','null'] + if (Array.isArray(property.type)) { + const types = property.type as string[]; + if (types.includes('null')) { + baseSchema.nullable = true; + } + // grab the type that's not `null` + propertyType = types.find((t) => t !== 'null'); + } else { + propertyType = property.type; + } + if (propertyType === 'object') { const nestedProperties = {}; Object.keys(property.properties).forEach((key) => { nestedProperties[key] = convertSchemaProperty(property.properties[key]); }); return { + ...baseSchema, type: SchemaType.OBJECT, properties: nestedProperties, required: property.required, }; - } else if (property.type === 'array') { + } else if (propertyType === 'array') { return { + ...baseSchema, type: SchemaType.ARRAY, items: convertSchemaProperty(property.items), }; } else { + const schemaType = SchemaType[propertyType.toUpperCase()] as SchemaType; + if (!schemaType) { + throw new GenkitError({ + status: 'INVALID_ARGUMENT', + message: `Unsupported property type ${propertyType.toUpperCase()}`, + }); + } return { - type: SchemaType[property.type.toUpperCase()], + ...baseSchema, + type: schemaType, }; } } -function toGeminiTool( +/** @hidden */ +export function toGeminiTool( tool: z.infer ): FunctionDeclaration { const declaration: FunctionDeclaration = { diff --git a/js/plugins/googleai/tests/gemini_test.ts b/js/plugins/googleai/tests/gemini_test.ts index 0ce663c97..88b794435 100644 --- a/js/plugins/googleai/tests/gemini_test.ts +++ b/js/plugins/googleai/tests/gemini_test.ts @@ -16,8 +16,9 @@ import { GenerateContentCandidate } from '@google/generative-ai'; import * as assert from 'assert'; -import { genkit } from 'genkit'; +import { genkit, z } from 'genkit'; import { MessageData, ModelInfo } from 'genkit/model'; +import { toJsonSchema } from 'genkit/schema'; import { afterEach, beforeEach, describe, it } from 'node:test'; import { GENERIC_GEMINI_MODEL, @@ -28,6 +29,7 @@ import { gemini15Pro, toGeminiMessage, toGeminiSystemInstruction, + toGeminiTool, } from '../src/gemini.js'; import { googleAI } from '../src/index.js'; @@ -501,6 +503,64 @@ describe('plugin', () => { }); }); +describe('toGeminiTool', () => { + it('', async () => { + const got = toGeminiTool({ + name: 'foo', + description: 'tool foo', + inputSchema: toJsonSchema({ + schema: z.object({ + simpleString: z.string().describe('a string').nullable(), + simpleNumber: z.number().describe('a number'), + simpleBoolean: z.boolean().describe('a boolean').optional(), + simpleArray: z.array(z.string()).describe('an array').optional(), + simpleEnum: z + .enum(['choice_a', 'choice_b']) + .describe('an enum') + .optional(), + }), + }), + }); + + const want = { + description: 'tool foo', + name: 'foo', + parameters: { + properties: { + simpleArray: { + description: 'an array', + items: { + type: 'string', + }, + type: 'array', + }, + simpleBoolean: { + description: 'a boolean', + type: 'boolean', + }, + simpleEnum: { + description: 'an enum', + enum: ['choice_a', 'choice_b'], + type: 'string', + }, + simpleNumber: { + description: 'a number', + type: 'number', + }, + simpleString: { + description: 'a string', + nullable: true, + type: 'string', + }, + }, + required: ['simpleString', 'simpleNumber'], + type: 'object', + }, + }; + assert.deepStrictEqual(got, want); + }); +}); + function assertEqualModelInfo( modelAction: ModelInfo, expectedLabel: string, diff --git a/js/plugins/vertexai/src/gemini.ts b/js/plugins/vertexai/src/gemini.ts index 3a5938329..00a605491 100644 --- a/js/plugins/vertexai/src/gemini.ts +++ b/js/plugins/vertexai/src/gemini.ts @@ -25,13 +25,20 @@ import { GenerativeModelPreview, HarmBlockThreshold, HarmCategory, + Schema, StartChatParams, ToolConfig, VertexAI, type GoogleSearchRetrieval, } from '@google-cloud/vertexai'; import { ApiClient } from '@google-cloud/vertexai/build/src/resources/index.js'; -import { GENKIT_CLIENT_HEADER, Genkit, JSONSchema, z } from 'genkit'; +import { + GENKIT_CLIENT_HEADER, + Genkit, + GenkitError, + JSONSchema, + z, +} from 'genkit'; import { CandidateData, GenerateRequest, @@ -424,7 +431,8 @@ function toGeminiRole( } } -const toGeminiTool = ( +/** @hidden */ +export const toGeminiTool = ( tool: z.infer ): FunctionDeclaration => { const declaration: FunctionDeclaration = { @@ -645,31 +653,65 @@ export function fromGeminiCandidate( // Translate JSON schema to Vertex AI's format. Specifically, the type field needs be mapped. // Since JSON schemas can include nested arrays/objects, we have to recursively map the type field // in all nested fields. -const convertSchemaProperty = (property) => { +function convertSchemaProperty(property) { if (!property || !property.type) { - return null; + return undefined; + } + const baseSchema = {} as Schema; + if (property.description) { + baseSchema.description = property.description; } - if (property.type === 'object') { + if (property.enum) { + baseSchema.enum = property.enum; + } + if (property.nullable) { + baseSchema.nullable = property.nullable; + } + let propertyType; + // nullable schema can ALSO be defined as, for example, type=['string','null'] + if (Array.isArray(property.type)) { + const types = property.type as string[]; + if (types.includes('null')) { + baseSchema.nullable = true; + } + // grab the type that's not `null` + propertyType = types.find((t) => t !== 'null'); + } else { + propertyType = property.type; + } + if (propertyType === 'object') { const nestedProperties = {}; Object.keys(property.properties).forEach((key) => { nestedProperties[key] = convertSchemaProperty(property.properties[key]); }); return { + ...baseSchema, type: FunctionDeclarationSchemaType.OBJECT, properties: nestedProperties, required: property.required, }; - } else if (property.type === 'array') { + } else if (propertyType === 'array') { return { + ...baseSchema, type: FunctionDeclarationSchemaType.ARRAY, items: convertSchemaProperty(property.items), }; } else { + const schemaType = FunctionDeclarationSchemaType[ + propertyType.toUpperCase() + ] as FunctionDeclarationSchemaType; + if (!schemaType) { + throw new GenkitError({ + status: 'INVALID_ARGUMENT', + message: `Unsupported property type ${propertyType.toUpperCase()}`, + }); + } return { - type: FunctionDeclarationSchemaType[property.type.toUpperCase()], + ...baseSchema, + type: schemaType, }; } -}; +} export function cleanSchema(schema: JSONSchema): JSONSchema { const out = structuredClone(schema); diff --git a/js/plugins/vertexai/tests/gemini_test.ts b/js/plugins/vertexai/tests/gemini_test.ts index aeb388d37..4df3a7131 100644 --- a/js/plugins/vertexai/tests/gemini_test.ts +++ b/js/plugins/vertexai/tests/gemini_test.ts @@ -16,13 +16,15 @@ import { GenerateContentCandidate } from '@google-cloud/vertexai'; import * as assert from 'assert'; -import { MessageData } from 'genkit'; +import { MessageData, z } from 'genkit'; +import { toJsonSchema } from 'genkit/schema'; import { describe, it } from 'node:test'; import { cleanSchema, fromGeminiCandidate, toGeminiMessage, toGeminiSystemInstruction, + toGeminiTool, } from '../src/gemini.js'; describe('toGeminiMessages', () => { @@ -381,3 +383,61 @@ describe('cleanSchema', () => { }); }); }); + +describe('toGeminiTool', () => { + it('', async () => { + const got = toGeminiTool({ + name: 'foo', + description: 'tool foo', + inputSchema: toJsonSchema({ + schema: z.object({ + simpleString: z.string().describe('a string').nullable(), + simpleNumber: z.number().describe('a number'), + simpleBoolean: z.boolean().describe('a boolean').optional(), + simpleArray: z.array(z.string()).describe('an array').optional(), + simpleEnum: z + .enum(['choice_a', 'choice_b']) + .describe('an enum') + .optional(), + }), + }), + }); + + const want = { + description: 'tool foo', + name: 'foo', + parameters: { + properties: { + simpleArray: { + description: 'an array', + items: { + type: 'STRING', + }, + type: 'ARRAY', + }, + simpleBoolean: { + description: 'a boolean', + type: 'BOOLEAN', + }, + simpleEnum: { + description: 'an enum', + enum: ['choice_a', 'choice_b'], + type: 'STRING', + }, + simpleNumber: { + description: 'a number', + type: 'NUMBER', + }, + simpleString: { + description: 'a string', + nullable: true, + type: 'STRING', + }, + }, + required: ['simpleString', 'simpleNumber'], + type: 'OBJECT', + }, + }; + assert.deepStrictEqual(got, want); + }); +}); diff --git a/js/testapps/flow-simple-ai/src/index.ts b/js/testapps/flow-simple-ai/src/index.ts index fea7b141f..38bc0f153 100644 --- a/js/testapps/flow-simple-ai/src/index.ts +++ b/js/testapps/flow-simple-ai/src/index.ts @@ -21,7 +21,11 @@ import { googleAI, gemini10Pro as googleGemini10Pro, } from '@genkit-ai/googleai'; -import { textEmbedding004, vertexAI } from '@genkit-ai/vertexai'; +import { + textEmbedding004, + vertexAI, + gemini15Flash as vertexGemini15Flash, +} from '@genkit-ai/vertexai'; import { GoogleAIFileManager } from '@google/generative-ai/server'; import { AlwaysOnSampler } from '@opentelemetry/sdk-trace-base'; import { initializeApp } from 'firebase-admin/app'; @@ -386,7 +390,11 @@ const gablorkenTool = ai.defineTool( { name: 'gablorkenTool', inputSchema: z.object({ - value: z.number(), + value: z + .number() + .describe( + 'always add 1 to the value (it is 1 based, but upstream it is zero based)' + ), }), description: 'can be used to calculate gablorken value', }, @@ -395,6 +403,23 @@ const gablorkenTool = ai.defineTool( } ); +const characterGenerator = ai.defineTool( + { + name: 'characterGenerator', + inputSchema: z.object({ + age: z.number().describe('must be between 23 and 27'), + type: z.enum(['archer', 'banana']), + name: z.string().describe('can only be Bob or John'), + surname: z.string(), + }), + description: + 'can be used to generate a character. Seed it with some input.', + }, + async (input) => { + return input; + } +); + export const toolCaller = ai.defineFlow( { name: 'toolCaller', @@ -441,7 +466,7 @@ export const forcedToolCaller = ai.defineFlow( }, async (input, { sendChunk }) => { const { response, stream } = ai.generateStream({ - model: gemini15Flash, + model: vertexGemini15Flash, config: { temperature: 1, }, @@ -458,6 +483,31 @@ export const forcedToolCaller = ai.defineFlow( } ); +export const toolCallerCharacterGenerator = ai.defineFlow( + { + name: 'toolCallerCharacterGenerator', + inputSchema: z.number(), + streamSchema: z.any(), + }, + async (input, { sendChunk }) => { + const { response, stream } = ai.generateStream({ + model: vertexGemini15Flash, + config: { + temperature: 1, + }, + tools: [characterGenerator, exitTool], + toolChoice: 'required', + prompt: `generate an archer character`, + }); + + for await (const chunk of stream) { + sendChunk(chunk); + } + + return await response; + } +); + export const invalidOutput = ai.defineFlow( { name: 'invalidOutput', @@ -657,9 +707,10 @@ ai.defineFlow('blockingMiddleware', async () => { ai.defineFlow('formatJson', async (input, { sendChunk }) => { const { output, text } = await ai.generate({ + model: gemini15Flash, prompt: `generate an RPG game character of type ${input || 'archer'}`, output: { - constrained: false, + constrained: true, instructions: true, schema: z .object({ @@ -675,6 +726,7 @@ ai.defineFlow('formatJson', async (input, { sendChunk }) => { ai.defineFlow('formatJsonManualSchema', async (input, { sendChunk }) => { const { output, text } = await ai.generate({ + model: gemini15Flash, prompt: `generate one RPG game character of type ${input || 'archer'} and generated JSON must match this interface \`\`\`typescript