-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathremix_d5_part1_solution.py
226 lines (164 loc) · 8.77 KB
/
remix_d5_part1_solution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
# %%
"""
# REMIX Day 5, Part 1 - IOIDataset
There is a lot of bookkeeping involved in running the IOI experiments, which we've provided for you. The class `IOIDataset` handles all the dataset related computation including tokenization and computation of the relevant indices.
<!-- toc -->
## Learning Objectives
After going through this material, you should be able to:
- Use the IOIDataset API
- Understanding the principle to create similar datasets during your future project
## Readings
* The [slides from the lecture](https://docs.google.com/presentation/d/13Bvmo8E6N5qhgj1yCXq5O7zNRzNNXZLzexlgdzdgZ_E/edit?usp=sharing)
* [A guide to language model interpretability](https://docs.google.com/document/d/1cSdLwC9mVaLxMDKaXbOsxrglwATOjc0NfMuUvxLNnNE/edit?usp=sharing) (most of the content covered in the lecture, here to refer to it if needed)
"""
# %%
# get_ipython().run_line_magic("load_ext", "autoreload")
# get_ipython().run_line_magic("autoreload", "2")
from copy import deepcopy
import torch
from remix_d5_utils import IOIDataset
MAIN = __name__ == "__main__"
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
print("Device: ", DEVICE)
# %%
"""
## Constructor
The simplest way to construct an `IOIDataset` involves two arguments: the number of prompts to generate, and a string representing the type of prompt.
Exercise: Try out the different types of prompt. You should be able to Ctrl+Click the IOIDataset below and see the legal string values. Or, if you put an illegal prompt type then your IDE should complain and show you a message with the legal string values.
Exercise: Is IOIDataset deterministic, or are the prompts random each time?
"""
# %%
ioi_dataset = IOIDataset(3, prompt_type="mixed", seed=78, device=DEVICE)
print("Prompts: ", ioi_dataset.prompts_text)
print("Tokens: ", ioi_dataset.prompts_toks)
# %%
"""
## Tokenization
Notice that the sentence begins with `<|endoftext|>`. The tokenizer recognizes this special string and replaces it with a specific token 50256. During the training of GPT2, we believe there were no padding tokens and two different articles could appear within one training example, separated by this token.
So the idea behind putting this at the start is to mimic GPT2's training, where the thing after this token comes at the start of a new article.
Not using '<|endoftext|>' token lead to big difference if the first token of the sentence is a name (e.g. "Alice and Bob ..." vs "<|endoftext|> Alice and Bob"). Results are similar when the first token is not important (e.g. "Then, Alice and Bob ..." vs "<|endoftext|>Then, Alice and Bob").
We use words for names, places and objects that are single tokens. This makes it easier to study the model: a token position contains all the information about a given name instead of being split between two for instance.
## Word_idx
This variable of type `Dict[str,torch.Tensor]` is dictionary that maps the name of a word to its index in each of the prompts. For example, `word_idx["IO"]` will give you a tensor of ints of shape `(NB_PROMPTS)`. Each entry is the index of the IO token in the prompt. The possible keys are "IO", "S1", "S2", "S1+1" and "END" for the index of the last token (" to").
"""
# %%
print(ioi_dataset.prompts_text[0])
for k, v in ioi_dataset.word_idx.items():
print(f" The token {k} is at position {v[0]}")
# %%
"""
To check that the position are correct, we can use the `prompts_text_toks` that store the tokenized prompt as a list of string. The names are replaced by their semantic annotation (IO, S1 or S2). This is also helpful to know how sentences are tokenized.
TBD: this should probably be a docstring on the actual class's field
"""
# %%
ioi_dataset.prompts_text_toks[0]
# %%
"""
`prompt_metadata`: a list of dictionaries containing metadata about the prompts. They include the string of the placeholders, the id of the template used and the order of the names.
"""
# %%
print(f"Some metadata about '{ioi_dataset.prompts_text[0]}'")
print()
print(ioi_dataset.prompts_metadata[0])
# %%
"""
## Metadata
You can use metadata to create copies (or modification) of the dataset. In this case, there is no randomness involved: all the information needed to create the dataset is contained in the metadata.
"""
# %%
new_metadata = deepcopy(ioi_dataset.prompts_metadata)
new_metadata[0]["S"] = "Robert"
new_ioi_dataset = IOIDataset(
N=ioi_dataset.N, prompt_type=ioi_dataset.prompt_type, manual_metadata=new_metadata, device=DEVICE
)
print(f"Original prompt: {ioi_dataset.prompts_text[0]}")
print(f"New prompt: {new_ioi_dataset.prompts_text[0]}")
# %% [markdown]
"""
## Flips
By "flip", we mean replacing part of a prompt with a new random name. For instance:
"""
# %%
flipped_io_dataset = ioi_dataset.gen_flipped_prompts("IO")
print("Original: ", ioi_dataset.prompts_text[0])
print("Flipped: ", flipped_io_dataset.prompts_text[0])
# %%
"""
We can also flip the S1 token to a random name. By doing so, we change the name family! The new dataset is part of the ABC family and not IOI as the sentences now contains three distinct names.
This means:
- The word_idx has different keys. "IO1" is the old "IO" and "IO2" is the newly created IO taking the place of "S1". "S" is the old "S2". In particular, this means that the number of the IO doesn't refer to their position in the sequence, you can have IO2 appearing before IO1.
- `prompt_metadata` now contains extra key "IO2" that is the value of the newly created IO, and "IO" is removed.
The majority of the keys for the ABC and IOI family are different to force you to know what you are looking at. After composing several filp, it's easy to forget if you are looking at a ABC or IOI dataset.
"""
# %%
flipped_s1_dataset = ioi_dataset.gen_flipped_prompts("S1")
print(f"Original prompt: {ioi_dataset.prompts_text[0]}")
print(f"New prompt: {flipped_s1_dataset.prompts_text[0]}")
assert flipped_s1_dataset.prompt_family == "ABC"
print(flipped_s1_dataset.word_idx)
# %%
"""
## Composing Flips
Naturally, you can do multiple flips and eventually reach a dataset that has nothing in common with the original.
"""
# %%
two_flip = flipped_s1_dataset.gen_flipped_prompts("IO1")
print(f"Original prompt: {ioi_dataset.prompts_text[0]}")
print(f"After fliping S1 then IO1: {two_flip.prompts_text[0]}")
# %%
three_flip = two_flip.gen_flipped_prompts("S")
print(f"Original prompt: {ioi_dataset.prompts_text[0]}")
print(f"After fliping S1, IO1 and S: {three_flip.prompts_text[0]}")
# %%
"""
## Generating Your Own Flip
We want to modified the sentences such that the two first names appearing in each sentences are flipped. For instance, "Alice and Bob ..." becomes "Bob and Alice ...".
Exercise: implement `order_flip` in the case where the dataset is from the family ABC. (The case of IOI is a bit tricky and involve template manipulation.).
<details>
<summary>Click here for a hint</summary>
You can flip the values of IO1 and IO2 in the metadata and create a new dataset using manual_metadata.
</details>
"""
# %%
def order_flip(dataset: IOIDataset) -> IOIDataset:
"""
- For a dataset from the ABC family, generate a new dataset where the two first names appears in flipped order. "Alice and Bob ..." becomes "Bob and Alice ...".
"""
assert dataset.prompt_family == "ABC"
new_prompts_metadata = deepcopy(dataset.prompts_metadata)
"SOLUTION"
for prompt in new_prompts_metadata:
temp = prompt["IO2"]
prompt["IO2"] = prompt["IO1"]
prompt["IO1"] = temp
return IOIDataset(
N=dataset.N,
prompt_type=dataset.prompt_type,
manual_metadata=new_prompts_metadata,
prompt_family=dataset.prompt_family,
device=DEVICE,
)
original = IOIDataset(3, prompt_type="mixed", device=DEVICE)
original = original.gen_flipped_prompts(
"S2"
) # An arbitrary flip to make it from the ABC family. When created `original`` is from the IOI family.
# %% running the function
flipped_order = order_flip(original)
# %% test of the correctness
assert flipped_order.prompt_family == "ABC"
for i in range(len(flipped_order)):
assert flipped_order.prompts_metadata[i]["IO2"] == original.prompts_metadata[i]["IO1"]
assert flipped_order.prompts_metadata[i]["IO1"] == original.prompts_metadata[i]["IO2"]
assert flipped_order.prompts_metadata[i]["S"] == original.prompts_metadata[i]["S"]
# %%
"""
Exercise: why does the following tests fail?
<details>
<summary>Answer</summary>
The _values_ of IO1 and IO2 were changed, but the position of the words labeled IO1 and IO2 are the same. The word_idx is a mapping from a token labeled by its role (IO, S, IO2 etc) to its position. So the word_idx is not changed.
</details>
"""
print(flipped_order.word_idx["IO1"], original.word_idx["IO2"])
print("Test 1: ", (flipped_order.word_idx["IO1"] == original.word_idx["IO2"]).all())
print("Test 2: ", (flipped_order.word_idx["IO2"] == original.word_idx["IO1"]).all())
# %%