-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathprepare_prompts.py
122 lines (102 loc) · 3.38 KB
/
prepare_prompts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import argparse
import ast
import json
import random
import yaml
from muchomusic.utils import get_all_audio_paths, load_questions_from_csv
def _shuffle_answers(questions, answers):
"""Shuffle order of the answer options."""
num_answers = len(answers[0])
answer_orders = [
random.sample(range(num_answers), k=num_answers) for i in range(len(questions))
]
shuffled_answers = [
[answers[i][j] for j in answer_orders[i]] for i in range(len(questions))
]
return shuffled_answers, answer_orders
def get_prompts(questions, answers, in_context_expamples: list[str]):
questions_with_options = []
letter_options = ["A", "B", "C", "D"]
for i, question in enumerate(questions):
if in_context_expamples:
question_with_options = (
" ".join(in_context_expamples)
+ "\n Question: "
+ question
+ "\n Options: "
)
else:
question_with_options = "Question: " + question + "\n Options: "
for j in range(len(answers[i])):
question_with_options += f"({letter_options[j]}) {answers[i][j]} "
question_with_options += "\n The correct answer is: "
questions_with_options.append(question_with_options)
return questions_with_options
def prepare_questions(
questions_path,
in_context_examples=None,
option_subset=None,
distractors=[
"incorrect_but_related",
"correct_but_unrelated",
"incorrect_and_unrelated",
],
):
(
ids,
questions,
answers,
datasets,
genres,
reasoning,
knowledge,
) = load_questions_from_csv(questions_path, distractors=distractors)
if option_subset is not None:
assert 0 in option_subset
answers = [answers[i] for i in option_subset]
answers, answer_orders = _shuffle_answers(questions, answers)
prompts = get_prompts(questions, answers, in_context_examples)
reasoning = [ast.literal_eval(item) for item in reasoning]
knowledge = [ast.literal_eval(item) for item in knowledge]
return {
"id": ids,
"prompt": prompts,
"answers": answers,
"answer_orders": answer_orders,
"dataset": datasets,
"genre": genres,
"reasoning": reasoning,
"knowledge": knowledge,
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="default")
parser.add_argument("--output_path", type=str, default="example_file.json")
args = parser.parse_args()
with open(
f"muchomusic/configs/{args.config}.yaml",
"r",
) as file:
exp_config = yaml.safe_load(file)
question_dict = prepare_questions(
questions_path="data/muchomusic.csv",
in_context_examples=exp_config["in_context_examples"],
distractors=exp_config["distractors"],
)
audio_file_paths = get_all_audio_paths(
question_dict["id"],
question_dict["dataset"],
)
question_dict["audio_path"] = audio_file_paths
question_dict["model_output"] = ["A"] * len(audio_file_paths)
inputs_dict = [
dict(
zip(
question_dict,
t,
)
)
for t in zip(*question_dict.values())
]
with open(args.output_path, "w") as f:
json.dump(inputs_dict, f)