Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(js/plugins): fixed gemini tool schema converter to correctly handle descriptions, enums and nullable fields #2027

Merged
merged 1 commit into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 39 additions & 5 deletions js/plugins/googleai/src/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import {
GoogleGenerativeAI,
InlineDataPart,
RequestOptions,
Schema,
SchemaType,
StartChatParams,
Tool,
Expand Down Expand Up @@ -330,31 +331,64 @@ function toGeminiRole(

function convertSchemaProperty(property) {
if (!property || !property.type) {
return null;
return undefined;
}
const baseSchema = {} as Schema;
if (property.description) {
baseSchema.description = property.description;
}
if (property.enum) {
baseSchema.enum = property.enum;
}
if (property.nullable) {
baseSchema.nullable = property.nullable;
}
if (property.type === 'object') {
let propertyType;
// nullable schema can ALSO be defined as, for example, type=['string','null']
if (Array.isArray(property.type)) {
const types = property.type as string[];
if (types.includes('null')) {
baseSchema.nullable = true;
}
// grab the type that's not `null`
propertyType = types.find((t) => t !== 'null');
} else {
propertyType = property.type;
}
if (propertyType === 'object') {
const nestedProperties = {};
Object.keys(property.properties).forEach((key) => {
nestedProperties[key] = convertSchemaProperty(property.properties[key]);
});
return {
...baseSchema,
type: SchemaType.OBJECT,
properties: nestedProperties,
required: property.required,
};
} else if (property.type === 'array') {
} else if (propertyType === 'array') {
return {
...baseSchema,
type: SchemaType.ARRAY,
items: convertSchemaProperty(property.items),
};
} else {
const schemaType = SchemaType[propertyType.toUpperCase()] as SchemaType;
if (!schemaType) {
throw new GenkitError({
status: 'INVALID_ARGUMENT',
message: `Unsupported property type ${propertyType.toUpperCase()}`,
});
}
return {
type: SchemaType[property.type.toUpperCase()],
...baseSchema,
type: schemaType,
};
}
}

function toGeminiTool(
/** @hidden */
export function toGeminiTool(
tool: z.infer<typeof ToolDefinitionSchema>
): FunctionDeclaration {
const declaration: FunctionDeclaration = {
Expand Down
62 changes: 61 additions & 1 deletion js/plugins/googleai/tests/gemini_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

import { GenerateContentCandidate } from '@google/generative-ai';
import * as assert from 'assert';
import { genkit } from 'genkit';
import { genkit, z } from 'genkit';
import { MessageData, ModelInfo } from 'genkit/model';
import { toJsonSchema } from 'genkit/schema';
import { afterEach, beforeEach, describe, it } from 'node:test';
import {
GENERIC_GEMINI_MODEL,
Expand All @@ -28,6 +29,7 @@ import {
gemini15Pro,
toGeminiMessage,
toGeminiSystemInstruction,
toGeminiTool,
} from '../src/gemini.js';
import { googleAI } from '../src/index.js';

Expand Down Expand Up @@ -501,6 +503,64 @@ describe('plugin', () => {
});
});

describe('toGeminiTool', () => {
it('', async () => {
const got = toGeminiTool({
name: 'foo',
description: 'tool foo',
inputSchema: toJsonSchema({
schema: z.object({
simpleString: z.string().describe('a string').nullable(),
simpleNumber: z.number().describe('a number'),
simpleBoolean: z.boolean().describe('a boolean').optional(),
simpleArray: z.array(z.string()).describe('an array').optional(),
simpleEnum: z
.enum(['choice_a', 'choice_b'])
.describe('an enum')
.optional(),
}),
}),
});

const want = {
description: 'tool foo',
name: 'foo',
parameters: {
properties: {
simpleArray: {
description: 'an array',
items: {
type: 'string',
},
type: 'array',
},
simpleBoolean: {
description: 'a boolean',
type: 'boolean',
},
simpleEnum: {
description: 'an enum',
enum: ['choice_a', 'choice_b'],
type: 'string',
},
simpleNumber: {
description: 'a number',
type: 'number',
},
simpleString: {
description: 'a string',
nullable: true,
type: 'string',
},
},
required: ['simpleString', 'simpleNumber'],
type: 'object',
},
};
assert.deepStrictEqual(got, want);
});
});

function assertEqualModelInfo(
modelAction: ModelInfo,
expectedLabel: string,
Expand Down
58 changes: 50 additions & 8 deletions js/plugins/vertexai/src/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,20 @@ import {
GenerativeModelPreview,
HarmBlockThreshold,
HarmCategory,
Schema,
StartChatParams,
ToolConfig,
VertexAI,
type GoogleSearchRetrieval,
} from '@google-cloud/vertexai';
import { ApiClient } from '@google-cloud/vertexai/build/src/resources/index.js';
import { GENKIT_CLIENT_HEADER, Genkit, JSONSchema, z } from 'genkit';
import {
GENKIT_CLIENT_HEADER,
Genkit,
GenkitError,
JSONSchema,
z,
} from 'genkit';
import {
CandidateData,
GenerateRequest,
Expand Down Expand Up @@ -424,7 +431,8 @@ function toGeminiRole(
}
}

const toGeminiTool = (
/** @hidden */
export const toGeminiTool = (
tool: z.infer<typeof ToolDefinitionSchema>
): FunctionDeclaration => {
const declaration: FunctionDeclaration = {
Expand Down Expand Up @@ -645,31 +653,65 @@ export function fromGeminiCandidate(
// Translate JSON schema to Vertex AI's format. Specifically, the type field needs be mapped.
// Since JSON schemas can include nested arrays/objects, we have to recursively map the type field
// in all nested fields.
const convertSchemaProperty = (property) => {
function convertSchemaProperty(property) {
if (!property || !property.type) {
return null;
return undefined;
}
const baseSchema = {} as Schema;
if (property.description) {
baseSchema.description = property.description;
}
if (property.type === 'object') {
if (property.enum) {
baseSchema.enum = property.enum;
}
if (property.nullable) {
baseSchema.nullable = property.nullable;
}
let propertyType;
// nullable schema can ALSO be defined as, for example, type=['string','null']
if (Array.isArray(property.type)) {
const types = property.type as string[];
if (types.includes('null')) {
baseSchema.nullable = true;
}
// grab the type that's not `null`
propertyType = types.find((t) => t !== 'null');
} else {
propertyType = property.type;
}
if (propertyType === 'object') {
const nestedProperties = {};
Object.keys(property.properties).forEach((key) => {
nestedProperties[key] = convertSchemaProperty(property.properties[key]);
});
return {
...baseSchema,
type: FunctionDeclarationSchemaType.OBJECT,
properties: nestedProperties,
required: property.required,
};
} else if (property.type === 'array') {
} else if (propertyType === 'array') {
return {
...baseSchema,
type: FunctionDeclarationSchemaType.ARRAY,
items: convertSchemaProperty(property.items),
};
} else {
const schemaType = FunctionDeclarationSchemaType[
propertyType.toUpperCase()
] as FunctionDeclarationSchemaType;
if (!schemaType) {
throw new GenkitError({
status: 'INVALID_ARGUMENT',
message: `Unsupported property type ${propertyType.toUpperCase()}`,
});
}
return {
type: FunctionDeclarationSchemaType[property.type.toUpperCase()],
...baseSchema,
type: schemaType,
};
}
};
}

export function cleanSchema(schema: JSONSchema): JSONSchema {
const out = structuredClone(schema);
Expand Down
62 changes: 61 additions & 1 deletion js/plugins/vertexai/tests/gemini_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@

import { GenerateContentCandidate } from '@google-cloud/vertexai';
import * as assert from 'assert';
import { MessageData } from 'genkit';
import { MessageData, z } from 'genkit';
import { toJsonSchema } from 'genkit/schema';
import { describe, it } from 'node:test';
import {
cleanSchema,
fromGeminiCandidate,
toGeminiMessage,
toGeminiSystemInstruction,
toGeminiTool,
} from '../src/gemini.js';

describe('toGeminiMessages', () => {
Expand Down Expand Up @@ -381,3 +383,61 @@ describe('cleanSchema', () => {
});
});
});

describe('toGeminiTool', () => {
it('', async () => {
const got = toGeminiTool({
name: 'foo',
description: 'tool foo',
inputSchema: toJsonSchema({
schema: z.object({
simpleString: z.string().describe('a string').nullable(),
simpleNumber: z.number().describe('a number'),
simpleBoolean: z.boolean().describe('a boolean').optional(),
simpleArray: z.array(z.string()).describe('an array').optional(),
simpleEnum: z
.enum(['choice_a', 'choice_b'])
.describe('an enum')
.optional(),
}),
}),
});

const want = {
description: 'tool foo',
name: 'foo',
parameters: {
properties: {
simpleArray: {
description: 'an array',
items: {
type: 'STRING',
},
type: 'ARRAY',
},
simpleBoolean: {
description: 'a boolean',
type: 'BOOLEAN',
},
simpleEnum: {
description: 'an enum',
enum: ['choice_a', 'choice_b'],
type: 'STRING',
},
simpleNumber: {
description: 'a number',
type: 'NUMBER',
},
simpleString: {
description: 'a string',
nullable: true,
type: 'STRING',
},
},
required: ['simpleString', 'simpleNumber'],
type: 'OBJECT',
},
};
assert.deepStrictEqual(got, want);
});
});
Loading
Loading