-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathremix_d3_solution.py
1447 lines (1162 loc) · 60.4 KB
/
remix_d3_solution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# %% [markdown]
"""
# REMIX Day 3 - Replicating Results on Induction Heads
Today you'll be replicating the results on induction heads from our [writeup](https://www.lesswrong.com/posts/j6s9H9SHrEhEfuJnq/causal-scrubbing-results-on-induction-heads). By the end, you'll have a more nuanced understanding of induction and be equipped to formulate and test your own hypotheses!
This second half of this notebook closely follows the writeup and I recommend having the writeup open to the corresponding section (where applicable) to look at the diagrams.
<!-- toc -->
## Learning Objectives
After today's material, you should be able to:
- Sample from `Circuit`s containing random variables represented by `DiscreteVar`
- Customize `PrintOptions` to control color and expansion of appropriate nodes
- Write scrubbing code manually to test hypotheses
## Readings
- [Induction Head Writeup on Less Wrong](https://www.lesswrong.com/posts/j6s9H9SHrEhEfuJnq/causal-scrubbing-on-induction-heads-part-4-of-5)
## Setup
"""
# %%
import os
import sys
from pprint import pprint
import rust_circuit as rc
import torch
from rust_circuit.model_rewrites import To, configure_transformer
from rust_circuit.module_library import negative_log_likelyhood
from rust_circuit.py_utils import S, I
from torch.testing import assert_close
import remix_d3_test as tests
import remix_utils
# build instructions and commit with something like:
# ~/unity/interp/demos/causal_scrubbing$ python ~/mlab2/build_instructions.py induction.py induction_instructions.md && git commit -am "d3 wip" --no-verify && git push
MAIN = __name__ == "__main__"
# On my Mac this took 44 minutes to run on CPU
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)
if "SKIP":
# CI takes longer than 20s timeout right now, just skip CI
IS_CI = os.getenv("IS_CI")
if IS_CI:
sys.exit(0)
# %% [markdown]
"""
## The dataset
Before we can actually run the experiments, we have a lot of preparation to get through.
We'll start by loading and examining the dataset, which is text from the validation set of OpenWebText.
### Data loading
As discussed in Day 2, when we parse a string representation of a `Circuit`, `rc.Parser` will automatically download referenced tensors from RRFS. The first run of the below cell might take a few seconds, but later runs should be nearly instant.
"""
# %%
seq_len = 300 # longer seq len is better, but short makes stuff a bit easier...
# n_files = 12
# reload_dataset = False
# This doesn't work without Unity but you shouldn't have to do it
# if reload_dataset:
# from interp.tools.data_loading import get_val_seqs
# dataset_toks = torch.tensor(get_val_seqs(n_files=n_files, files_start=0, max_size=seq_len + 1), device=DEVICE)
# n_samples, _ = dataset_toks.shape
# toks_int_values = rc.Array(dataset_toks.float(), name="toks_int_vals")
# print(f'new dataset "{toks_int_values.repr()}"')
P = rc.Parser()
toks_int_values = P("'toks_int_vals' [104091,301] Array 3f36c4ca661798003df14994").cast_array()
# %%
"""
### Data Inspection
Machine learning is "garbage in, garbage out" so it's important to understand your data.
Here the data consists of 104091 examples, each consisting of 301 tokens.
To inspect the data we need the `tokenizer` that was used to convert text to tokens. This is typically stored with the model. We'll examine the model in a bit, but we need to load it here to get access to the tokenizer.
Again, this will take some time on the first run.
"""
# %%
loaded, tokenizer, extra_args = remix_utils.load_attention_only_2()
# %% [markdown]
"""
Exercise: convert the first two training examples back to text using [`tokenizer.batch_decode`](https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizer.batch_decode) and manually inspect the data. Can you spot opportunities for induction or induction-like behaviors to help with predictions?
"""
# %%
if "SOLUTION":
if MAIN:
text = tokenizer.batch_decode(toks_int_values.value[0:2].int())
# %% [markdown]
"""
### Tokens Where Induction is Likely
Out of around 50,000 tokens in the vocabulary, we're narrowing our investigation to a smaller list of around 10,000 tokens. We tried to choose tokens "A" where "hard" induction ("AB...AB") is particularly helpful. (Hard induction is induction with exactly repeated tokens, as contrasted with "soft" induction which may copy over sentence structure or some other feature of the previous token).
Below, `good_induction_candidate[token_id]` is 1 if the token at that index is part of the short list and 0 otherwise.
In this work, we're not interested in explaining everything the induction heads do, only what they do on these examples (where we can reasonably expect them to be largely doing induction rather than other things).
See the appendix of the writeup for information on how these were chosen.
Exercise: Convert all the tokens in the short list back to text (again using [`tokenizer.batch_decode`](https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizer.batch_decode)) and print them out to get a sense of what sorts of tokens exist and are likely to benefit from induction.
Exercise: What is the longest token in this set of good induction candidates?
Optional exercise: Find a non-English token. What does it mean? Does it kinda make sense that this token would benefit from induction?
"""
# %%
good_induction_candidate = torch.load(f"remix_d3_data/induction_candidates.pt").to(device=DEVICE, dtype=torch.float32)
if "SOLUTION":
if MAIN:
toks = tokenizer.batch_decode(good_induction_candidate.nonzero().flatten().view(-1, 1))
maxlen_tok = max((len(tok), tok) for tok in toks)
"""機 means machine"""
# %% [markdown]
"""
## The Model
### Working on GPU
`rc.cast_circuit` is a general function to cast an entire `Circuit` to a different device and/or dtype. In this case of arrays like `toks_int_values`, it's just the same as doing `Array(toks_int_values.to(DEVICE, dtype="int64"))`, but it's good to know this function for more complicated cases.
Note that unlike the PyTorch method `nn.Module.to` which modifies the module in-place, `rc.cast_circuit` returns a new `Circuit` and doesn't modify the input `Circuit`. This aligns with the general philosophy in the circuits library to never modify in-place.
Move the dataset and model to the GPU:
"""
# %%
toks_int_values = rc.cast_circuit(toks_int_values, rc.TorchDeviceDtypeOp(device=DEVICE, dtype="int64")).cast_array()
loaded = {s: rc.cast_circuit(c, rc.TorchDeviceDtypeOp(device=DEVICE)) for s, c in loaded.items()}
# %% [markdown]
"""
### Examining the model
At this stage, print out a representation of the model using the code below and take a minute to appreciate the mighty two-layer attention only model in all its glory. Other than some chunks being missing, there's only one substantial difference in architecture from the GPT-2-small model you implemented yesterday.
"""
# %%
# TBD: are names finalized here? We should give them a printout of the canonical model
orig_circuit = loaded["t.bind_w"]
tok_embeds = loaded["t.w.tok_embeds"]
pos_embeds = loaded["t.w.pos_embeds"]
if MAIN:
orig_circuit.print_html()
# %%
# TODO say how long should they spend on this?
"""
Exercise: inspect the circuit. When compared to GPT-2, this model is missing blocks (e.g., the MLP layers). In terms of the blocks that it does contain, how do they differ from the corresponding blocks in GPT-2?
We suggest getting practice looking at the printout. Note that in the printout, `orig_circuit.print()` with default options will omit repeated modules, so you only see the attention module expanded under `'b1'` and not under `'b0'`. You can also try `orig_circuit.print_html()`.
<details>
<summary>Can you give me a super specific hint about where to look?</summary>
Focus on the embeddings. Try this: `rc.Matcher("a.qk_input").get_unique(orig_circuit).print()`
</details>
<details>
<summary>Solution</summary>
The way positional embeddings work is different - the "shortformer" encoding is used. It's important to understand that in GPT, the position embedding is added to the token embedding at the very beginning, meaning positional information is mixed in with token information within the residual stream.
In shortformer, instead of multiplying the Q weight with the attention input, we multiplying the Q weight with (attention input + positional embedding). The same thing goes for K weight (but not V), and this applies at every layer. For more, see [the Shortformer paper](https://aclanthology.org/2021.acl-long.427.pdf).
</details>
"""
# %% [markdown]
"""
## Random variables and sampling
The circuits that we built yesterday were able to represent deterministic computational graphs. The circuits library also needs ways to deal with sampling from random distributions (and especially sampling over the dataset or some subset of it). In this context, we only need to deal with sampling from finite sets, so we will specialize to discrete random variables.
Here we take a brief look at how the circuits library deals with discrete random variables. In terms of what you need for the research phase, you need familiarity with how things are structured, but most of the code dealing with random variables and sampling will be provided as a library (which we'll see in tomorrow's material).
### Discrete variables
Discrete random variables are created in `rust_circuit` using the `rc.DiscreteVar` class. If you hover over `rc.DiscreteVar`, you'll see that it takes two arguments: a Circuit `values` specifying the values the variable can take, and another Circuit called `probs_and_group`. To specify a random variable we need to specify the values it can take and the probabilities of each of those outcomes. These correspond to the two arguments of `rc.DiscreteVar`, except that `probs_and_group` serves a second purpose, which is to track which random variables should be sampled together, as we'll see below.
Here are some simple examples. Follow along with this code:
"""
# %%
dice_values = rc.Array(torch.arange(1, 7).to(dtype=torch.float))
dice1 = rc.DiscreteVar(dice_values)
probs = dice1.probs_and_group.cast_tag().node
if MAIN:
print(f"The values the dice can take are: {dice_values.value}")
print("Not specifying group_and_probs gives you the uniform distribution.")
print(
"The default probs_and_group object is wrapped in a Tag (to add a UUID) "
f"but we can unwrap that to see the probabilities, which are {probs}."
)
# %% [markdown]
"""
We mentioned above that `probs_and_group` served a second purpose, which is to track which random variables should be sampled together. The idea here is that if you set the `probs_and_group` attribute to the same object in two different random variables, those variables will be sampled together, i.e., the samples will be perfectly correlated.
"""
# %%
dice2 = rc.DiscreteVar(dice_values, dice1.probs_and_group)
dice3 = rc.DiscreteVar(dice_values)
if MAIN:
print(f"dice2 will be perfectly correlated with dice")
print(
"By default, not specifying probs_and_group will give you a new uniform "
"distribution, so dice and dice3 are uncorrelated."
)
# %% [markdown]
"""
### Sampling discrete variables
To test these claims about correlation, we need to be able to sample these random variables.
To sample a random variable, the `rust_circuit` library uses an `rc.Sampler` object, which needs to be initialized with an `rc.SampleSpec` object.
Today you'll need `rc.RandomSampleSpec`, for when we want to sample randomly, and `rc.RunDiscreteVarAllSpec`, which ignores the probabilities and samples every input (which can be useful when trying to figure out what's going on with a circuit).
Follow along again:
"""
# %%
random_sampler = rc.Sampler(rc.RandomSampleSpec((10,)))
all_values_spec = rc.RunDiscreteVarAllSpec.create_full_from_circuits(dice1)
all_values_sampler = rc.Sampler(all_values_spec)
if MAIN:
for i in range(1, 4):
print(f"Dice {i}: ", random_sampler.sample(locals()[f"dice{i}"]).evaluate())
print("All values: ", all_values_sampler.sample(dice1.add(dice1)).evaluate())
# %% [markdown]
"""
Optional exercise: using these classes, estimate the expectation of (i) the value of `dice1` multiplied by the value of `dice2` and (ii) the value of `dice1` multiplied by the value of `dice3`.
"""
# %%
if "SOLUTION":
if MAIN:
random_sampler = rc.Sampler(rc.RandomSampleSpec((1000,)))
expectation_1_times_2 = random_sampler.sample(dice1.mul(dice2)).mean().evaluate()
expectation_1_times_3 = random_sampler.sample(dice1.mul(dice3)).mean().evaluate()
# %% [markdown]
"""
### Sampling the input dataset
Let's now apply all this to our dataset.
Run the code below and note that `toks_int_var` has a shape that represents a single sampled datum. It is not, however, explicitly computable. Intuitively it represents the random discrete variable over their input dataset.
We can however, sample from them using a `Sampler` object! This will add in a batch dimension we can evaluate them over.
You can also 'sample' every possible value of the discrete variable. The `group` argument tells it which set of random variables to sample (as mentioned above, all variables that share an identical `group` attribute are sampled together).
"""
toks_int_var = rc.DiscreteVar(toks_int_values, name="toks_int_var")
if MAIN:
print("Variable")
print(" Shape: ", toks_int_var.shape)
print(" Computable: ", toks_int_var.is_explicitly_computable)
# simplify=False is workaround for https://github.com/redwoodresearch/unity/pull/1973
sampled_var = rc.Sampler(rc.RandomSampleSpec((200,), simplify=False)).sample(toks_int_var)
if MAIN:
print("\nRandom samples:")
print(" Shape: ", sampled_var.shape)
print(" Computable: ", sampled_var.is_explicitly_computable)
group = toks_int_var.probs_and_group
on_all_var = rc.Sampler(rc.RunDiscreteVarAllSpec([group])).sample(toks_int_var)
if MAIN:
print("\n All samples:")
print(" Shape: ", on_all_var.shape)
print(" Computable: ", on_all_var.is_explicitly_computable)
# %%
# TBD later: move to day 1 for next batch?
"""
## Revisiting some `rust_circuit` concepts
### Slicing and indexing
What about if we want to get at individual tokens? Well for that we need `Indexer`s. Let's review.
In Day 1 we introduced the `slice` object as a way to represent indexing into a single dimension. Recall that you can use the slice constructor with two integers like `slice(start, stop)`. What does it do when you pass just one argument?
One might guess that you get `slice(start, None, None)`, but this is wrong - actually it is the same as `slice(None, stop, None)`. This is consistent with the way that the `range()` builtin works, but is still easy to trip over. If you want to avoid remembering this and also save a couple keystrokes, we have a helper object usually imported as `S`. You can use it like so:
"""
# %%
assert S[1:] == slice(1, None, None)
assert S[:1] == slice(None, 1, None)
assert S[2:3:-1] == slice(2, 3, -1)
if MAIN:
try:
S[2] # type: ignore
except Exception as e:
print("It is an error to pass a plain int to SLICER.")
try:
S[2:3, 4:5] # type: ignore
except Exception as e:
print("It is an error to pass multiple slices to SLICER.")
# %% [markdown]
"""
Analogously, a helper object usually imported as `I` provides an equivalent and succint way to generate representations of (possibly multi-dimensional) indexing.
These objects are commonly used with the `rc.Index` class: in the last line of the following block we show how to use an Indexer object with `rc.Index` to specialize `toks_int_var` to the first position.
"""
# %%
# TBD lowpri: this and the slicer could be a short exercise - "write an equivalent"
assert I[5] == (5,)
assert I[5, 6] == (5, 6)
assert I[5:6] == (slice(5, 6, None),)
assert I[5:6, 7:8] == (slice(5, 6, None), slice(7, 8, None))
first_tok_int_var = rc.Index(toks_int_var, I[0], name="first_tok_int_var")
# %% [markdown]
"""
### Modules
We saw Modules yesterday but they can be quite confusing because they are close in concept-space to a bunch of similar things, so it's worth recapping them. Modules are in the category of concepts that once you grok what's going on, it all becomes obvious, so if you notice that you are confused later in the day, please ask a TA to explain.
Yesterday we explained modules in the context of neural networks. Today we'll go a bit more abstract. Let's start with an analogy: When programming in Python, we can do things like add two tensors `t1` and `t2`, for which the syntax in Python is (obviously) `t_1 + t_2`. You know how to represent the same computation in `rust_circuit` land: we wrap the tensors in `rc.Array` and then combine them using an `rc.Add` circuit, i.e., `rc.Add(rc.Array(t1, name="t1"), rc.Array(t2, name="t2"), name="t1 + t2")`.
When programming in Python, it is also useful to use functions to encapsulate behaviour, e.g., `def plus(a, b): return a + b` (which you can also write in Python as `plus = lambda a, b: a + b`). The `rc.ModuleSpec` class in `rust_circuit` is a pretty direct mapping of the concept of a function in a programming language.
Let's get comfortable with this by creating a `rc.ModuleSpec` for the plus function defined in the previous paragraph.
We first need to make `rc.Symbol`s for the arguments. To keep things super-simple here, we'll make symbols with shape `(2,)`, i.e., they are vectors in 2-dimensions.
"""
# %%
a = rc.Symbol.new_with_random_uuid(shape=(2,), name="a")
b = rc.Symbol.new_with_random_uuid(shape=(2,), name="b")
"""
Then we need a `rc.Circuit` to represent the function body. For the `plus` function this is easy:
"""
spec_circuit = rc.Add(a, b, name="a + b")
"""
Finally, the `rust_circuit` representation of our plus function `plus = lambda a, b: a + b` is
"""
spec = rc.ModuleSpec(circuit=spec_circuit, arg_specs=[rc.ModuleArgSpec(a), rc.ModuleArgSpec(b)])
"""
Note that we needed to wrap the `Symbol`s in `rc.ModuleArgSpec`. Hover over that in VS Code to see the docstring and you'll see that this lets us specify whether we can run the function on batched inputs (we needed this yesterday), and whether we can pass arguments with different sizes (also needed yesterday).
So that's the abstract function. But when a function appears in a computational tree, it needs specific values for the arguments and that's essentially what an `rc.Module` is for: it's a function, i.e., an `rc.ModuleSpec`, together with a list of `rc.Circuits` that are bound to the arguments.
(Binding variables is a concept from logic. You might want to have a quick look at [wikipedia](https://en.wikipedia.org/wiki/Free_variables_and_bound_variables#Examples) if you haven't seen it before.)
Anyway, to see how this works, let's create some specific vectors and then bind them to the arguments.
"""
t1 = rc.Array(torch.Tensor([1, 0]), name="t1")
t2 = rc.Array(torch.Tensor([0, 1]), name="t2")
module = rc.Module(spec=spec, name="t1 + t2", a=t1, b=t2)
module.print_html()
# %%
"""
Study the printout and make sure you understand it **completely**, as you'll be looking at bunch of more complicated versions later today.
You'll see the `rc.Module` at top, which has three children: (i) its `rc.ModuleSpec`, (ii) the argument binding for `a`, and (iii) the argument binding for `b`.
An argument binding looks like `'t1' [2] Array ! 'a'`. You read this as: the value to the left of the '!' is bound to the symbol with the name to the right of the '!'.
Optional exercise: what happens if multiple Symbols have the same name `a`? What happens if there's a non-symbol with the name `a`?
<details>
<summary>
Solution
</summary>
The name on the right of the '!' refers to the argument with that name in the current module, i.e., the symbols with that name in the `rc.ModuleSpec` for that module. It's fine if there are symbols with the same name elsewhere, or non-symbols with that name.
The `rust_circuit` library requires that the names of arguments in an `rc.ModuleSpec` be unique, i.e., running the following will result in an exception:
```python
try:
a1 = rc.Symbol.new_with_random_uuid(shape=(2,), name="a")
a2 = rc.Symbol.new_with_random_uuid(shape=(2,), name="a")
spec_circuit = rc.Add(a1, a2, name="a + a")
spec = rc.ModuleSpec(circuit=spec_circuit, arg_specs=[rc.ModuleArgSpec(a1), rc.ModuleArgSpec(a2)])
except rc.ConstructModuleArgsDupNamesError as e:
print("Exception raised: ", e)
```
The bindings are also always listed in the same order as the symbols appear in the list `arg_specs` that we used to create the `rc.ModuleSpec`, a fact we'll need to exploit later.
</details>
<br />
Notice that our function is stored in unevaluated form. We can evaluate it, like any circuit:
"""
# %%
module.evaluate()
# %%
"""
There's another operation we'll need extensively today: `rc.Module.substitute`. Calling `substitute()` will 'dissolve' the module, substituting the symbols with the objects bound to them. This is the equivalent of inlining our Python `plus` function, i.e., replacing `plus(t1, t2)` with `t1 + t2`.
Try that now and see how the printout changes.
"""
# %%
module.substitute().print_html()
# %%
"""
Like functions in Python, modules can be nested, and so let's get experience looking at the printouts of those, as you'll be doing that a bunch later today.
"""
# %%
# TBD: we really want an example here where a symbol is bound to another symbol that's itself bound in an outer module.
plus1 = rc.Module(spec=spec, name="plus1", a=t1, b=t2)
plus2 = rc.Module(spec=spec, name="plus2", a=plus1, b=t2)
plus3 = rc.Module(spec=spec, name="plus3", a=plus1, b=plus2)
plus3.print_html()
# %%
"""
Again, study the printout. Some things to note:
* The copies of the spec are marked "(repeat)" and are not expanded by default.
* The second child of 'plus3' is printed as `'plus1' Module ! 'a'`, which means the output of the `plus1` module is bound
to the argument corresponding to the the `rc.Symbol 'a'`, just as we set it up.
Now let's call substitute on the `plus2` module, and look at the printout. What changed? Does it look as you expect?
"""
# %%
plus3_with_partial_substitution = rc.Module(spec=spec, name="plus3", a=plus1, b=plus2.substitute())
plus3_with_partial_substitution.print_html()
# %%
# %% [markdown]
"""
## Evaluating model performance: constructing the loss
In this section, we'll construct the loss that our network was trained to minimize. Of course, we want to construct this as a circuit as well!
### Inputs and targets
Our model was trained to do next-token prediction.
Define `input_toks` to be all tokens except the last position, and `true_tokens` to be the corresponding ground truth next tokens (for every sequence position, not just for the last position). This is a good time to practice using `I`.
"""
# %%
input_toks: rc.Index
true_toks: rc.Index
if "SOLUTION":
input_toks = toks_int_var.index(I[:-1], name="input_toks_int")
true_toks = toks_int_var.index(I[1:], name="true_toks_int")
assert input_toks.shape == (seq_len,) # type: ignore
assert true_toks.shape == (seq_len,) # type: ignore
# %% [markdown]
"""
### Model Binding
The method `rc.get_free_symbols` helpfully tells us which `Symbol`s haven't been bound yet. This is useful for debugging purposes.
"""
if MAIN:
print("Free symbols: ")
pprint(rc.get_free_symbols(orig_circuit))
# %% [markdown]
"""
It's time to bind these free symbols. The function `rc.module_new_bind` is just a more succint way to create a `Module` instance then calling the `Module` constructor. You pass tuples containing symbol names and the values to bind and away you go! Note that this doesn't modify `orig_circuit`.
The node "t.input" represents the embedded tokens just like GPT, "a.mask" is the causal mask just like GPT, and "a.pos_input" is computed the same way as in GPT, but again in shortformer it will be used differently by the model.
Exercise: explain to your partner in your own words how "a.pos_input" will be used.
"""
# %%
idxed_embeds = rc.GeneralFunction.gen_index(tok_embeds, input_toks, index_dim=0, name="idxed_embeds")
assert extra_args.causal_mask, "Should not apply causal mask if the transformer doesn't expect it!"
causal_mask = rc.Array(
(torch.arange(seq_len)[:, None] >= torch.arange(seq_len)[None, :]).to(tok_embeds.cast_array().value),
f"t.a.c.causal_mask",
)
assert extra_args.pos_enc_type == "shortformer"
pos_embeds = pos_embeds.index(I[:seq_len], name="t.w.pos_embeds_idxed")
model = rc.module_new_bind(
orig_circuit,
("t.input", idxed_embeds),
("a.mask", causal_mask),
("a.pos_input", pos_embeds),
name="t.call",
)
assert model.are_any_found(orig_circuit)
assert not rc.get_free_symbols(model)
# TBD: add a blurb about gen_index_0_c being the spec of the gen_index fn
loss = rc.Module(
negative_log_likelyhood.spec,
**{"ll.input": model, "ll.label": true_toks},
name="t.loss",
)
# %% [markdown]
"""For today's work, we only want to compute loss on the good induction candidates:"""
# %%
is_good_induction_candidate = rc.GeneralFunction.gen_index(
x=rc.Array(good_induction_candidate, name="tok_is_induct_candidate"),
index=input_toks,
index_dim=0,
name="induct_candidate",
)
loss = rc.Einsum(
(loss, (0,)),
(is_good_induction_candidate, (0,)),
out_axes=(0,),
name="loss_on_candidates",
)
# %% [markdown]
"""
### Cumulants
A cumulant is a concept in probability theory, but you don't need to know anything about cumulants right now. The one relevant fact for today is that the "first cumulant" of a distribution is just the regular old mean of a distribution that you already know about. (Higher order cumulants come up in [other research(https://arxiv.org/abs/2210.01892) done at Redwood.)
Right now, our `loss` node depends on the input `DiscreteVar`s. Since these are random variables, our loss will also being a random variable. By wrapping `loss` in an `rc.Cumulant`, we're saying that we will be interested in the mean loss over the input distribution.
This cumulant will have shape `(seq_len,)` since we're computing the loss at every position. We then take the mean to get the average loss per model prediction (just like regular LM loss).
"""
# %%
# TBD: say more about is_good_induction mask & mean over seqpos.
# Highlight that this is like LM loss but only on specific tokens
# I think instead of mean we should sum, and divide by the number of predictions made (because many positions are masked out).
expected_loss_by_seq = rc.Cumulant(loss, name="t.expected_loss_by_seq")
expected_loss = expected_loss_by_seq.mean(name="t.expected_loss", scalar_name="recip_seq")
# %%
printer = rc.PrintHtmlOptions(
shape_only_when_necessary=False,
traversal=rc.new_traversal(
term_early_at=rc.Regex(r"a\.*.w.\.*ind")
| rc.Matcher(
{
"b",
"final.norm",
"idxed_embeds",
"nll",
"t.w.pos_embeds_idxed",
"true_toks_int",
"induct_candidate",
}
)
),
comment_arg_names=True,
)
if MAIN:
expected_loss.print(printer)
# %% [markdown]
"""
## Causal scrubbing
Congratulations! You made it through the prepatory work. It's finally time to do causal scrubbing!
Recall from the writeup that we'll be running our model on two inputs. One is the original input, and we'll use its next tokens to compute the loss. The other is the random other input, and we'll run the parts of the model we claim don't matter on this one.
Exercise: make another `DiscreteVar`, `toks_int_var_other` that will be uncorrelated with `toks_int_var`.
"""
# %%
if "SOLUTION":
toks_int_var_other = rc.DiscreteVar(toks_int_values, name="toks_int_var_other")
if MAIN:
print("Your names should match these to make later validation much easier:")
assert toks_int_var.name == "toks_int_var"
assert toks_int_var_other.name == "toks_int_var_other"
# %% [markdown]
"""
Quick test:
"""
def seeder(c: rc.Circuit) -> int:
"""
Just a silly way to get two fixed seeds.
Setting seeds for consistent results between runs of this notebook.
"""
if c == toks_int_var.probs_and_group:
return 11
elif c == toks_int_var_other.probs_and_group:
return 22
else:
raise ValueError(
"Expected one of the probs_and_group we constructed earlier, but got something else!",
c,
)
sampler = rc.Sampler(rc.RandomSampleSpec((200,), simplify=False, seeder=seeder))
assert (
torch.corrcoef(
torch.stack(
(
(
sampler.sample(toks_int_var).evaluate()[:, 10],
sampler.sample(toks_int_var_other).evaluate()[:, 10],
)
)
)
)[0, 1]
< 0.1
)
# %% [markdown]
"""
### Setting up sampler
"""
# %%
def sample_and_evaluate(c: rc.Circuit, num_samples: int = 16 * 128, batch_size=32) -> float:
"""
More samples is better! It'll just take (linearly) longer to run.
(In this notebook we aren't calculating error bars, but you're welcome to do so.)
"""
def run_on_sampled(c: rc.Circuit) -> rc.Circuit:
"""
Function for sampler to run after sampling (before we evaluate the resulting circuit).
batch_to_concat breaks up num_samples dim into batches of batch_size so they can be evaluated separately (and not run out of memory).
substitute_all_modules gets rid of Module nodes; compiler complains if you don't do this today, sorry.
"""
return rc.batch_to_concat(
rc.substitute_all_modules(c),
axis=0,
batch_size=batch_size,
)
sampler = rc.Sampler(rc.RandomSampleSpec((num_samples,), seeder=seeder), run_on_sampled=run_on_sampled)
estim = sampler.estimate(c)
return rc.optimize_and_evaluate(estim).item()
# %% [markdown]
"""
### Custom printing
Below is a helpful printer: it will color things getting the random input red, things getting the original input blue, things getting both purple, and things getting neither grey. It's good practice to play around with printing until you can clearly see what's going on in your Circuit.
"""
# %%
scrubbed = lambda c: c.are_any_found(toks_int_var_other)
not_scrubbed = lambda c: c.are_any_found(toks_int_var)
def scrub_colorer(c):
getting_scrubbed = c.are_any_found(toks_int_var_other)
getting_unscrubbed = c.are_any_found(toks_int_var)
if getting_scrubbed and getting_unscrubbed:
return "purple"
elif getting_scrubbed:
return "red"
elif getting_unscrubbed:
return "cyan"
else:
return "lightgrey"
scrubbed_printer = printer.evolve(
colorer=scrub_colorer,
traversal=rc.restrict(
printer.traversal,
term_early_at=lambda c: not (c.are_any_found(toks_int_var_other)),
),
)
# %%
unscrubbed_out = sample_and_evaluate(expected_loss, 16 * 128)
if MAIN:
print(f"Loss with no scrubbing: {unscrubbed_out:.3f}")
assert_close(unscrubbed_out, 0.17, atol=1e-2, rtol=1e-3)
# %% [markdown]
"""
## Establishing a Baseline
### Scrubbing all inputs
When scrubbing, we want to compute our "percent loss recovered" as a metric. While this is generally sensible, it isn't completely satisfactory for various reasons. The metric can go over 100%, and it feels like researchers can [Goodhart](https://en.wikipedia.org/wiki/Goodhart%27s_law) the metric. We're thinking about ways to make this more valid involving having an adversary, but for now we'll just take the metric as an indicator that provides some evidence where higher (up to 100%) is better.
In the [Causal Scrubbing Appendix](https://www.lesswrong.com/posts/kcZZAsEjwrbczxN2i/causal-scrubbing-appendix#2_1__Percentage_of_loss_recovered__as_a_measure_of_hypothesis_quality), we gave a formula for this using a baseline where the inputs are scrubbed.
Concretely, we run our model on random inputs, while computing the loss w.r.t. the original labels. This isn't actually the baseline we will use (we'll explain more later) but it's a good warm-up.
Exercise: implement `scrub_input` and use it to replace the inputs to the model with uncorrelated inputs. You need to define the `rc.IterativeMatcher` `unused_baseline_path` so that `scrub_input` replaces the inputs but not the labels.
"""
# %%
def scrub_input(c: rc.Circuit, in_path: rc.IterativeMatcher) -> rc.Circuit:
"""Replace all instances of `toks_int_var` descended from in_path with `toks_int_var_other`"""
"SOLUTION"
# Because of how the tests for these exercises are structured, we want to match on the name
# (which will be the same between runs of the nb) rather than on hash (which will be different due to random probs_and_group).
return in_path.chain("toks_int_var").update(c, lambda _: toks_int_var_other)
if "SOLUTION":
unused_baseline_path = rc.IterativeMatcher("t.call")
else:
unused_baseline_path: rc.IterativeMatcher
"""TODO: YOUR CODE HERE"""
if MAIN:
tests.test_all_inputs_matcher(unused_baseline_path, expected_loss)
unused_baseline = scrub_input(expected_loss, unused_baseline_path)
# %%
if MAIN:
expected_loss.print_html()
# %%
"""
Take a look at this print: does it look like what you expected?
"""
if MAIN:
scrubbed_printer.print(unused_baseline)
# %%
unused_baseline_out = sample_and_evaluate(unused_baseline)
if MAIN:
print(f"Loss with scrubbing the whole model: {unused_baseline_out:.3f}")
assert_close(unused_baseline_out, 0.81, atol=1e-2, rtol=1e-3)
# %% [markdown]
"""
### Rewriting the model to split up the heads
The actual baseline we want to use isn't random inputs to everything, but only random inputs to the induction heads. This represents (very roughly) a model that is working normally except that induction is "disabled".
We want to be able to pass different inputs into the induction heads and the other heads. To do this, we'll rewrite our transformer so that there's a node named "a1.ind" consisting of just heads 1.5 and 1.6, and a node "a1.not_ind" (called "a1 other" in the writeup) consisting of the other layer 1 heads.
For future experiments, we'll also want to separate the "previous token head" 0.0 into its own node named "a0.prev", and call the other layer 0 heads "a0.not_prev".
Exercise: read through the source code for `configure_transformer` and figure out how to call it so that these heads are split up. We're expecting you to use `use_pull_up_head_split=True`; the other arguments you should be able to figure out. Verify the printed circuit looks reasonable.
Warning: the tests here can be very finicky -- in particular, there are several ways to write the split_by_head config that are equivilant in meaning but the test will reject for silly reasons (e.g. S[0] != S[:1]).
"""
# %%
# TBD: how frustrating is this exercise?
if "SOLUTION":
by_head = configure_transformer(
expected_loss.get_unique("t.bind_w"),
to=To.ATTN_HEAD_MLP_NORM,
split_by_head_config={
0: [(0, "prev"), (S[1:], "not_prev")],
1: [
(S[5:7], "ind"),
(torch.tensor([0, 1, 2, 3, 4, 7]).to(model.device), "not_ind"),
],
},
use_pull_up_head_split=True,
check_valid=True,
)
else:
split_by_head_config = """TODO: YOUR CODE HERE""" # type: ignore
by_head = configure_transformer(
expected_loss.get_unique("t.bind_w"),
to=To.ATTN_HEAD_MLP_NORM,
split_by_head_config=split_by_head_config, # type: ignore
use_pull_up_head_split=True,
check_valid=True,
) # type: ignore
"""Sanity checks"""
assert by_head.get({"a1.ind", "a1.not_ind", "a0.prev", "a0.not_prev"}, fancy_validate=True)
"""Bit of tidying: renames, and replacing symbolic shapes with their numeric values"""
by_head = by_head.update(lambda c: ".keep." in c.name, lambda c: c.rename(c.name.replace(".keep.", ".")))
by_head = rc.conform_all_modules(by_head)
printer = printer.evolve(
traversal=rc.restrict(
printer.traversal,
term_early_at=rc.Matcher({"b0", "a1.norm", "a.head"}) | rc.Regex(r"\.*not_ind\.*"),
)
)
if MAIN:
printer.print(by_head)
tests.test_by_head(by_head)
print("Updating expected_loss to use the by_head version")
expected_loss = expected_loss.update("t.bind_w", lambda _: by_head)
# TODO: duplicated
"""replace symbolic shapes with their real values"""
expected_loss = rc.conform_all_modules(expected_loss)
# %% [markdown]
"""
### Substitution
Note: this section can be pretty confusing! Take it slowly, talk it through with your partner, and be willing to ask call a TA to help explain things!
Let us focus on one particular part of the circuit, a module called "a.head.on_inp" inside of our induction heads:
"""
ind_head_on_inp_subcircuit = expected_loss.get_unique("a1.ind").get_unique("a.head.on_inp")
if MAIN:
ind_head_on_inp_subcircuit.print_html(printer)
# %% [markdown]
"""
Recall that within a module `circ_a Add ! sym_a Symbol` means the module binds the value `circ_a` (an Add node) to the symbol `sym_a`, which is then required to appears in that module's spec.
Some things to notice:
- This module is representing our induction heads as a function that takes three inputs: `a.q.input`, `a.k.input`, and `a.v.input`. These inputs are used to form the queries, keys, and values respectfully.
- Normally when we run an attention head we use the same `[seq_len, hidden_size]` matrix as inputs to all three of these. However, it is possible to run the attention head on three different inputs! In fact this is necessary to replicate the causal scrubbing experiments.
- This `a.head.on_inp` module is responsible for binding these three inputs. It binds both the query and key inputs to a simple circuit which adds two symbols: `a.input` (representing the input to the attention head) and `pos_input` (represeting the positional embeddings). The value input is bound to the same `a.input`.
We want to be able to replicate an experiment where we change some of the tokens that are upstream of the value-input to the induction heads, but not change either the query-input or key-input. This would require writing an Iterative Macher that can match paths through the value-input but not the other two.
Unfortunately there is no way to do this with the current circuit. While we could write a matcher that matches only one copy of the `a.input` symbol, there's no way to chain that matcher to upstream to the embeddings. It's just a symbol, there are no embeddings upstream!
In this section we will rewrite the model so that writing these sorts of matchers is possible.
"""
# %% [markdown]
"""
So what is this `a.input` symbol? What is it doing here? And where are the embeddings?
To answer that we need to zoom out from this particular a.head.on_inp module.
A more complete sub-circuit representing the induction head would look something like this:
```
mystery_module Module
spec of mystery_module
...
a.head.on_inp Module
a.head Einsum
...
a.qk_input Add ! a.q.input
a.input
pos_embeds
a.qk_input Add ! a.k.input
a.input
pos_embeds
a.input ! a.v.input
output_of_a1_ln ! a.input
```
Here `mystery_module` is binding the `a.input` symbol to the `output_of_a1_ln` circuit.
"""
# %% [markdown]
"""
So now how can we fix the problem described above (that we can't write an iterative matcher to perform the update we want)? Well, it should be possible to rewrite the circuit shown above into this form instead:
```
spec of mystery_module
...
a.head.on_inp Module
a.head Einsum
...
a.qk_input Add ! a.q.input
output_of_a1_ln
pos_embeds
a.qk_input Add ! a.k.input
output_of_a1_ln
pos_embeds
output_of_a1_ln ! a.v.input
```
Then we could perform the above update on the copy of `a1.norm` that is bound to `a.v.input`!
This is analogous to the transformation between
```
polynomial_function = lambda x: x**2 + 3*x + 4
polynomial_function(23)
```
into `23**2 + 3*23 + 4`: we transform a function call (`polynomial_function(23)`) into the body of the function (`x**2 + 3*x + 4`), with the input symbol (`x`) replaced by it's bound value (23).
This transformation can be achieved by calling `.substitute()` on the `'a1.ind_sum.norm_call'` module!. This eliminates the module and performs the transformation above!
(you may need to use `cast_module()` to convince your typechecker that `a1.ind_sum.norm_call` really is a Module!)
"""
# %%
"""
Okay, enough talking! Time to peform this substitution.
Exercise: First, figure out which module bind through the circuit to find the identity of `a.input`.
Do this by examining the below print and looking for a line that ends in `! 'a.input'` as this is the symbol we are trying to substitute.
The print below only focuses on the sub-circuit representing the attention head.
"""
if MAIN:
expected_loss.get_unique("b1.a.ind_sum").print_html(printer)
"""
<details>
<summary>Solution</summary>
The module `a1.ind_sum.norm_call` binds `a1.norm` to `a.input`.
</details>
"""
# %% [markdown]
"""
Now use `.update` to call substitute on this module!
"""
with_a1_ind_inputs_v1 = expected_loss
if "SOLUTION":
with_a1_ind_inputs_v1 = with_a1_ind_inputs_v1.update("a1.ind_sum.norm_call", lambda c: c.cast_module().substitute())
ind_head_on_inp_subcircuit_v1 = with_a1_ind_inputs_v1.get_unique("a1.ind").get_unique("a.head.on_inp")
ind_head_on_inp_subcircuit_v1.print_html(printer.evolve(traversal=rc.new_traversal(term_early_at={"a.head", "ln"})))
assert "a.input" not in [symb.name for symb in rc.get_free_symbols(ind_head_on_inp_subcircuit_v1)]
# %%
"""
Unfortunately we aren't done yet. There is still a single symbol that forces all three inputs to be the same. What is that symbol? What module binds it? Once again, figure out what module is binding it and substitute that module.
"""
if MAIN:
with_a1_ind_inputs_v1.get_unique("b1.a.ind_sum").print_html(printer)
with_a1_ind_inputs_v2 = with_a1_ind_inputs_v1
if "SOLUTION":
with_a1_ind_inputs_v2 = with_a1_ind_inputs_v2.update("b1.a.ind_sum", lambda c: c.cast_module().substitute())
if MAIN:
ind_head_on_inp_subcircuit_v2 = with_a1_ind_inputs_v2.get_unique("a1.ind").get_unique(
rc.restrict("a.head.on_inp", end_depth=2)
)
printer.print(ind_head_on_inp_subcircuit_v2)
assert set(symb.name for symb in rc.get_free_symbols(ind_head_on_inp_subcircuit_v2)) == {
"a.w.q_h",
"a.w.k_h",
"a.mask",
"a.w.v_h",
"a.w.o_h",
"a.pos_input",
"t.input",
}
# %%
"""
We're very close! There's still one symbol that is shared across all three inputs that prevents us from changing the token embeddings through one path but not the other.
"""
if MAIN:
with_a1_ind_inputs_v2.get_unique("b1.a.ind_sum").print_html(printer)
with_a1_ind_inputs_v3 = with_a1_ind_inputs_v2
if "SOLUTION":
with_a1_ind_inputs_v3 = with_a1_ind_inputs_v3.update("t.call", lambda c: c.cast_module().substitute())
if MAIN:
ind_head_on_inp_subcircuit_v3 = with_a1_ind_inputs_v3.get_unique("a1.ind").get_unique(
rc.restrict("a.head.on_inp", end_depth=2)
)
printer.print(ind_head_on_inp_subcircuit_v3)
print("Checking various things...")
assert set(symb.name for symb in rc.get_free_symbols(ind_head_on_inp_subcircuit_v3)) == {
"a.w.q_h",
"a.w.k_h",
"a.w.v_h",
"a.w.o_h",
}
a1_ind = with_a1_ind_inputs_v3.get_unique("a1.ind")
assert not rc.get_free_symbols(a1_ind), "there should be no free symbols in a1!"
assert a1_ind.are_any_found(
toks_int_var
), "toks_int_var should appear at least one in the subcircuit rooted at a1.ind"
tests.test_with_a1_ind_inputs(with_a1_ind_inputs_v3)
print("Checks passed! Well done!!")
with_a1_ind_inputs = with_a1_ind_inputs_v3
"""
We can now progress on to replicating experiments.
"""
# %% [markdown]
"""
### Scrubbing Induction Heads
To complete the baseline section, we need to run the model where the induction heads are "scrubbed" - all inputs to the induction heads are replaced with inputs chosen randomly.
"""
scrubbed_ind = scrub_input(with_a1_ind_inputs, rc.IterativeMatcher("a1.ind"))
if MAIN:
scrubbed_printer.print(scrubbed_ind)
baseline_out = sample_and_evaluate(scrubbed_ind)
if MAIN:
print(f"Loss with induction heads scrubbed: {baseline_out:.3f}")
def loss_recovered(l):
return (l - baseline_out) / (unscrubbed_out - baseline_out)
# %% [markdown]