Skip to content

Commit

Permalink
Merge remote-tracking branch 'refs/remotes/origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
timvieira committed Jun 19, 2024
2 parents 7c6dd48 + adcac77 commit dacc1e8
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 81 deletions.
16 changes: 16 additions & 0 deletions bench/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,22 @@ before running any evaluation, `spider-eval` depends on `punkt` package of `nltk
>>> nltk.download('punkt')
```

***vLLM***

To run the evaluation script on vLLM, first serve the vLLM model, by doing

```bash
python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-8B-Instruct --port 9999
```

and then the server will be up at `http://localhost:9999`.

If vLLM complains the model is gated, you might want to follow the printed instruction to get permission to use the model
and then set your Hugginface token:
```bash
export HF_TOKEN=xxx
```

### Generation

to generate sql on the first 100 examples on the spider dev set on llama-2-chat-13b model, from the root directory, do
Expand Down
84 changes: 3 additions & 81 deletions bench/run_spider_llama2_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,91 +11,13 @@
from tqdm import tqdm
from transformers import AutoTokenizer, pipeline

import bench
from bench.spider.dialogue import load_spider_data
from bench.spider.schema import load_schemas
from bench.spider.prompt_formatter import SpiderPromptFormatter

logger = logging.getLogger(__name__)


def serialize_schema(db_schema: bench.spider.schema.DbSchema):
table_strs = []
for table in db_schema.tables:
column_strs = []
for column in table.columns:
column_strs.append(f'* {column.name} ({column.tpe.value}): {column.nl_name}')
table_str = '\n'.join([table.name] + column_strs)
table_strs.append(table_str)

return '\n\n'.join(table_strs)


class PromptFormatter:
def __init__(self, spider_train_data, db_map):
self.spider_train_data = spider_train_data
self.db_map = db_map

self.llama2_chat_prompt_template = (
"""<s>[INST] <<SYS>>
{system_prompt}
<</SYS>>
{user_message_1} [/INST] {model_answer_1} </s>"""
+ '<s>[INST] {user_message_2} [/INST] {model_answer_2} </s>'
+ '<s>[INST] {user_message_3} [/INST] {model_answer_3} </s>'
+ '<s>[INST] {user_message} [/INST]'
)

self.system_prompt = (
'You are a coding assistant helping an analyst answer questions over business data in SQL. '
'More specifically, the analyst provides you a database schema '
'(tables in the database along with their column names and types) '
'and asks a question about the data that can be solved by issuing a SQL query to the database. '
'In response, you write the SQL statement that answers the question. '
'You do not provide any commentary or explanation of what the code does, '
'just the SQL statement ending in a semicolon.'
)

self.user_message_template = """Here is a database schema:
{schema_str}
Please write me a SQL statement that answers the following question: {utterance}
Remember, DO NOT provide any commentary or explanation of what the code does, just the SQL statement ending in a semicolon."""

def format(self, datum):
spider_train_data = self.spider_train_data
db_map = self.db_map

llama2_chat_prompt_template = self.llama2_chat_prompt_template
system_prompt = self.system_prompt
user_message_template = self.user_message_template

prompt_var_dict = {
'system_prompt': system_prompt,
}

# in-context examples from training data
for i, example_id in enumerate([10, 100, 1000], 1):
train_datum = spider_train_data[example_id]
user_message = user_message_template.format(
schema_str=serialize_schema(db_map[train_datum.schema_name]),
utterance=train_datum.utterance,
)
prompt_var_dict[f'user_message_{i}'] = user_message
prompt_var_dict[f'model_answer_{i}'] = train_datum.query + ';'

# the actual question
user_message = user_message_template.format(
schema_str=serialize_schema(db_map[datum.schema_name]),
utterance=datum.utterance,
)
prompt_var_dict['user_message'] = user_message

return llama2_chat_prompt_template.format(**prompt_var_dict)


def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()

Expand Down Expand Up @@ -123,7 +45,7 @@ def main():

spider_dev_data = load_spider_data(raw_spider_dir / 'dev.json')
spider_train_data = load_spider_data(raw_spider_dir / 'train_spider.json')
prompt_formatter = PromptFormatter(spider_train_data, spider_schemas)
prompt_formatter = SpiderPromptFormatter(spider_train_data, spider_schemas)

model = 'meta-llama/Llama-2-7b-chat-hf'
model = model.replace('7b', args.model_size)
Expand All @@ -146,7 +68,7 @@ def main():
predicted = []

for i, dev_datum in tqdm(enumerate(spider_dev_data[:n_query]), total=n_query):
prompt = prompt_formatter.format(dev_datum)
prompt = prompt_formatter.format_llama2(dev_datum)
if i == 0:
print('=' * 30 + ' Example prompt ' + '=' * 30)
print(prompt)
Expand Down
107 changes: 107 additions & 0 deletions bench/run_spider_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Evaluate spider on vLLM Llama models without any grammar restriction."""

import argparse
import logging
import os
from pathlib import Path

from openai import OpenAI
from tqdm import tqdm

from bench.spider.dialogue import load_spider_data
from bench.spider.schema import load_schemas
from bench.spider.prompt_formatter import SpiderPromptFormatter

logger = logging.getLogger(__name__)


def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()

parser.add_argument('--n-query', type=int, default=100)
parser.add_argument('--model-name', type=str,
default="meta-llama/Meta-Llama-3-8B-Instruct",
choices=["meta-llama/Meta-Llama-3-8B-Instruct"])
parser.add_argument('--exp-name', type=str, default='llama3-8b-100')
parser.add_argument("--api-base", type=str, default="http://localhost:9999/v1")
parser.add_argument("--api-key", type=str, default="EMPTY")

return parser


def main():
logging.basicConfig(
format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=os.environ.get('LOGLEVEL', 'INFO').upper(),
)
# disable unnecessary logs from httpx used by openai client
logging.getLogger("httpx").setLevel(logging.WARNING)
parser = get_parser()
args = parser.parse_args()

logger.info("loading spider data...")
raw_spider_dir = Path('bench/spider/data/spider')
spider_schemas = load_schemas(
schemas_path=raw_spider_dir / 'tables.json', db_path=raw_spider_dir / 'database'
)

spider_dev_data = load_spider_data(raw_spider_dir / 'dev.json')
spider_train_data = load_spider_data(raw_spider_dir / 'train_spider.json')
logger.info("spider data loaded.")

prompt_formatter = SpiderPromptFormatter(spider_train_data, spider_schemas)

logger.info(f"Creating client for '{args.model_name}' served at {args.api_base}")
client = OpenAI(
api_key=args.api_key,
base_url=args.api_base,
)
n_query = args.n_query

gold = []
predicted = []

for i, dev_datum in tqdm(enumerate(spider_dev_data[:n_query]), total=n_query):
messages = prompt_formatter.format_openai(dev_datum)

if i == 0: # print an example for demonstration
print('=' * 30 + ' Example prompt ' + '=' * 30)
for msg in messages:
print(msg["role"] + ":")
print("=" * (len(msg["role"])+1))
print(msg["content"])
print("-" * 100)
print('=' * 30 + ' End of prompt ' + '=' * 30)

chat_response = client.chat.completions.create(
model=args.model_name,
# model="mistralai/Mistral-7B-Instruct-v0.1",
messages=messages,
seed=0
)
gold.append(dev_datum)
predicted.append(chat_response.choices[0].message.content)

gold = spider_dev_data[:n_query]

gold_outfile = f'bench/spider-eval/gold-{args.exp_name}.txt'
pred_outfile = f'bench/spider-eval/predicted-{args.exp_name}.txt'

logger.info(f"saving output to {gold_outfile} and {pred_outfile}")

with open(gold_outfile, 'w+') as f:
for datum in gold:
print(f'{datum.query}\t{datum.schema_name}', file=f)

with open(pred_outfile, 'w+') as f:
for datum in predicted:
datum = datum.replace('\n', ' ')
assert '\t' not in datum
print(datum.strip(), file=f)


if __name__ == '__main__':
main()
103 changes: 103 additions & 0 deletions bench/spider/prompt_formatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from typing import List

import bench.spider.schema
from bench.spider.dialogue import SpiderDatum


def serialize_schema(db_schema: bench.spider.schema.DbSchema):
table_strs = []
for table in db_schema.tables:
column_strs = []
for column in table.columns:
column_strs.append(f'* {column.name} ({column.tpe.value}): {column.nl_name}')
table_str = '\n'.join([table.name] + column_strs)
table_strs.append(table_str)

return '\n\n'.join(table_strs)


class SpiderPromptFormatter:
def __init__(self, spider_train_data: List[SpiderDatum], db_map):
self.spider_train_data = spider_train_data
self.db_map = db_map

self.llama2_chat_prompt_template = (
"""<s>[INST] <<SYS>>
{system_prompt}
<</SYS>>
{user_message_1} [/INST] {model_answer_1} </s>"""
+ '<s>[INST] {user_message_2} [/INST] {model_answer_2} </s>'
+ '<s>[INST] {user_message_3} [/INST] {model_answer_3} </s>'
+ '<s>[INST] {user_message} [/INST]'
)

self.system_prompt = (
'You are a coding assistant helping an analyst answer questions over business data in SQL. '
'More specifically, the analyst provides you a database schema '
'(tables in the database along with their column names and types) '
'and asks a question about the data that can be solved by issuing a SQL query to the database. '
'In response, you write the SQL statement that answers the question. '
'You do not provide any commentary or explanation of what the code does, '
'just the SQL statement ending in a semicolon.'
)

self.user_message_template = """Here is a database schema:
{schema_str}
Please write me a SQL statement that answers the following question: {utterance}
Remember, DO NOT provide any commentary or explanation of what the code does, just the SQL statement ending in a semicolon."""

def format_llama2(self, datum):
spider_train_data = self.spider_train_data
db_map = self.db_map

llama2_chat_prompt_template = self.llama2_chat_prompt_template
system_prompt = self.system_prompt
user_message_template = self.user_message_template

prompt_var_dict = {
'system_prompt': system_prompt,
}

# in-context examples from training data
for i, example_id in enumerate([10, 100, 1000], 1):
train_datum = spider_train_data[example_id]
user_message = user_message_template.format(
schema_str=serialize_schema(db_map[train_datum.schema_name]),
utterance=train_datum.utterance,
)
prompt_var_dict[f'user_message_{i}'] = user_message
prompt_var_dict[f'model_answer_{i}'] = train_datum.query + ';'

# the actual question
user_message = user_message_template.format(
schema_str=serialize_schema(db_map[datum.schema_name]),
utterance=datum.utterance,
)
prompt_var_dict['user_message'] = user_message

return llama2_chat_prompt_template.format(**prompt_var_dict)

def format_openai(self, datum):
messages = [
{"role": "system", "content": self.system_prompt},
]
for example_id in [10, 100, 1000]:
train_datum = self.spider_train_data[example_id]
user_message = self.user_message_template.format(
schema_str=serialize_schema(self.db_map[train_datum.schema_name]),
utterance=train_datum.utterance,
)
messages.append({"role": "user", "content": user_message})
messages.append({"role": "system", "content": train_datum.query + ";"})

# the actual question
user_message = self.user_message_template.format(
schema_str=serialize_schema(self.db_map[datum.schema_name]),
utterance=datum.utterance,
)
messages.append({"role": "user", "content": user_message})
return messages

0 comments on commit dacc1e8

Please sign in to comment.