Skip to content

Commit

Permalink
4/n Gemini Prompt Schema
Browse files Browse the repository at this point in the history
Prompt Schema for Gemini Pro.

This prompt schema doesn't define a schema for Gemini's Safety Settings As those are more complicated, they will come on top.

API Docs for Generation Config: https://ai.google.dev/api/python/google/ai/generativelanguage/GenerationConfig

Skipped these attributes as part of the prompt schema.
- skipped candidate count
- didn't add max_output_tokens. I was getting weird index out of bounds exceptions when setting max_output_tokens which occured at the output parsing step. I tried with 35, 400, 200. Thought it would be better UX to just leave it out for now before figuring out why.
- safety settings

See:
Code Ref: https://github.com/google/generative-ai-python/blob/main/google/generativeai/types/safety_types.py#L218-L221
API Doc Ref: https://ai.google.dev/docs/safety_setting_gemini#code-examples

## Testplan

<img width="1310" alt="Screenshot 2024-02-26 at 11 20 04 AM" src="https://github.com/lastmile-ai/aiconfig/assets/141073967/9572a4c0-2005-4e6c-9e93-c1eb0c3859dd">
  • Loading branch information
Ankush Pala [email protected] committed Feb 26, 2024
1 parent 20a144c commit 0f6475d
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 4 deletions.
16 changes: 12 additions & 4 deletions my_aiconfig.aiconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,20 @@
"prompts": [
{
"name": "prompt_1",
"input": "",
"input": "Tell me a joke\n",
"metadata": {
"model": "gpt-4",
"model": {
"name": "gemini-pro",
"settings": {
"generation_config": {
"top_p": 0.1,
"top_k": 1,
"temperature": 0.4
}
}
},
"parameters": {}
},
"outputs": []
}
}
],
"$schema": "https://json.schemastore.org/aiconfig-1.0"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import { PromptSchema } from "../../utils/promptUtils";
// This does not support Gemini Vision. Model parser does not support it.
// TODO: Safety Settings, Candidate Count, max_output_tokens
export const GeminiParserPromptSchema: PromptSchema = {
// https://ai.google.dev/api/python/google/ai/generativelanguage/GenerationConfig
input: {
type: "string",
},
model_settings: {
type: "object",
properties: {
generation_config: {
type: "object",
properties: {
candidate_count: {},
temperature: {
type: "number",
description: "Controls the randomness of the output.",
minimum: 0.0,
maximum: 1.0,
},
top_p: {
type: "number",
description:
"The maximum cumulative probability of tokens to consider when sampling.",
},
top_k: {
type: "integer",
description:
"The maximum number of tokens to consider when sampling.",
},
stop_sequences: {
type: "array",
description:
"The set of character sequences (up to 5) that will stop output generation",
items: {
type: "string",
},
},
},
},
},
},
prompt_metadata: {
type: "object",
properties: {
remember_chat_context: {
type: "boolean",
},
stream: {
type: "boolean",
},
},
},
};
5 changes: 5 additions & 0 deletions python/src/aiconfig/editor/client/src/utils/promptUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
import { PaLMTextParserPromptSchema } from "../shared/prompt_schemas/PaLMTextParserPromptSchema";
import { PaLMChatParserPromptSchema } from "../shared/prompt_schemas/PaLMChatParserPromptSchema";
import { AnyscaleEndpointPromptSchema } from "../shared/prompt_schemas/AnyscaleEndpointPromptSchema";
import { GeminiParserPromptSchema } from "../shared/prompt_schemas/GeminiPromptSchema";
import { HuggingFaceAutomaticSpeechRecognitionPromptSchema } from "../shared/prompt_schemas/HuggingFaceAutomaticSpeechRecognitionPromptSchema";
import { HuggingFaceAutomaticSpeechRecognitionRemoteInferencePromptSchema } from "../shared/prompt_schemas/HuggingFaceAutomaticSpeechRecognitionRemoteInferencePromptSchema";
import { HuggingFaceImage2TextTransformerPromptSchema } from "../shared/prompt_schemas/HuggingFaceImage2TextTransformerPromptSchema";
Expand All @@ -24,6 +25,7 @@ import { HuggingFaceImage2TextRemoteInferencePromptSchema } from "../shared/prom
import { ClaudeBedrockPromptSchema } from "../shared/prompt_schemas/ClaudeBedrockPromptSchema";
import { HuggingFaceConversationalRemoteInferencePromptSchema } from "../shared/prompt_schemas/HuggingFaceConversationalRemoteInferencePromptSchema";


/**
* Get the name of the model for the specified prompt. The name will either be specified in the prompt's
* model metadata, or as the default_model in the aiconfig metadata
Expand Down Expand Up @@ -121,6 +123,9 @@ export const PROMPT_SCHEMAS: Record<string, PromptSchema> = {
// PaLMChatParser
"models/chat-bison-001": PaLMChatParserPromptSchema,

// Gemini
"gemini-pro": GeminiParserPromptSchema,

// AnyscaleEndpoint
AnyscaleEndpoint: AnyscaleEndpointPromptSchema,

Expand Down

0 comments on commit 0f6475d

Please sign in to comment.