-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathremix_d5_part2_solution.py
390 lines (297 loc) · 13.4 KB
/
remix_d5_part2_solution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
# %%
"""
# REMIX Day 5, Part 2 - Model Loading
In this notebook, we'll go step by step through the process of loading GPT2-small. We'll apply various modifications to make it easier to write our experiments.
<!-- toc -->
## Learning Objectives
After today's material, you should be able to:
- Rewrite and rename the model to make it easier to write experiments
## Readings
None
"""
# %%
import os
import sys
import rust_circuit as rc
import torch
import torch as t
import os
from rust_circuit.model_rewrites import To, configure_transformer
from rust_circuit.py_utils import I
from remix_d5_utils import IOIDataset, print_max_min_by_tok_k_torch
from typing import Optional
import remix_utils
MAIN = __name__ == "__main__"
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
print("Device: ", DEVICE)
if "SKIP":
# Skip CI for now - avoids downloading GPT2
IS_CI = os.getenv("IS_CI")
if IS_CI:
sys.exit(0)
# %%
"""
## Model Loading
We want our circuit to be split both by attention head and by token position, and as a metric we are only interested in the logit difference between IO and S. The basic plan is:
- Split by heads using `configure_transformer`
- Split by token position using `split_to_concat`
We can't split by token position until we know how many token positions there. We'll make an `IOIDataset` for demo purposes that will have a specific sequence length.
You've probably loaded this model in Day 2, but if not then this will take some time to download the weights from RRFS.
The `t.bind_w` circuit is short for "transformer bind weights" - it's a `Module` with the pretrained weights included, but no token or positional embeddings yet.
"""
# %%
ioi_dataset = IOIDataset(prompt_type="mixed", N=100, device=DEVICE)
MAX_LEN = ioi_dataset.prompts_toks.shape[1] # maximal length
circ_dict, tokenizer, model_info = remix_utils.load_gpt2_small_circuit()
unbound_circuit = circ_dict["t.bind_w"]
# %%
"""
### Token Embeddings
To get our token embeddings, we'll use a placeholder for the tokens themselves of the appropriate length. Recall that we don't include a batch dimension.
`bind_to_input` computes two Circuits: the attention mask and an `Add` of the positional and token embeddings. This is a plain Python function so you can look into it if you like.
"""
# %%
tokens_sym = rc.Symbol.new_with_random_uuid((MAX_LEN,), name="tokens")
token_embeds = rc.GeneralFunction.gen_index(circ_dict["t.w.tok_embeds"], tokens_sym, 0, name="tok_embeds")
bound_circuit = model_info.bind_to_input(unbound_circuit, token_embeds, circ_dict["t.w.pos_embeds"])
print(bound_circuit)
# %%
"""
## configure_transformer
Here, `split_by_head_config="full"` means that each head will have its own set of parameters.
`use_flatten_res=True` means that the input to each module is the sum of all previous modules. Instead of having nested `Add`s for each block, we have one `Add` with a long list of all the components (MLPs and attention layers) coming before.
Exercise: in `transformed_circuit`, examine the first block's attention layer 'b0.a' and make sure you understand how the heads are split.
Exercise: in `transformed_circuit`, examine the third block `b2` and make sure you understand how the flattening works. Compare to the previous structure in `bound_circuit`.
"""
# %%
transformed_circuit = bound_circuit.update(
"t.bind_w",
lambda c: configure_transformer(
c,
To.ATTN_HEAD_MLP_NORM,
split_by_head_config="full",
use_pull_up_head_split=True,
use_flatten_res=True,
),
)
if "SOLUTION":
print("b0.a:\n", transformed_circuit.get_unique("b0.a"))
print("\nb2: ")
rc.PrintOptions.print_depth(transformed_circuit.get_unique("b2"), end_depth=6)
# %%
"""
## conform_all_modules
Our network still has some symbolic sizes in it, but we're now using concrete inputs with known sizes, so there's no need for symbolic sizes anymore. In fact, we won't be able to split by position if the sequence dimension is still symbolic.
Calling `rc.conform_all_modules` walks the tree and for each `Module`, replaces symbolic sizes with known ones wherever possible.
"""
# %%
print("Before conforming: ", transformed_circuit.get_unique("b0.a"))
transformed_circuit = rc.conform_all_modules(transformed_circuit)
print("After conforming: ", transformed_circuit.get_unique("b0.a"))
# %%
"""
## t.call substitution
The outer `t.call` `Module`'s only purpose is to have placeholders for the attention mask and the input embeddings. We don't need it anymore, so we can substitute it out to make our circuit a bit more readable.
"""
# %%
subbed_circuit = transformed_circuit.cast_module().substitute()
subbed_circuit = subbed_circuit.rename("logits")
subbed_circuit.print_html()
# %%
"""
## Removing all modules but layer norm
`Module`s are helpful when building the model to avoid copy pasting code, but they are not very helpful when we want to specify precise paths through the model: we cannot chain through `Symbol` instances!
Exercise: make a new `subbed_circuit` where you substitute away all `Module`s except for layer norms.
"""
# %%
# TBD: explain why you have to do multiple substitutes - or is there a better way that doesn't involve multiple calls?
if "SOLUTION":
def module_but_norm(circuit: rc.Circuit) -> bool:
"""Match all Module nodes that are not layer norms"""
return isinstance(circuit, rc.Module) and not (
"norm" in circuit.name or "ln" in circuit.name or "final" in circuit.name
)
while True:
print("Modules remaining: ", len(subbed_circuit.get(rc.Module)))
prev = subbed_circuit
subbed_circuit = subbed_circuit.update(module_but_norm, lambda c: c.cast_module().substitute())
if prev == subbed_circuit:
break
else:
"""TODO: update subbed_circuit so the test passes"""
expected = [
"a0.norm",
"a1.norm",
"a10.norm",
"a11.norm",
"a2.norm",
"a3.norm",
"a4.norm",
"a5.norm",
"a6.norm",
"a7.norm",
"a8.norm",
"a9.norm",
"final.call",
"final.norm",
"m0.norm",
"m1.norm",
"m10.norm",
"m11.norm",
"m2.norm",
"m3.norm",
"m4.norm",
"m5.norm",
"m6.norm",
"m7.norm",
"m8.norm",
"m9.norm",
]
actual = sorted([c.name for c in subbed_circuit.get(rc.Module)])
assert actual == expected
# %%
"""
## Renaming of blocks
We will use these names a lot in our future experiment. To make it easier, we shorten them to remove useless information (such as the 'b' for 'block' that is not interesting for us. We consider attention heads and mlps, not blocks as a whole.)
"""
# %%
renamed_circuit = subbed_circuit.update(rc.Regex(r"[am]\d(.h\d)?$"), lambda c: c.rename(c.name + ".inner"))
renamed_circuit = renamed_circuit.update("t.inp_tok_pos", lambda c: c.rename("embeds"))
for l in range(model_info.params.num_layers):
"""b0 -> a1.input, ... b11 -> final.input"""
next = "final" if l == model_info.params.num_layers - 1 else f"a{l+1}"
renamed_circuit = renamed_circuit.update(f"b{l}", lambda c: c.rename(f"{next}.input"))
"""b0.m -> m0, etc."""
renamed_circuit = renamed_circuit.update(f"b{l}.m", lambda c: c.rename(f"m{l}"))
renamed_circuit = renamed_circuit.update(f"b{l}.m.p_bias", lambda c: c.rename(f"m{l}.p_bias"))
renamed_circuit = renamed_circuit.update(f"b{l}.a", lambda c: c.rename(f"a{l}"))
renamed_circuit = renamed_circuit.update(f"b{l}.a.p_bias", lambda c: c.rename(f"a{l}.p_bias"))
for h in range(model_info.params.num_layers):
"""b0.a.h0 -> a0.h0, etc."""
renamed_circuit = renamed_circuit.update(f"b{l}.a.h{h}", lambda c: c.rename(f"a{l}.h{h}"))
renamed_circuit.print_html()
# %%
"""
## Split By Position
We'll want to have the ability to target only specific sequence positions for interventions such as "the output of head 0.0 at sequence position 5". To do this, we'll create an intermediate `Index` node named "a0.h0_at_idx_5", and then concatenate a bunch of these back together to get a "a0.h0_by_pos" that is the same as the original "a0.h0".
Exercise: implement `split_to_concat_axis_0`.
"""
# %%
def split_to_concat_axis_0(c: rc.Circuit) -> rc.Concat:
"""Turns `c` into `Concat(c[0:1], c[1:2], ...)`.
Each index should be named {c.name}_at_idx_{i}.
The output name should be {c.name}_by_pos.
Simplified version of rc.split_to_concat.
"""
"SOLUTION"
print(c)
n = c.shape[0]
inps = [rc.Index(c, I[i : i + 1], name=f"{c.name}_at_idx_{i}") for i in range(n)]
return rc.Concat(*inps, axis=0, name=f"{c.name}_by_pos")
# matches a#.h# and m#
head_and_mlp_matcher = rc.IterativeMatcher(rc.Regex(r"^(a\d+.h\d+?|m\d+)$"))
split_circuit = renamed_circuit.update(head_and_mlp_matcher, split_to_concat_axis_0)
a0h0 = split_circuit.get_unique("a0.h0_by_pos")
for i in range(16):
idx = a0h0.children[i]
assert isinstance(idx, rc.Index)
assert idx.name == f"a0.h0_at_idx_{i}"
# %%
"""
## More Renames
Again, we rename some names to make the circuit easier to read. We use a trick to make renaming faster: we create a dictionary of old names to new names, and then use the `update` method to rename all the nodes at once.
"""
# %%
new_names_dict = {}
for l in range(model_info.params.num_layers):
for i in range(MAX_LEN):
for h in range(model_info.params.num_layers):
# b0.a.h0 -> a0.h0, etc.
new_names_dict[f"a{l}.h{h}_at_idx_{i}"] = f"a{l}_h{h}_t{i}"
new_names_dict[f"m{l}_at_idx_{i}"] = f"m{l}_t{i}"
split_circuit = split_circuit.update(
rc.Matcher(*list(new_names_dict.keys())), lambda c: c.rename(new_names_dict[c.name])
)
split_circuit.print_html()
# %%
"""
## Running the Circuit!
You may also be wondering about how to actually run your circuit!
Here we are expanding the model to have a batch dimension using `rc.Sampler`.
We also replace the tokens with a `DiscreteVar`, but then tell our `Sampler` to run on all datums in the `DiscreteVar`'s input dataset in order (so there is no actual randomness).
We print the top 5 tokens with max and min logits. In the top 5 logits, IO appears first, but you can also find S. IO is put a probability much stronger than S, but the proba of S is still much higher than a random name.
"""
# %%
def evaluate_on_dataset(c: rc.Circuit, tokens: torch.Tensor, group: Optional[rc.Circuit] = None):
"""Run the circuit on all elements of tokens. Assumes the 'tokens' module exists in the circuit."""
arr = rc.Array(tokens, name="tokens")
var = rc.DiscreteVar(arr)
c2 = c.update("tokens", lambda _: var)
transform = rc.Sampler(rc.RunDiscreteVarAllSpec([var.probs_and_group]))
return transform.sample(c2).evaluate()
all_logits = evaluate_on_dataset(split_circuit, ioi_dataset.prompts_toks[:5, :])
next_token_logits = all_logits[torch.arange(5), MAX_LEN - 1]
for i in range(5):
print(f'\n\nExpected completions for prompt: "{ioi_dataset.prompts_text[i]}"')
print_max_min_by_tok_k_torch(next_token_logits[i], k=5)
# %%
"""
## Logit Differences
We'll then create a new circuit that is only computing the logit difference betwee IO and S (the metric we're interested in). This circuit will also contain the labels of our dataset. We will use it to run path patching experiments where we are only interested in changing the inputs (and so don't have to deal with the labels after that).
"""
# %%
io_s_labels = torch.cat([ioi_dataset.io_tokenIDs.unsqueeze(1), ioi_dataset.s_tokenIDs.unsqueeze(1)], dim=1)
device_dtype = rc.TorchDeviceDtype(dtype="float32", device="cpu")
tokens_device_dtype = rc.TorchDeviceDtype(device_dtype.device, "int64")
labels = rc.cast_circuit(rc.Array(torch.zeros(2), name="labels"), tokens_device_dtype.op()).cast_array()
labels1 = rc.Index(labels, I[0], name="labels1")
labels2 = rc.Index(labels, I[1], name="labels2")
logit1 = rc.GeneralFunction.gen_index(
split_circuit.index((-1,)),
labels1,
index_dim=0,
batch_x=True,
name="logit1",
)
logit2 = rc.GeneralFunction.gen_index(
split_circuit.index((-1,)),
labels2,
index_dim=0,
batch_x=True,
name="logit2",
)
logit_diff_circuit = rc.Add.minus(logit1, logit2)
# %%
"""
## Labels
Let's add the labels to our circuit. The labels are `DiscreteVar`s inserted in the circuit and the `group` variable stores the order they are sampled from. As long as the same `group` is used between sentences and labels, they'll be kept in the same order.
"""
# %%
def add_labels_to_circuit(c: rc.Circuit, tokens: torch.Tensor, labels: torch.Tensor):
"""Run the circuit on all elements of tokens. Assumes the 'tokens' module exists in the circuit."""
assert tokens.ndim == 2 and tokens.shape[1] == MAX_LEN
batch_size = tokens.shape[0]
print(batch_size)
group = rc.DiscreteVar.uniform_probs_and_group(batch_size)
c = c.update(
"labels",
lambda _: rc.DiscreteVar(rc.Array(labels, name="labels"), probs_and_group=group),
)
return c, group
logit_diff_circuit, group = add_labels_to_circuit(logit_diff_circuit, ioi_dataset.prompts_toks, io_s_labels)
"""
## Sanity Check
As a sanity check, let's run the logit diff circuit on our dataset
"""
# %%
c = logit_diff_circuit.update(
"tokens",
lambda _: rc.DiscreteVar(rc.Array(ioi_dataset.prompts_toks, name="tokens"), probs_and_group=group),
)
transform = rc.Sampler(rc.RunDiscreteVarAllSpec([group]))
results = transform.sample(c).evaluate()
print(f"Logit difference for the first 5 prompts: {results[:5]}")
print(f"Average logit difference: {results.mean()} +/- {results.std()}")
print("Testing that results are in a usual range: ")
assert results.mean() > 2.5 and results.mean() < 4
# %%