Skip to content

Commit

Permalink
put prompt generation functions in a seperate file
Browse files Browse the repository at this point in the history
  • Loading branch information
loubnabnl committed Sep 19, 2022
1 parent eb0708f commit 42f8ef6
Showing 1 changed file with 236 additions and 0 deletions.
236 changes: 236 additions & 0 deletions lm_eval/prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
"""Prompt design for each benchmark"""

import json
import re
import torch


EOF_STRINGS = ["\nclass", "\ndef", "\n#", "\n@", "\nprint", "\nif"]
MBPP_EOF_STRINGS = ["\nclass", "\nassert", '\n"""', "\nprint", "\nif", "\n<|/"]
TRIPLE_QUOTE = '"""'
SINGLE_TRIPLE_QUOTE = "'''"
SPACES4 = " " * 4


def truncate_prompt_apps(prompt, tokenizer, max_length, call_format):
# if a prompt is very long we truncate it but keep the end phrases
input_ids = tokenizer(prompt, return_tensors="pt").input_ids[0]
if len(input_ids) > max_length:
end_phrase = tokenizer(
call_format + "\nANSWER:\n", return_tensors="pt"
).input_ids[0]
max_length = max_length - len(end_phrase)
new_ids = torch.cat((input_ids[:max_length], end_phrase))
prompt = tokenizer.decode(
new_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
return prompt


def apps_few_shot_prompt(prompt):
with open("lm_eval/few_shot_examples/apps_few_shot_prompts.json", "r") as file:
examples = json.load(file)

# add two examples one for each implementation type: call-based/input-based
one_shot_prompt = (
"Implement answers to the following problems:\nProblem:\n"
+ examples["problem_type1"]
+ "\nUse Standard Input format\nANSWER:\n"
+ examples["solution_type1"]
+ "\n\nProblem:\n"
+ examples["problem_type2"]
+ "\nUse Call-Based format\nANSWER:\n\n"
+ examples["solution_type2"]
+ "\n\nProblem:\n"
+ prompt
)
return one_shot_prompt


def generate_prompt_apps(
sample, tokenizer, max_length=1024, prefix="", setup="finetuning"
):
"""Generate prompts for APPS
Finetuning setup: prompt= question with some starter code and function name if they exist.
We also specify the type of the prompt, i.e. whether it is call-based or standard input
2-shot: two examples of input/output are included"""

if setup == "finetuning":
starter_code = (
None if len(sample["starter_code"]) == 0 else sample["starter_code"]
)
try:
input_outpout = json.loads(sample["input_output"])
fn_name = (
None if not input_outpout.get("fn_name") else input_outpout["fn_name"]
)
except ValueError:
fn_name = None
prompt = "\nQUESTION:\n"
prompt += sample["question"]
if starter_code:
prompt += starter_code
if fn_name:
call_format = "\nUse Standard Input format"
prompt += call_format
else:
call_format = "\nUse Call-Based format"
prompt += call_format
prompt += "\nANSWER:\n"
if setup != "finetuning":
# few shot mode: this adds 270 tokens in avg to the prompt
prompt = apps_few_shot_prompt(prompt)
prompt = truncate_prompt_apps(prompt, tokenizer, max_length, call_format)

else:
with open("lm_eval/few_shot_examples/apps_few_shot_prompts.json", "r") as file:
examples = json.load(file)

# add two examples one for each implementation type: call-based/input-based
prompt = (
"Implement answers to the following problems:\nProblem:\n"
+ examples["problem_type1"]
+ "\nUse Standard Input format\nANSWER:\n"
+ examples["solution_type1"]
+ "\n\nProblem:\n"
+ examples["problem_type2"]
+ "\nUse Call-Based format\nANSWER:\n\n"
+ examples["solution_type2"]
+ "\n\nProblem:\n"
+ prompt
)

return prefix + prompt


def mbpp_incoder_prompt(sample, include_solution_mbpp=False, prefix=""):
"""Generate prompts for MBPP prompt similarily to InCoder
prompt = docstringthat includes one test"""
description = sample["text"]
test_example = sample["test_list"][0]
prompt = f'"""\n{description}\n{test_example}\n"""\n'

if include_solution_mbpp:
prompt += f"{sample['code']}\n"
return prefix + prompt


def mbpp_google_prompt(sample, include_tests=True, prefix=""):
"""Generate prompts for MBPP similarily to the original google paper
with an option for including the tests cases or not:
prompt = description + 'Your code should
satisfy these tests:'+ three assert statements"""

prompt = sample["text"]
if include_tests:
prompt += " Your code should satisfy these tests:\n"
for test in sample["test_list"]:
prompt += "\n" + test
return prefix + prompt


def code_to_text_prompt(sample, language="python", prompt_type="left", prefix=""):
"""Generate prompts for code-to-text task
For prompt_type left we include the left code with function signature (only possible for Python now),
else we only include the whole body"""
# TODO implement signature extraction for other languages?
code = sample["code"]

if language == "python":
# python code includes the docstring
text = sample["docstring"]
prompt_prefix = code[: code.index(text)]
prompt_prefix = standardize_docstring_prompt(prompt_prefix)
if prompt_type == "left":
return prefix + prompt_prefix
else:
prompt_suffix = code[code.index(text) + len(text) :]
prompt_suffix = prompt_suffix.replace(TRIPLE_QUOTE, "")
prompt_suffix = prompt_suffix.replace(SINGLE_TRIPLE_QUOTE, "")

prompt_prefix = prompt_prefix.strip().removesuffix(TRIPLE_QUOTE)
prompt_prefix = prompt_prefix.strip().removesuffix(SINGLE_TRIPLE_QUOTE)
prompt = (
prompt_prefix + prompt_suffix + '\n"""Explanation of the code above:\n'
)
return prefix + prompt

elif language == "Ruby":
return prefix + code + "\n=begin Explanation of the code above:\n"

else:
return prefix + code + "\n/* Explanation of the code above:\n"


# source: InCoder evaluation code https://github.com/dpfried/lm-evaluation-harness/
def standardize_docstring_prompt(prefix):
"""Strips any existing docstring delimiters from the prompt prefix and
and adds our own delimiter (triple quote) and whitespace.
Note an edge cases being handled here:
- codexglue docstring text sometimes contains the docstring delimiters, inconsistently
"""

for delim in [TRIPLE_QUOTE, SINGLE_TRIPLE_QUOTE]:
if delim in prefix:
prefix = prefix[: prefix.index(delim)]
break

single_single_quote_with_trailing_spaces = re.compile(r'[^\'"][\']\s*$')
if single_single_quote_with_trailing_spaces.search(prefix):
prefix = prefix[
: single_single_quote_with_trailing_spaces.search(prefix).start()
]

single_double_quote_with_trailing_spaces = re.compile(r'[^\'"]["]\s*$')
if single_double_quote_with_trailing_spaces.search(prefix):
prefix = prefix[
: single_double_quote_with_trailing_spaces.search(prefix).start()
]

prefix += TRIPLE_QUOTE
return prefix


def two_shot_prompt(entry, text, examples):
"""Two shot prompt format as instructions & solutions"""
instrcution1 = "\nInstruction:\n" + examples["instruction1"]
solution1 = "\nSolution:\n" + examples["solution1"]
instrcution2 = "\nInstruction:\n" + examples["instruction2"]
solution2 = "\nSolution:\n" + examples["solution2"]
examples = entry + instrcution1 + solution1 + instrcution2 + solution2
prompt = examples + "\nInstruction:\n" + text + "\nSolution:\n"
return prompt


def conala_prompt(sample, prefix=""):
"""Generate prompts for CoNaLa text-to-code task in a 2-shot setting"""
with open("lm_eval/few_shot_examples/conala_few_shot_prompts.json", "r") as file:
examples = json.load(file)
text_column = "rewritten_intent" if sample["rewritten_intent"] else "intent"
text = prefix + sample[text_column].strip()
entry = "Answer the following instructions in one line of Python code:\n"
prompt = two_shot_prompt(entry, text, examples)
return prefix + prompt


def spider_prompt(sample, prefix=""):
"""Generate prompts for Spider text-to-code task in a 2-shot setting"""
with open("lm_eval/few_shot_examples/spider_few_shot_prompts.json", "r") as file:
examples = json.load(file)
text = prefix + sample["question"].strip()
entry = "Answer the following instructions in a one line SQL query:\n"
prompt = two_shot_prompt(entry, text, examples)
return prefix + prompt


def concode_prompt(sample, prefix=""):
"""Generate prompts for Spider text-to-code task in a 2-shot setting"""
with open("lm_eval/few_shot_examples/concode_few_shot_prompts.json", "r") as file:
examples = json.load(file)
text = sample["nl"].split("concode_field_sep")[0].strip()
if text.endswith("."):
text = text[:-1].strip()
text = prefix + text
entry = "Answer the following instructions in a one line of Java code:\n"
prompt = two_shot_prompt(entry, text, examples)
return prefix + prompt

0 comments on commit 42f8ef6

Please sign in to comment.