Skip to content

Commit

Permalink
Multimodal fix - multiple changes to chat agent module to support mul…
Browse files Browse the repository at this point in the history
…timodal I/O (#60)

* Fixed issues with basic system prompt + added default system prompts #43

* multiple changes to chat agent module to support multimodal I/O

* bumped package version for alpha release
  • Loading branch information
pranav-kural authored Jul 30, 2024
1 parent 5f5fa03 commit f4f024d
Show file tree
Hide file tree
Showing 6 changed files with 250 additions and 93 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@oconva/qvikchat",
"version": "2.0.0-alpha.2",
"version": "2.0.0-alpha.3",
"repository": {
"type": "git",
"url": "https://github.com/oconva/qvikchat.git"
Expand Down
82 changes: 42 additions & 40 deletions src/agents/chat-agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import {MessageData} from '@genkit-ai/ai/model';
import {ToolArgument} from '@genkit-ai/ai/tool';
import {Dotprompt} from '@genkit-ai/dotprompt';
import {PromptOutputSchema} from '../prompts/prompts';
import {dallE3} from 'genkitx-openai';

/**
* Represents the type of chat agent.
Expand Down Expand Up @@ -45,15 +46,13 @@ export type AgentTypeConfig =
* @property systemPrompt - The system prompt for the chat agent.
* @property chatPrompt - The chat prompt for the chat agent.
* @property tools - Tools for the chat agent.
* @property model - The supported model to use for chat completion.
* @property modelConfig - The model configuration.
* @property responseOutputSchema - The output schema for the response.
*/
export type ChatAgentConfig = {
systemPrompt?: Dotprompt;
chatPrompt?: Dotprompt;
tools?: ToolArgument[];
model?: SupportedModels;
modelConfig?: ModelConfig;
responseOutputSchema?: OutputSchemaType;
} & AgentTypeConfig;
Expand Down Expand Up @@ -108,7 +107,6 @@ export type GenerateResponseHistoryProps =
* @property enableChatHistory - Indicates whether to use chat history.
* @property chatHistoryStore - The chat history store.
* @property tools - The tool arguments.
* @property model - The supported model.
* @property modelConfig - The model configuration.
* @property systemPrompt - The system prompt.
* @property chatPrompt - The chat prompt.
Expand All @@ -118,7 +116,6 @@ export type GenerateResponseProps = {
context?: string;
chatId?: string;
tools?: ToolArgument[];
model?: SupportedModels;
modelConfig?: ModelConfig;
systemPrompt?: Dotprompt;
chatPrompt?: Dotprompt;
Expand Down Expand Up @@ -149,11 +146,6 @@ export interface ChatAgentMethods {
generateResponse: (
props: GenerateResponseProps
) => Promise<GenerateResponseReturnObj>;

/**
* Method to get model name that the chat agent is using.
*/
getModelName(): string;
}

/**
Expand All @@ -166,7 +158,6 @@ export interface ChatAgentInterface
export type GenerateSystemPromptResponseParams = {
agentType?: ChatAgentType;
prompt: Dotprompt;
model?: string;
modelConfig?: ModelConfig;
query?: string;
context?: string;
Expand All @@ -183,7 +174,6 @@ export class ChatAgent implements ChatAgentInterface {
systemPrompt?: Dotprompt;
chatPrompt?: Dotprompt;
tools?: ToolArgument[];
private modelName: string;
modelConfig?: ModelConfig;
responseOutputSchema?: OutputSchemaType;

Expand All @@ -196,8 +186,7 @@ export class ChatAgent implements ChatAgentInterface {
* @param enableChatHistory - Indicates whether to use chat history.
* @param chatHistoryStore - The chat history store.
* @param tools - Tools for the chat agent.
* @param model - The supported model. If not provided, will use the default model (e.g. Gemini 1.5 Flash).
* @param modelConfig - The model configuration.
* @param modelConfig - The model configuration. If not provided, will use the default model (e.g. Gemini 1.5 Flash).
*/
constructor(config: ChatAgentConfig = {}) {
this.agentType = config.agentType ?? defaultChatAgentConfig.agentType;
Expand All @@ -207,9 +196,6 @@ export class ChatAgent implements ChatAgentInterface {
this.systemPrompt = config.systemPrompt;
this.chatPrompt = config.chatPrompt;
this.tools = config.tools;
this.modelName = config.model
? SupportedModelNames[config.model]
: SupportedModelNames[defaultChatAgentConfig.model];
this.modelConfig = config.modelConfig;
this.responseOutputSchema = config.responseOutputSchema;
}
Expand Down Expand Up @@ -323,7 +309,6 @@ export class ChatAgent implements ChatAgentInterface {
private static generateSystemPromptResponse({
agentType,
prompt,
model,
modelConfig,
query,
context,
Expand All @@ -333,8 +318,9 @@ export class ChatAgent implements ChatAgentInterface {
// generate the response
const res = prompt.generate({
// if undefined, will use model defined in the dotprompt
model: model,
config: modelConfig,
model:
SupportedModelNames[modelConfig?.name ?? defaultChatAgentConfig.model],
config: {...modelConfig},
input: ChatAgent.getFormattedInput({agentType, query, context, topic}),
tools: tools,
});
Expand All @@ -350,20 +336,17 @@ export class ChatAgent implements ChatAgentInterface {
static getPromptOutputSchema(
responseOutputSchema?: OutputSchemaType
): PromptOutputSchema {
if (!responseOutputSchema || responseOutputSchema.responseType === 'text') {
if (!responseOutputSchema || responseOutputSchema.format === 'text') {
return {format: 'text'};
} else if (responseOutputSchema.responseType === 'json') {
} else if (responseOutputSchema.format === 'json') {
return {
format: 'json',
schema: responseOutputSchema.schema,
jsonSchema: responseOutputSchema.jsonSchema,
};
} else if (responseOutputSchema.responseType === 'media') {
} else if (responseOutputSchema.format === 'media') {
return {format: 'media'};
} else {
throw new Error(
`Invalid response type ${responseOutputSchema.responseType}`
);
throw new Error(`Invalid response type ${responseOutputSchema.format}`);
}
}

Expand All @@ -382,6 +365,29 @@ export class ChatAgent implements ChatAgentInterface {
async generateResponse(
params: GenerateResponseProps
): Promise<GenerateResponseReturnObj> {
// if the model being used is Dall-E3 (e.g., for image generation)
// simply return the response
if (
params.modelConfig?.name === 'dallE3' || // if model provided in params is Dall-E3
(!params.modelConfig?.name && this.modelConfig?.name === 'dallE3') // if model not provided in params and default model is Dall-E3
) {
// configurations for Dall-E3 model
const dallEConfig = params.modelConfig ?? this.modelConfig;
// return response
return {
res: await generate({
model: dallE3,
config: {
...dallEConfig,
},
prompt: params.query,
tools: params.tools,
output: {
format: 'media',
},
}),
};
}
// System prompt to use
// In order of priority: systemPrompt provided as argument to generateResponse, this.systemPrompt, default system prompt
const prompt =
Expand All @@ -400,9 +406,6 @@ export class ChatAgent implements ChatAgentInterface {
res: await ChatAgent.generateSystemPromptResponse({
agentType: this.agentType,
prompt,
model: params.model
? SupportedModelNames[params.model]
: this.modelName,
modelConfig: params.modelConfig ?? this.modelConfig,
query: params.query,
context: params.context,
Expand All @@ -420,9 +423,6 @@ export class ChatAgent implements ChatAgentInterface {
const res = await ChatAgent.generateSystemPromptResponse({
agentType: this.agentType,
prompt,
model: params.model
? SupportedModelNames[params.model]
: this.modelName,
modelConfig: params.modelConfig ?? this.modelConfig,
query: params.query,
context: params.context,
Expand All @@ -448,7 +448,16 @@ export class ChatAgent implements ChatAgentInterface {
throw new Error(`No data found for chat ID ${params.chatId}.`);
// generate response for given query (will use chat prompt and any provided chat history, context and tools)
const res = await generate({
model: params.model ?? this.modelName,
model:
SupportedModelNames[
params.modelConfig?.name ??
this.modelConfig?.name ??
defaultChatAgentConfig.model
],
config: {
...this.modelConfig,
...params.modelConfig,
},
prompt: params.query,
history: chatHistory,
context: params.context
Expand All @@ -467,11 +476,4 @@ export class ChatAgent implements ChatAgentInterface {
res,
};
}

/**
* Method to get model name that the chat agent is using.
*/
getModelName() {
return this.modelName;
}
}
Loading

0 comments on commit f4f024d

Please sign in to comment.