-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'refs/remotes/origin/main'
- Loading branch information
Showing
4 changed files
with
229 additions
and
81 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
"""Evaluate spider on vLLM Llama models without any grammar restriction.""" | ||
|
||
import argparse | ||
import logging | ||
import os | ||
from pathlib import Path | ||
|
||
from openai import OpenAI | ||
from tqdm import tqdm | ||
|
||
from bench.spider.dialogue import load_spider_data | ||
from bench.spider.schema import load_schemas | ||
from bench.spider.prompt_formatter import SpiderPromptFormatter | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def get_parser() -> argparse.ArgumentParser: | ||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument('--n-query', type=int, default=100) | ||
parser.add_argument('--model-name', type=str, | ||
default="meta-llama/Meta-Llama-3-8B-Instruct", | ||
choices=["meta-llama/Meta-Llama-3-8B-Instruct"]) | ||
parser.add_argument('--exp-name', type=str, default='llama3-8b-100') | ||
parser.add_argument("--api-base", type=str, default="http://localhost:9999/v1") | ||
parser.add_argument("--api-key", type=str, default="EMPTY") | ||
|
||
return parser | ||
|
||
|
||
def main(): | ||
logging.basicConfig( | ||
format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', | ||
datefmt='%Y-%m-%d %H:%M:%S', | ||
level=os.environ.get('LOGLEVEL', 'INFO').upper(), | ||
) | ||
# disable unnecessary logs from httpx used by openai client | ||
logging.getLogger("httpx").setLevel(logging.WARNING) | ||
parser = get_parser() | ||
args = parser.parse_args() | ||
|
||
logger.info("loading spider data...") | ||
raw_spider_dir = Path('bench/spider/data/spider') | ||
spider_schemas = load_schemas( | ||
schemas_path=raw_spider_dir / 'tables.json', db_path=raw_spider_dir / 'database' | ||
) | ||
|
||
spider_dev_data = load_spider_data(raw_spider_dir / 'dev.json') | ||
spider_train_data = load_spider_data(raw_spider_dir / 'train_spider.json') | ||
logger.info("spider data loaded.") | ||
|
||
prompt_formatter = SpiderPromptFormatter(spider_train_data, spider_schemas) | ||
|
||
logger.info(f"Creating client for '{args.model_name}' served at {args.api_base}") | ||
client = OpenAI( | ||
api_key=args.api_key, | ||
base_url=args.api_base, | ||
) | ||
n_query = args.n_query | ||
|
||
gold = [] | ||
predicted = [] | ||
|
||
for i, dev_datum in tqdm(enumerate(spider_dev_data[:n_query]), total=n_query): | ||
messages = prompt_formatter.format_openai(dev_datum) | ||
|
||
if i == 0: # print an example for demonstration | ||
print('=' * 30 + ' Example prompt ' + '=' * 30) | ||
for msg in messages: | ||
print(msg["role"] + ":") | ||
print("=" * (len(msg["role"])+1)) | ||
print(msg["content"]) | ||
print("-" * 100) | ||
print('=' * 30 + ' End of prompt ' + '=' * 30) | ||
|
||
chat_response = client.chat.completions.create( | ||
model=args.model_name, | ||
# model="mistralai/Mistral-7B-Instruct-v0.1", | ||
messages=messages, | ||
seed=0 | ||
) | ||
gold.append(dev_datum) | ||
predicted.append(chat_response.choices[0].message.content) | ||
|
||
gold = spider_dev_data[:n_query] | ||
|
||
gold_outfile = f'bench/spider-eval/gold-{args.exp_name}.txt' | ||
pred_outfile = f'bench/spider-eval/predicted-{args.exp_name}.txt' | ||
|
||
logger.info(f"saving output to {gold_outfile} and {pred_outfile}") | ||
|
||
with open(gold_outfile, 'w+') as f: | ||
for datum in gold: | ||
print(f'{datum.query}\t{datum.schema_name}', file=f) | ||
|
||
with open(pred_outfile, 'w+') as f: | ||
for datum in predicted: | ||
datum = datum.replace('\n', ' ') | ||
assert '\t' not in datum | ||
print(datum.strip(), file=f) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
from typing import List | ||
|
||
import bench.spider.schema | ||
from bench.spider.dialogue import SpiderDatum | ||
|
||
|
||
def serialize_schema(db_schema: bench.spider.schema.DbSchema): | ||
table_strs = [] | ||
for table in db_schema.tables: | ||
column_strs = [] | ||
for column in table.columns: | ||
column_strs.append(f'* {column.name} ({column.tpe.value}): {column.nl_name}') | ||
table_str = '\n'.join([table.name] + column_strs) | ||
table_strs.append(table_str) | ||
|
||
return '\n\n'.join(table_strs) | ||
|
||
|
||
class SpiderPromptFormatter: | ||
def __init__(self, spider_train_data: List[SpiderDatum], db_map): | ||
self.spider_train_data = spider_train_data | ||
self.db_map = db_map | ||
|
||
self.llama2_chat_prompt_template = ( | ||
"""<s>[INST] <<SYS>> | ||
{system_prompt} | ||
<</SYS>> | ||
{user_message_1} [/INST] {model_answer_1} </s>""" | ||
+ '<s>[INST] {user_message_2} [/INST] {model_answer_2} </s>' | ||
+ '<s>[INST] {user_message_3} [/INST] {model_answer_3} </s>' | ||
+ '<s>[INST] {user_message} [/INST]' | ||
) | ||
|
||
self.system_prompt = ( | ||
'You are a coding assistant helping an analyst answer questions over business data in SQL. ' | ||
'More specifically, the analyst provides you a database schema ' | ||
'(tables in the database along with their column names and types) ' | ||
'and asks a question about the data that can be solved by issuing a SQL query to the database. ' | ||
'In response, you write the SQL statement that answers the question. ' | ||
'You do not provide any commentary or explanation of what the code does, ' | ||
'just the SQL statement ending in a semicolon.' | ||
) | ||
|
||
self.user_message_template = """Here is a database schema: | ||
{schema_str} | ||
Please write me a SQL statement that answers the following question: {utterance} | ||
Remember, DO NOT provide any commentary or explanation of what the code does, just the SQL statement ending in a semicolon.""" | ||
|
||
def format_llama2(self, datum): | ||
spider_train_data = self.spider_train_data | ||
db_map = self.db_map | ||
|
||
llama2_chat_prompt_template = self.llama2_chat_prompt_template | ||
system_prompt = self.system_prompt | ||
user_message_template = self.user_message_template | ||
|
||
prompt_var_dict = { | ||
'system_prompt': system_prompt, | ||
} | ||
|
||
# in-context examples from training data | ||
for i, example_id in enumerate([10, 100, 1000], 1): | ||
train_datum = spider_train_data[example_id] | ||
user_message = user_message_template.format( | ||
schema_str=serialize_schema(db_map[train_datum.schema_name]), | ||
utterance=train_datum.utterance, | ||
) | ||
prompt_var_dict[f'user_message_{i}'] = user_message | ||
prompt_var_dict[f'model_answer_{i}'] = train_datum.query + ';' | ||
|
||
# the actual question | ||
user_message = user_message_template.format( | ||
schema_str=serialize_schema(db_map[datum.schema_name]), | ||
utterance=datum.utterance, | ||
) | ||
prompt_var_dict['user_message'] = user_message | ||
|
||
return llama2_chat_prompt_template.format(**prompt_var_dict) | ||
|
||
def format_openai(self, datum): | ||
messages = [ | ||
{"role": "system", "content": self.system_prompt}, | ||
] | ||
for example_id in [10, 100, 1000]: | ||
train_datum = self.spider_train_data[example_id] | ||
user_message = self.user_message_template.format( | ||
schema_str=serialize_schema(self.db_map[train_datum.schema_name]), | ||
utterance=train_datum.utterance, | ||
) | ||
messages.append({"role": "user", "content": user_message}) | ||
messages.append({"role": "system", "content": train_datum.query + ";"}) | ||
|
||
# the actual question | ||
user_message = self.user_message_template.format( | ||
schema_str=serialize_schema(self.db_map[datum.schema_name]), | ||
utterance=datum.utterance, | ||
) | ||
messages.append({"role": "user", "content": user_message}) | ||
return messages |