Skip to content

Commit

Permalink
(improvement)(headless)Commit new impl of SqlGenStrategy
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryjzhang committed Jun 5, 2024
1 parent 91e27bc commit 008c1c3
Show file tree
Hide file tree
Showing 12 changed files with 318 additions and 324 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.tencent.supersonic.headless.core.chat.parser.llm;

import com.tencent.supersonic.common.util.JsonUtil;
import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
import dev.langchain4j.data.message.AiMessage;
Expand All @@ -15,12 +15,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.stream.Collectors;

import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_FEW_SHOT_NUMBER;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_SELF_CONSISTENCY_NUMBER;
import java.util.concurrent.ConcurrentHashMap;


@Service
Expand All @@ -29,46 +24,75 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {

@Override
public LLMResp generate(LLMReq llmReq) {
//1.retriever sqlExamples and generate exampleListPool
keyPipelineLog.info("OnePassSCSqlGenStrategy llmReq:{}", llmReq);

int exemplarRecallNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
int fewShotNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_FEW_SHOT_NUMBER));
int selfConsistencyNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER));
//1.recall exemplars
keyPipelineLog.info("OnePassSCSqlGenStrategy llmReq:\n{}", llmReq);
List<List<Map<String, String>>> exemplarsList = promptHelper.getFewShotExemplars(llmReq);

List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
exemplarRecallNumber);
List<List<Map<String, String>>> exampleListPool = promptGenerator.getExampleCombos(sqlExamples,
fewShotNumber, selfConsistencyNumber);
//2.generate sql generation prompt for each self-consistency inference
Map<Prompt, List<Map<String, String>>> prompt2Exemplar = new HashMap<>();
for (List<Map<String, String>> exemplars : exemplarsList) {
Prompt prompt = generatePrompt(llmReq, exemplars);
prompt2Exemplar.put(prompt, exemplars);
}

//2.generator linking and sql prompt by sqlExamples,and parallel generate response.
List<String> linkingSqlPromptPool = promptGenerator.generatePromptPool(llmReq, exampleListPool, true);
List<String> llmResults = new CopyOnWriteArrayList<>();
linkingSqlPromptPool.parallelStream().forEach(linkingSqlPrompt -> {
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingSqlPrompt))
.apply(new HashMap<>());
keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:{}", prompt.toSystemMessage());
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
//3.perform multiple self-consistency inferences parallelly
Map<Prompt, String> prompt2Output = new ConcurrentHashMap<>();
prompt2Exemplar.keySet().parallelStream().forEach(prompt -> {
keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:\n{}", prompt.toSystemMessage());
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
String result = response.content().text();
llmResults.add(result);
keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:{}", result);
prompt2Output.put(prompt, result);
keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:\n{}", result);
}
);
//3.format response.
List<String> sqlList = llmResults.stream()
.map(OutputFormat::getSql).collect(Collectors.toList());

Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlList);
//4.format response.
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(
Lists.newArrayList(prompt2Output.values()));
LLMResp llmResp = new LLMResp();
llmResp.setQuery(llmReq.getQueryText());
//TODO: should use the same few-shot exemplars as the one chose by self-consistency vote
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(exemplarsList.get(0), sqlMapPair.getRight()));

return llmResp;
}

private Prompt generatePrompt(LLMReq llmReq, List<Map<String, String>> fewshotExampleList) {
String instruction = ""
+ "#Role: You are a data analyst experienced in SQL languages.\n"
+ "#Task: You will be provided a natural language query asked by business users,"
+ "please convert it to a SQL query so that relevant answer could be returned to the user "
+ "by executing the SQL query against underlying database.\n"
+ "#Rules:\n"
+ "1.Always use `数据日期` as the date field.\n"
+ "2.Always use `datediff` function to calculate date range.\n"
+ "3.Only output SQL statement.\n"
+ "#Exemplars:\n%s"
+ "#UserQuery: %s "
+ "#DatabaseMetadata: %s "
+ "#SQL: ";

StringBuilder exemplarsStr = new StringBuilder();
for (Map<String, String> example : fewshotExampleList) {
String metadata = example.get("dbSchema");
String question = example.get("questionAugmented");
String sql = example.get("sql");
String exemplarStr = String.format("#UserQuery: %s #DatabaseMetadata: %s #SQL: %s\n",
question, metadata, sql);
exemplarsStr.append(exemplarStr);
}

Pair<String, String> questionPrompt = promptHelper.transformQuestionPrompt(llmReq);
String dbSchema = questionPrompt.getLeft();
String questionAugmented = questionPrompt.getRight();
String promptStr = String.format(instruction, exemplarsStr, questionAugmented, dbSchema);

LLMResp result = new LLMResp();
result.setQuery(llmReq.getQueryText());
result.setSqlRespMap(OutputFormat.buildSqlRespMap(sqlExamples, sqlMapPair.getRight()));
return result;
return PromptTemplate.from(promptStr).apply(new HashMap<>());
}

@Override
public void afterPropertiesSet() {
SqlGenStrategyFactory.addSqlGenerationForFactory(LLMReq.SqlGenType.ONE_PASS_AUTO_COT_SELF_CONSISTENCY, this);
SqlGenStrategyFactory.addSqlGenerationForFactory(LLMReq.SqlGenType.ONE_PASS_SELF_CONSISTENCY, this);
}
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package com.tencent.supersonic.headless.core.chat.parser.llm;

import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.headless.core.config.ParserConfig;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_FEW_SHOT_NUMBER;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_SELF_CONSISTENCY_NUMBER;

@Component
@Slf4j
public class PromptHelper {

@Autowired
private ParserConfig parserConfig;

@Autowired
private ExemplarManager exemplarManager;

public List<List<Map<String, String>>> getFewShotExemplars(LLMReq llmReq) {
int exemplarRecallNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
int fewShotNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_FEW_SHOT_NUMBER));
int selfConsistencyNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER));

List<Map<String, String>> exemplars = exemplarManager.recallExemplars(llmReq.getQueryText(),
exemplarRecallNumber);
List<List<Map<String, String>>> results = new ArrayList<>();

// use random collection of exemplars for each self-consistency inference
for (int i = 0; i < selfConsistencyNumber; i++) {
List<Map<String, String>> shuffledList = new ArrayList<>(exemplars);
Collections.shuffle(shuffledList);
results.add(shuffledList.subList(0, fewShotNumber));
}

return results;
}

public Pair<String, String> transformQuestionPrompt(LLMReq llmReq) {
String tableName = llmReq.getSchema().getDataSetName();
List<String> fieldNameList = llmReq.getSchema().getFieldNameList();
List<LLMReq.ElementValue> linkedValues = llmReq.getLinking();
String currentDate = llmReq.getCurrentDate();
String priorExts = llmReq.getPriorExts();

String dbSchema = "Table: " + tableName + ", Columns = " + fieldNameList;

List<String> priorLinkingList = new ArrayList<>();
for (ElementValue value : linkedValues) {
String fieldName = value.getFieldName();
String fieldValue = value.getFieldValue();
priorLinkingList.add("‘" + fieldValue + "‘是一个‘" + fieldName + "‘");
}
String currentDataStr = "current date is " + currentDate;
String linkingListStr = String.join(",", priorLinkingList);
String termStr = getTermStr(llmReq);
String questionAugmented = String.format("%s (补充信息:%s . %s . %s) (备注: %s)", llmReq.getQueryText(),
linkingListStr, currentDataStr, termStr, priorExts);

return Pair.of(dbSchema, questionAugmented);
}

private String getTermStr(LLMReq llmReq) {
List<LLMReq.Term> terms = llmReq.getSchema().getTerms();
StringBuilder termsDesc = new StringBuilder();
if (!CollectionUtils.isEmpty(terms)) {
termsDesc.append("相关业务术语:");
for (int idx = 0; idx < terms.size(); idx++) {
LLMReq.Term term = terms.get(idx);
String name = term.getName();
String description = term.getDescription();
List<String> alias = term.getAlias();
String descPart = StringUtils.isBlank(description) ? "" : String.format(",它通常是指<%s>", description);
String aliasPart = CollectionUtils.isEmpty(alias) ? "" : String.format(",类似的表达还有%s", alias);
termsDesc.append(String.format("%d.<%s>是业务术语%s%s;", idx + 1, name, descPart, aliasPart));
}
if (termsDesc.length() > 0) {
termsDesc.setLength(termsDesc.length() - 1);
}
}

return termsDesc.toString();
}

}
Loading

0 comments on commit 008c1c3

Please sign in to comment.