From 811f0cf2111347f8cdb737f2a526a16b04a65f43 Mon Sep 17 00:00:00 2001 From: Mini256 Date: Sat, 18 Nov 2023 16:31:06 +0800 Subject: [PATCH] data-explorer: force to use only a single SQL (#1683) --- packages/api-server/src/app.ts | 4 +- packages/api-server/src/env.ts | 16 ++--- .../src/plugins/services/bot-service/index.ts | 61 +++++++++++++++++-- .../bot-service/prompt/prompt-manager.ts | 52 ++++++++-------- 4 files changed, 93 insertions(+), 40 deletions(-) diff --git a/packages/api-server/src/app.ts b/packages/api-server/src/app.ts index 168ea056192..5485545f815 100644 --- a/packages/api-server/src/app.ts +++ b/packages/api-server/src/app.ts @@ -28,17 +28,17 @@ export interface AppConfig { PLAYGROUND_SHADOW_DATABASE_URL: string; PLAYGROUND_DAILY_QUESTIONS_LIMIT: number; PLAYGROUND_TRUSTED_GITHUB_LOGINS: string[]; + EXPLORER_GENERATE_ANSWER_PROMPT_NAME: string; EXPLORER_USER_MAX_QUESTIONS_PER_HOUR: number; EXPLORER_USER_MAX_QUESTIONS_ON_GOING: number; EXPLORER_GENERATE_SQL_CACHE_TTL: number; EXPLORER_QUERY_SQL_CACHE_TTL: number; EXPLORER_OUTPUT_ANSWER_IN_STREAM: boolean; + EMBEDDING_SERVICE_ENDPOINT: string; GITHUB_ACCESS_TOKENS: string[]; OPENAI_API_KEY: string; AUTH0_DOMAIN: string; AUTH0_SECRET: string; - EMBEDDING_SERVICE_ENDPOINT: string; - PROMPT_TEMPLATE_NAME: string; TIDB_CLOUD_DATA_SERVICE_APP_ID: string; TIDB_CLOUD_DATA_SERVICE_PUBLIC_KEY: string; TIDB_CLOUD_DATA_SERVICE_PRIVATE_KEY: string; diff --git a/packages/api-server/src/env.ts b/packages/api-server/src/env.ts index 973b32c218c..ab44a9b87ec 100644 --- a/packages/api-server/src/env.ts +++ b/packages/api-server/src/env.ts @@ -1,6 +1,6 @@ import { resolve } from "path"; -export const DEFAULT_ANSWER_PROMPT_TEMPLATE = 'explorer-generate-answer'; +export const DEFAULT_EXPLORER_GENERATE_ANSWER_PROMPT_NAME = 'explorer-generate-answer'; export const APIServerEnvSchema = { type: 'object', @@ -50,6 +50,10 @@ export const APIServerEnvSchema = { separator: ',', default: '' }, + EXPLORER_GENERATE_ANSWER_PROMPT_NAME: { + type: 'string', + default: DEFAULT_EXPLORER_GENERATE_ANSWER_PROMPT_NAME, + }, EXPLORER_USER_MAX_QUESTIONS_PER_HOUR: { type: 'number', default: 15 @@ -70,6 +74,9 @@ export const APIServerEnvSchema = { type: 'boolean', default: false }, + EMBEDDING_SERVICE_ENDPOINT: { + type: 'string' + }, GITHUB_ACCESS_TOKENS: { type: 'string', separator: ',' @@ -83,13 +90,6 @@ export const APIServerEnvSchema = { AUTH0_SECRET: { type: 'string' }, - EMBEDDING_SERVICE_ENDPOINT: { - type: 'string' - }, - PROMPT_TEMPLATE_NAME: { - type: 'string', - default: DEFAULT_ANSWER_PROMPT_TEMPLATE, - }, TIDB_CLOUD_DATA_SERVICE_APP_ID: { type: 'string' }, diff --git a/packages/api-server/src/plugins/services/bot-service/index.ts b/packages/api-server/src/plugins/services/bot-service/index.ts index d9cf864d81b..af0cd32cb9a 100644 --- a/packages/api-server/src/plugins/services/bot-service/index.ts +++ b/packages/api-server/src/plugins/services/bot-service/index.ts @@ -1,4 +1,5 @@ import {FastifyBaseLogger} from "fastify"; +import {DEFAULT_EXPLORER_GENERATE_ANSWER_PROMPT_NAME} from "../../../env"; import {countAPIRequest, measureAPIRequest, openaiAPICounter, openaiAPITimer} from "../../../metrics"; import {ContextProvider} from "./prompt/context/context-provider"; import {Answer} from "./types"; @@ -25,7 +26,8 @@ export default fp(async (app) => { log, app.config.OPENAI_API_KEY, app.promptTemplateManager, - app.embeddingContextProvider + app.embeddingContextProvider, + app.config.EXPLORER_GENERATE_ANSWER_PROMPT_NAME )); }, { name: '@ossinsight/bot-service', @@ -44,6 +46,7 @@ export class BotService { private readonly apiKey: string, private readonly promptManager: PromptManager, private readonly contextProvider?: ContextProvider, + private readonly generateAnswerPromptName: string = DEFAULT_EXPLORER_GENERATE_ANSWER_PROMPT_NAME ) { const configuration = new Configuration({ apiKey: this.apiKey @@ -96,13 +99,11 @@ export class BotService { } private async loadGenerateAnswerPromptTemplate(question: string): Promise<[string, PromptConfig]> { - let promptName = 'explorer-generate-answer'; let context: Record = { question: question }; if (this.contextProvider) { - promptName = 'explorer-generate-answer-with-context'; context = await this.contextProvider.provide(context); } - return await this.promptManager.getPrompt(promptName, context); + return await this.promptManager.getPrompt(this.generateAnswerPromptName, context); } public async questionToAnswerInStream(question: string, callback: (answer: Answer, key: string, value: any) => void): Promise<[Answer | null, string | null]> { @@ -303,7 +304,12 @@ export class BotService { break; case 'sql': key = "querySQL"; - answer.querySQL = value; + const sqlArr = splitSqlStatements(value); + if (sqlArr.length > 1) { + this.log.warn({ sqlArr }, `Got multiple SQLs from OpenAI API: ${question}`); + } + // Notice: Avoid multiple SQL Error. + answer.querySQL = sqlArr[0]; break; case 'chart': value = value ? { @@ -333,3 +339,48 @@ export class BotService { } } + +function splitSqlStatements(sqlString: string): string[] { + let statements = []; + let currentStatement = ''; + let inSingleQuote = false; + let inDoubleQuote = false; + + for (let i = 0; i < sqlString.length; i++) { + const char = sqlString[i]; + const nextChar = i + 1 < sqlString.length ? sqlString[i + 1] : null; + + // Handle escape characters + if (char === '\\' && nextChar) { + currentStatement += char + nextChar; + i++; // Skip the next character + continue; + } + + // Toggle single quote state + if (char === "'" && !inDoubleQuote) { + inSingleQuote = !inSingleQuote; + } + + // Toggle double quote state + if (char === '"' && !inSingleQuote) { + inDoubleQuote = !inDoubleQuote; + } + + // If not within quotes and a semicolon is encountered, split the statement + if (char === ';' && !inSingleQuote && !inDoubleQuote) { + statements.push(currentStatement.trim()); + currentStatement = ''; + continue; + } + + currentStatement += char; + } + + // Add the last statement if it exists + if (currentStatement.trim()) { + statements.push(currentStatement.trim()); + } + + return statements; +} \ No newline at end of file diff --git a/packages/api-server/src/plugins/services/bot-service/prompt/prompt-manager.ts b/packages/api-server/src/plugins/services/bot-service/prompt/prompt-manager.ts index 74f9dfeffd8..31126efe539 100644 --- a/packages/api-server/src/plugins/services/bot-service/prompt/prompt-manager.ts +++ b/packages/api-server/src/plugins/services/bot-service/prompt/prompt-manager.ts @@ -24,7 +24,6 @@ export default fp(async (fastify) => { }); export interface PromptConfig { - name: string; model: string; stop: string[]; max_tokens: number; @@ -33,6 +32,15 @@ export interface PromptConfig { n: number; } +const defaultPromptConfig: PromptConfig = { + model: "gpt-3.5-turbo", + stop: [], + max_tokens: 200, + temperature: 0, + top_p: 1, + n: 1, +}; + export class PromptManager { private readonly logger: pino.Logger; private readonly promptConfigs: Map = new Map(); @@ -72,32 +80,26 @@ export class PromptManager { } private async loadConfigFromFile(name: string, promptConfigDir: string) { - let promptConfig: PromptConfig = { - name: name, - model: "gpt-3.5-turbo", - stop: [], - max_tokens: 200, - temperature: 0, - top_p: 1, - n: 1, - }; - const configPath = path.join(promptConfigDir, "config.json"); - this.logger.info(`Loading prompt <${name}> config from file ${configPath}.`); - if (!fs.existsSync(configPath)) { - throw new Error(`Prompt config file ${configPath} not found.`); - } - try { - promptConfig = { - ...JSON.parse(fs.readFileSync(configPath, "utf-8")) - }; - } catch (err: any) { - throw new Error(`Failed to parse prompt config file ${configPath}: ${err.message}`, { - cause: err - }); + if (fs.existsSync(configPath)) { + this.logger.info(`Loading prompt <${name}> config from file ${configPath}.`); + try { + const overrideConfig = JSON.parse(fs.readFileSync(configPath, "utf-8")); + const config = { + ...defaultPromptConfig, + ...overrideConfig + }; + this.promptConfigs.set(name, config); + this.logger.info({ config }, `Prompt <${name}> config loaded.`); + } catch (err: any) { + throw new Error(`Failed to parse prompt config file ${configPath}: ${err.message}`, { + cause: err + }); + } + } else { + this.logger.info(`Prompt <${name}> config file ${configPath} not found, using default config.`); + this.promptConfigs.set(name, defaultPromptConfig); } - - this.promptConfigs.set(name, promptConfig); } private async loadTemplateFromFile(name: string, promptConfigDir: string) {