Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vllm integration #14

Merged
merged 37 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
e4dda7e
vllm inference
lyutyuh Jun 19, 2024
e5ff3de
vllm particle
lyutyuh Jun 19, 2024
4e059a0
vllm base
lyutyuh Jun 20, 2024
d98994b
more makefile cleanup
benlipkin Jun 19, 2024
0eafb43
Fix: Use the `draw` argument in proposal distributions and LM samplin…
timvieira Jun 19, 2024
c92cec8
make smc-steer print out particle info when verbosity>0
postylem Jun 19, 2024
5e71fed
make smc-steer print out particle info when verbosity>0
postylem Jun 19, 2024
b283e08
tidying up
timvieira Jun 19, 2024
ee1ff33
misc cleanups
timvieira Jun 19, 2024
5bda0d8
fix: hf_tokenizers should use genparse.tokenization methods.
timvieira Jun 19, 2024
e0a8467
tidy
timvieira Jun 19, 2024
95d07db
refactor prompt formatter to be compatible with hf and vllm
leo-du Jun 19, 2024
873b62f
add script to evaluate spider on vllm
leo-du Jun 19, 2024
2855191
add vLLM setup instructions
leo-du Jun 19, 2024
0e294df
tidy
timvieira Jun 19, 2024
db7849b
whoops; forget to add file
timvieira Jun 19, 2024
9c62b33
()
timvieira Jun 19, 2024
385c9cf
new utils to simplify simplify instantiation of `HFPPLSampler`, loadi…
timvieira Jun 20, 2024
b5e115f
vllm inference
lyutyuh Jun 19, 2024
988af54
vllm base
lyutyuh Jun 20, 2024
01939d0
fix params
lyutyuh Jun 20, 2024
505ff9f
ignore
lyutyuh Jun 20, 2024
7dd732f
Merge branch 'main' into vllm
lyutyuh Jun 20, 2024
b1cddb3
fixing p_next for vllm
lyutyuh Jun 20, 2024
ae08a64
add vllm dependency
lyutyuh Jun 20, 2024
eee0efc
mock llm pnext kwargs
lyutyuh Jun 20, 2024
bca561c
cleaning code and commenting
lyutyuh Jun 20, 2024
dc48360
Merge branch 'main' into vllm
benlebrun Jun 20, 2024
8bf257c
vllm version
lyutyuh Jun 20, 2024
ff5cf6c
fix sample_next_token in token proposal duplication and return value …
benlebrun Jun 20, 2024
44a9c9f
fix multi particle
lyutyuh Jun 23, 2024
7f4c9f1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 23, 2024
bb2766f
formatting
lyutyuh Jun 23, 2024
fe1bccb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 23, 2024
f7759b3
benchmarking
lyutyuh Jun 23, 2024
db092e7
commenting
lyutyuh Jun 23, 2024
51e5133
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,6 @@ build/
*.pyx
*.so
*.html
*.pkl
*.pdf

31 changes: 17 additions & 14 deletions bench/run_spider_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@ def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()

parser.add_argument('--n-query', type=int, default=100)
parser.add_argument('--model-name', type=str,
default="meta-llama/Meta-Llama-3-8B-Instruct",
choices=["meta-llama/Meta-Llama-3-8B-Instruct"])
parser.add_argument(
'--model-name',
type=str,
default='meta-llama/Meta-Llama-3-8B-Instruct',
choices=['meta-llama/Meta-Llama-3-8B-Instruct'],
)
parser.add_argument('--exp-name', type=str, default='llama3-8b-100')
parser.add_argument("--api-base", type=str, default="http://localhost:9999/v1")
parser.add_argument("--api-key", type=str, default="EMPTY")
parser.add_argument('--api-base', type=str, default='http://localhost:9999/v1')
parser.add_argument('--api-key', type=str, default='EMPTY')

return parser

Expand All @@ -38,19 +41,19 @@ def main():
level=os.environ.get('LOGLEVEL', 'INFO').upper(),
)
# disable unnecessary logs from httpx used by openai client
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger('httpx').setLevel(logging.WARNING)
parser = get_parser()
args = parser.parse_args()

logger.info("loading spider data...")
logger.info('loading spider data...')
raw_spider_dir = Path('bench/spider/data/spider')
spider_schemas = load_schemas(
schemas_path=raw_spider_dir / 'tables.json', db_path=raw_spider_dir / 'database'
)

spider_dev_data = load_spider_data(raw_spider_dir / 'dev.json')
spider_train_data = load_spider_data(raw_spider_dir / 'train_spider.json')
logger.info("spider data loaded.")
logger.info('spider data loaded.')

prompt_formatter = SpiderPromptFormatter(spider_train_data, spider_schemas)

Expand All @@ -70,17 +73,17 @@ def main():
if i == 0: # print an example for demonstration
print('=' * 30 + ' Example prompt ' + '=' * 30)
for msg in messages:
print(msg["role"] + ":")
print("=" * (len(msg["role"])+1))
print(msg["content"])
print("-" * 100)
print(msg['role'] + ':')
print('=' * (len(msg['role']) + 1))
print(msg['content'])
print('-' * 100)
print('=' * 30 + ' End of prompt ' + '=' * 30)

chat_response = client.chat.completions.create(
model=args.model_name,
# model="mistralai/Mistral-7B-Instruct-v0.1",
messages=messages,
seed=0
seed=0,
)
gold.append(dev_datum)
predicted.append(chat_response.choices[0].message.content)
Expand All @@ -90,7 +93,7 @@ def main():
gold_outfile = f'bench/spider-eval/gold-{args.exp_name}.txt'
pred_outfile = f'bench/spider-eval/predicted-{args.exp_name}.txt'

logger.info(f"saving output to {gold_outfile} and {pred_outfile}")
logger.info(f'saving output to {gold_outfile} and {pred_outfile}')

with open(gold_outfile, 'w+') as f:
for datum in gold:
Expand Down
8 changes: 4 additions & 4 deletions bench/spider/prompt_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,21 @@ def format_llama2(self, datum):

def format_openai(self, datum):
messages = [
{"role": "system", "content": self.system_prompt},
{'role': 'system', 'content': self.system_prompt},
]
for example_id in [10, 100, 1000]:
train_datum = self.spider_train_data[example_id]
user_message = self.user_message_template.format(
schema_str=serialize_schema(self.db_map[train_datum.schema_name]),
utterance=train_datum.utterance,
)
messages.append({"role": "user", "content": user_message})
messages.append({"role": "system", "content": train_datum.query + ";"})
messages.append({'role': 'user', 'content': user_message})
messages.append({'role': 'system', 'content': train_datum.query + ';'})

# the actual question
user_message = self.user_message_template.format(
schema_str=serialize_schema(self.db_map[datum.schema_name]),
utterance=datum.utterance,
)
messages.append({"role": "user", "content": user_message})
messages.append({'role': 'user', 'content': user_message})
return messages
189 changes: 189 additions & 0 deletions benchmark/benchmark_vllm_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import pickle
from argparse import ArgumentParser
from random import seed

import numpy as np
from arsenal import colors
from hfppl import CachedCausalLM
from torch import manual_seed
from transformers import AutoTokenizer, set_seed

from genparse import Float
from genparse.cfglm import EarleyBoolMaskCFGLM
from genparse.lm import AsyncGreedilyTokenizedLLM
from genparse.vllm_compatibility import vllmpplLLM
from genparse.proposal import CharacterProposal, TokenProposal
from genparse.vllm_steer import VLLMSampler
from genparse.util import LarkStuff

import torch

p = ArgumentParser()
p.add_argument('--model', choices=['gpt2', 'codellama'], required=True)
p.add_argument('--proposal', choices=['token', 'character'], default='character')
p.add_argument('--particles', type=int, default=1)
p.add_argument('--n-beam', type=int, default=1)
p.add_argument('--reps', type=int, default=1)
p.add_argument('--max-tokens', type=int, default=100)
p.add_argument('--verbosity', type=int, default=0)
p.add_argument('--seed', type=int, default=0)
p.add_argument(
'--inference',
choices=['smc-standard', 'smc-steer', 'importance-sampling'],
default='smc-standard',
)
args = p.parse_args()


RANDOM_SEED = args.seed
set_seed(RANDOM_SEED)
seed(RANDOM_SEED)
manual_seed(RANDOM_SEED)


if args.model == 'gpt2':
import transformers

from genparse.lm import LLM

MODEL_ID = 'gpt2'
hfppl_llm = vllmpplLLM(MODEL_ID)
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_ID)

elif args.model == 'codellama':
import transformers
import torch

assert torch.cuda.is_available()

MODEL_ID = 'codellama/CodeLlama-7b-Instruct-hf'
hfppl_llm = vllmpplLLM(MODEL_ID, dtype=torch.float32, max_model_len=4096)
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_ID)

else:
raise ValueError(args.model)


prompt_template = """
You have access to a political survey data table named "data", which includes the following columns:
- "age" (integer)
- "gender" ("male" or "female"),
- "year" (integer)
- "state_color" ("blue" or "red")
- "zipcode" (integer)
- "vote" ("democrat" or "republican")
- "registered_party" ("democrat" or "republican")
- "race_ethnicity" ("white", "black", or "latino").

Q: Write a SQL query that shows individuals' age and gender, for people over 50 years old.
A: SELECT age, gender FROM data WHERE age>50 </s>
Q: Write a SQL query that shows individuals' vote and zipcode, ordered from lowest to highest age.
A: SELECT vote, zipcode, age FROM data ORDER BY age ASC </s>
Q: %s
A:"""

grammar = r"""

start: WS? "SELECT" WS select_expr WS "FROM" WS from_expr [WS "WHERE" WS bool_condition] [WS "GROUP BY" WS var_list] [WS "ORDER BY" WS orderby_expr] WS EOS
EOS: "</s>"
select_expr: STAR | select_list
bool_condition: bool_expr | "(" bool_condition WS "AND" WS bool_condition ")" | "(" bool_condition WS "OR" WS bool_condition ")"
bool_expr: var "=" value | var ">" value | var "<" value
from_expr: "data"
orderby_expr: var_list WS "ASC" | var_list WS "DESC"
select_list: select_var ("," WS select_var)*
var_list: var ("," WS var)*
select_var: var | "AVG(" var ")" | "MEDIAN(" var ")" | "COUNT(" var ")"
var: "age" | "gender" | "year" | "state_color" | "zipcode" | "vote" | "race_ethnicity"
value: NUMBER | "'red'" | "'blue'" | "'white'" | "'black'" | "'latino'" | "'republican'" | "'democrat'" | "'male'" | "'female'"
STAR: "*"
NUMBER: /\d+/
WS: /[ \n\r\t]+/

"""


prompts = [
"Write a SQL query that returns white voters' average age for each state color and sort the results.",
'Write a SQL query that shows the young republicans.',
'Write a SQL query that shows the old democrats in Williamsburg.',
'Write a SQL query that shows the oldest democrat in each red state.',
'Write a SQL query that shows the average age of red states vs blue states.',
]


def main():
character_cfg = LarkStuff(grammar).char_cfg(0.99, ignore='[ ]?')

guide = EarleyBoolMaskCFGLM(character_cfg)

BATCH_SIZE = 80

hfppl_llm.batch_size = BATCH_SIZE
genparse_llm = AsyncGreedilyTokenizedLLM(
model=hfppl_llm, tokenizer=tokenizer, batch_size=BATCH_SIZE
)

guide = EarleyBoolMaskCFGLM(LarkStuff(grammar).char_cfg(0.99, ignore='[ ]?'))
sampler = VLLMSampler(llm=genparse_llm, guide=guide)
if args.proposal == 'character':
proposal = CharacterProposal(llm=genparse_llm, guide=guide)
elif args.proposal == 'token':
proposal = TokenProposal(llm=genparse_llm, guide=guide, K=5)
else:
raise ValueError(f'invalid proposal name {args.proposal!r}')

for _ in range(args.reps):
for sql_prompt in prompts:
prompt = prompt_template % sql_prompt
print(colors.cyan % colors.line(100))
print(colors.cyan % sql_prompt)

particles, record = sampler.run_inference(
prompt=prompt,
proposal=proposal,
method=args.inference,
n_particles=args.particles,
max_tokens=args.max_tokens,
n_beam=args.n_beam,
verbosity=args.verbosity,
return_record=False,
)

# if args.particles > 1 and record is not None:
# fig = record.plot_particles_trajectory()
# fig.write_html('viz.html')
# print('wrote to viz.html')

print(colors.yellow % 'character posterior')
posterior = Float.chart()
for p in particles:
posterior[''.join(p.context).strip()] += np.exp(p.weight)
print(posterior.normalize())

if 0:
print(colors.yellow % 'token posterior')
posterior = Float.chart()
for p in particles:
posterior[tuple(p.context)] += np.exp(p.weight)
print(posterior.normalize())

sampler.timer.plot_feature('t')
with open('vllm_runtime.pkl', 'wb') as f:
pickle.dump(sampler.timer, f)
print('wrote to vllm_runtime.pkl')

import pylab as pl

pl.title(args)
pl.xlabel('context size (characters)')
pl.savefig('vllm_runtime.pdf')
print('wrote to vllm_runtime.pdf')
pl.show()

# from arsenal.debug import ip
# ip()


if __name__ == '__main__':
main()
20 changes: 14 additions & 6 deletions genparse/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from genparse.semiring import Float
from genparse.tokenization import decode_tokenizer_vocab
from genparse.vllm_compatibility import vllmpplLLM


class LM:
Expand Down Expand Up @@ -269,11 +270,18 @@ def __call__(self, xs):
async def next_token_logprobs(self, xs, top=None):
return self.p_next(xs, top=top).map_values(np.log)

async def p_next(self, xs, top=None):
assert isinstance(xs, str)
tokens = self.tokenizer.encode(xs)

_logp = await self._model.next_token_logprobs(tokens)
async def p_next(self, xs='', top=None, _logp=None, **kwargs):
if isinstance(self._model, vllmpplLLM):
# Pass the kwargs to the model.
# This is useful for passing the `execute_model_req`
# _logp is provided by the vllm centralized step function
assert (
_logp is not None
), 'Please provide the log probabilities when using VLLM.'
else:
assert isinstance(xs, str)
tokens = self.tokenizer.encode(xs)
_logp = await self._model.next_token_logprobs(tokens)
_logp = _logp.cpu().numpy() if hasattr(_logp, 'cpu') else _logp
_p = np.exp(_logp)

Expand Down Expand Up @@ -342,7 +350,7 @@ def __init__(self, V, eos, _p=None):
self._encode = {x: i for i, x in enumerate(self._decode)}
super().__init__(eos=eos, V=V)

def p_next(self, _):
def p_next(self, _, **kwargs):
return LazyProb(self._p, self._encode, self._decode)

# def __call__(self, x):
Expand Down
18 changes: 12 additions & 6 deletions genparse/proposal/character.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ async def sample_next_token(
verbosity=0,
correct_weights=True,
draw=sample_dict,
execute_model_req=None,
p_llm=None,
**kwargs,
):
"""
Expand All @@ -103,18 +105,22 @@ async def sample_next_token(
prompt : The LLM prompt.
context : The previous generated tokens.
verbosity : > 1 prints sampling process.
compare_time : true compares time spent in LLM to cfg+trie.
correct_weights : whether to correct the importance weights with RAVI.
false leads to probabilistically incorrect inference.
p_llm: provide the model with pre-computed p_llm. For VLLM.
Returns:
token : Proposed LLM token.
weight_update : Incremental SMC weight update.
"""

if iscoroutinefunction(self.llm.p_next):
p_llm = await self.llm.p_next(prompt + context)
else:
p_llm = self.llm.p_next(prompt + context)
if p_llm is None:
if iscoroutinefunction(self.llm.p_next):
p_llm = await self.llm.p_next(
prompt + context, execute_model_req=execute_model_req
)
else:
p_llm = self.llm.p_next(
prompt + context, execute_model_req=execute_model_req
)

self._update_trie(p_llm)

Expand Down
Loading
Loading