From 0f6475d14e82324ac7bc0a35e632894a7d1578f4 Mon Sep 17 00:00:00 2001 From: "Ankush Pala ankush@lastmileai.dev" <> Date: Mon, 26 Feb 2024 11:22:15 -0500 Subject: [PATCH] 4/n Gemini Prompt Schema MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Prompt Schema for Gemini Pro. This prompt schema doesn't define a schema for Gemini's Safety Settings As those are more complicated, they will come on top. API Docs for Generation Config: https://ai.google.dev/api/python/google/ai/generativelanguage/GenerationConfig Skipped these attributes as part of the prompt schema. - skipped candidate count - didn't add max_output_tokens. I was getting weird index out of bounds exceptions when setting max_output_tokens which occured at the output parsing step. I tried with 35, 400, 200. Thought it would be better UX to just leave it out for now before figuring out why. - safety settings See: Code Ref: https://github.com/google/generative-ai-python/blob/main/google/generativeai/types/safety_types.py#L218-L221 API Doc Ref: https://ai.google.dev/docs/safety_setting_gemini#code-examples ## Testplan Screenshot 2024-02-26 at 11 20 04 AM --- my_aiconfig.aiconfig.json | 16 ++++-- .../prompt_schemas/GeminiPromptSchema.ts | 55 +++++++++++++++++++ .../editor/client/src/utils/promptUtils.ts | 5 ++ 3 files changed, 72 insertions(+), 4 deletions(-) create mode 100644 python/src/aiconfig/editor/client/src/shared/prompt_schemas/GeminiPromptSchema.ts diff --git a/my_aiconfig.aiconfig.json b/my_aiconfig.aiconfig.json index d79f1800a..2819457e3 100644 --- a/my_aiconfig.aiconfig.json +++ b/my_aiconfig.aiconfig.json @@ -9,12 +9,20 @@ "prompts": [ { "name": "prompt_1", - "input": "", + "input": "Tell me a joke\n", "metadata": { - "model": "gpt-4", + "model": { + "name": "gemini-pro", + "settings": { + "generation_config": { + "top_p": 0.1, + "top_k": 1, + "temperature": 0.4 + } + } + }, "parameters": {} - }, - "outputs": [] + } } ], "$schema": "https://json.schemastore.org/aiconfig-1.0" diff --git a/python/src/aiconfig/editor/client/src/shared/prompt_schemas/GeminiPromptSchema.ts b/python/src/aiconfig/editor/client/src/shared/prompt_schemas/GeminiPromptSchema.ts new file mode 100644 index 000000000..dc04f07b4 --- /dev/null +++ b/python/src/aiconfig/editor/client/src/shared/prompt_schemas/GeminiPromptSchema.ts @@ -0,0 +1,55 @@ +import { PromptSchema } from "../../utils/promptUtils"; +// This does not support Gemini Vision. Model parser does not support it. +// TODO: Safety Settings, Candidate Count, max_output_tokens +export const GeminiParserPromptSchema: PromptSchema = { + // https://ai.google.dev/api/python/google/ai/generativelanguage/GenerationConfig + input: { + type: "string", + }, + model_settings: { + type: "object", + properties: { + generation_config: { + type: "object", + properties: { + candidate_count: {}, + temperature: { + type: "number", + description: "Controls the randomness of the output.", + minimum: 0.0, + maximum: 1.0, + }, + top_p: { + type: "number", + description: + "The maximum cumulative probability of tokens to consider when sampling.", + }, + top_k: { + type: "integer", + description: + "The maximum number of tokens to consider when sampling.", + }, + stop_sequences: { + type: "array", + description: + "The set of character sequences (up to 5) that will stop output generation", + items: { + type: "string", + }, + }, + }, + }, + }, + }, + prompt_metadata: { + type: "object", + properties: { + remember_chat_context: { + type: "boolean", + }, + stream: { + type: "boolean", + }, + }, + }, +}; diff --git a/python/src/aiconfig/editor/client/src/utils/promptUtils.ts b/python/src/aiconfig/editor/client/src/utils/promptUtils.ts index 1b6c09484..3c0fdd520 100644 --- a/python/src/aiconfig/editor/client/src/utils/promptUtils.ts +++ b/python/src/aiconfig/editor/client/src/utils/promptUtils.ts @@ -8,6 +8,7 @@ import { import { PaLMTextParserPromptSchema } from "../shared/prompt_schemas/PaLMTextParserPromptSchema"; import { PaLMChatParserPromptSchema } from "../shared/prompt_schemas/PaLMChatParserPromptSchema"; import { AnyscaleEndpointPromptSchema } from "../shared/prompt_schemas/AnyscaleEndpointPromptSchema"; +import { GeminiParserPromptSchema } from "../shared/prompt_schemas/GeminiPromptSchema"; import { HuggingFaceAutomaticSpeechRecognitionPromptSchema } from "../shared/prompt_schemas/HuggingFaceAutomaticSpeechRecognitionPromptSchema"; import { HuggingFaceAutomaticSpeechRecognitionRemoteInferencePromptSchema } from "../shared/prompt_schemas/HuggingFaceAutomaticSpeechRecognitionRemoteInferencePromptSchema"; import { HuggingFaceImage2TextTransformerPromptSchema } from "../shared/prompt_schemas/HuggingFaceImage2TextTransformerPromptSchema"; @@ -24,6 +25,7 @@ import { HuggingFaceImage2TextRemoteInferencePromptSchema } from "../shared/prom import { ClaudeBedrockPromptSchema } from "../shared/prompt_schemas/ClaudeBedrockPromptSchema"; import { HuggingFaceConversationalRemoteInferencePromptSchema } from "../shared/prompt_schemas/HuggingFaceConversationalRemoteInferencePromptSchema"; + /** * Get the name of the model for the specified prompt. The name will either be specified in the prompt's * model metadata, or as the default_model in the aiconfig metadata @@ -121,6 +123,9 @@ export const PROMPT_SCHEMAS: Record = { // PaLMChatParser "models/chat-bison-001": PaLMChatParserPromptSchema, + // Gemini + "gemini-pro": GeminiParserPromptSchema, + // AnyscaleEndpoint AnyscaleEndpoint: AnyscaleEndpointPromptSchema,