-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathremix_d5_part3_solution.py
1602 lines (1209 loc) · 66.5 KB
/
remix_d5_part3_solution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# %%
"""
# REMIX Day 5, Part 3 - creating your tools
This notebook gives you a tour of the classic tools to use to explore language model internals. We'll build the tools ourselves and run our first experiments. We'll also learn to interpret (sometimes confusing) experimental results.
<!-- toc -->
## Learning Objectives
After today's material, you should be able to:
- Understand the motivation and the implementation of path patching
- Use `iterative_path_patching` to implement your own experiments
- Understand the motivation behind "moving pieces experiments" and their implementation
- Gather attention patterns and visualize them with the CUI
- Use helper functions to handle grouping of nodes and creation of causal scrubbing hypothesis and matchers
## Readings
- [Exploratory interp exercises presentation](https://docs.google.com/document/d/1qyHT4W9TtVL77AMKN514SjXT9fyNS70DJH9FFQ7YiDg/edit?usp=sharing)
- [Introduction to path patching](https://docs.google.com/document/d/1FWJUwnD50-IMrr92K6w3LIaAjhJp-HBN7ixvXBJM70o/edit?usp=sharing)
* The [slides from the lecture](https://docs.google.com/presentation/d/13Bvmo8E6N5qhgj1yCXq5O7zNRzNNXZLzexlgdzdgZ_E/edit?usp=sharing) for the terminology.
"""
import os
import sys
import torch
from rust_circuit.causal_scrubbing.dataset import Dataset
from rust_circuit.causal_scrubbing.experiment import (
Experiment,
ExperimentEvalSettings,
)
from rust_circuit.ui import cui
from rust_circuit.ui.very_named_tensor import VeryNamedTensor
# %%
from remix_d5_utils import (
IOIDataset,
load_and_split_gpt2,
load_logit_diff_model,
)
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
print("Device: ", DEVICE)
MAIN = __name__ == "__main__" # this notebook will be imported. We don't want to run long experiment during import
if "SKIP":
# Skip CI for now - avoids downloading GPT2
IS_CI = os.getenv("IS_CI")
if IS_CI:
sys.exit(0)
# TBD: remove this as it breaks CI?
# if "SKIP":
# if MAIN:
# get_ipython().run_line_magic("load_ext", "autoreload")
# get_ipython().run_line_magic("autoreload", "2")
import time
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, cast
import plotly.express as px
import rust_circuit as rc
import torch
import torch as t
from rust_circuit.causal_scrubbing.hypothesis import (
CondSampler,
Correspondence,
ExactSampler,
InterpNode,
UncondSampler,
chain_excluding,
corr_root_matcher,
)
import remix_utils
from remix_d5_utils import HeadOrMlpType, AttnSuffixForGpt, MLPHeadAndPosSpec
# %%
"""
## Claim 1
We take as a running example the claim: "Attention heads directly influencing the logits are either not influencing IO and S logits, or are increasing the IO logits more than the S.". At the end of this notebook, you should be able to have a nuanced view of what this claim means and to what extent it is correct.
### Setup
Before that, we'll need to import the model and the dataset using the code we wrote in the first two notebooks.
### Creating the dataset.
Despite our object being able to support multiple templates that are not aligned (e.g. position of IO varies from sequence to sequence), we will only use one template for this demo.
Thus, the name position is the same for all sequences. Because sentences are aligned, we can define global variables for the position of the tokens.
"""
# %%
ioi_dataset = IOIDataset(prompt_type="BABA", N=50, seed=42, nb_templates=1, device=DEVICE)
MAX_LEN = ioi_dataset.prompts_toks.shape[1]
for k, idx in ioi_dataset.word_idx.items(): # check that all the sentences are aligned
assert (idx == idx[0]).all()
END_POS = int(ioi_dataset.word_idx["END"][0].item())
IO_POS = int(ioi_dataset.word_idx["IO"][0].item())
S1_POS = int(ioi_dataset.word_idx["S1"][0].item())
S2_POS = int(ioi_dataset.word_idx["S2"][0].item())
# %%
"""
### Defining Dataset Variation
We use the `gen_flipped_prompts` method to create datasets we will use in this notebook. We defined them at the start of the notebook, so we make sure that they are the same for all experiments (to avoid the case where running a cell twice leads to different results).
Exercise: print the first 5 sentences of the `flipped_IO_dataset` and `flipped_S_dataset` and make sure you understand what information they hold. What are the families of all these datasets?
"""
# %%
# a dataset where the IO token is fipped
flipped_IO_dataset = ioi_dataset.gen_flipped_prompts("IO")
flipped_S_dataset = ioi_dataset.gen_flipped_prompts("S")
flipped_IO_S_dataset = ioi_dataset.gen_flipped_prompts("IO").gen_flipped_prompts("S")
flipped_IO_S1_order = ioi_dataset.gen_flipped_prompts("order")
# %%
"""
### Model Loading
We will then load the model using the steps described previously. First the main circuit, then the logit diff (ld for short) circuit for path patching experiments. `group` will be used to keep the labels in sync with the inputs.
"""
circuit = load_and_split_gpt2(MAX_LEN)
io_s_labels = torch.cat([ioi_dataset.io_tokenIDs.unsqueeze(1), ioi_dataset.s_tokenIDs.unsqueeze(1)], dim=1)
ld_circuit, group = load_logit_diff_model(circuit, io_s_labels)
ld_circuit = rc.cast_circuit(ld_circuit, rc.TorchDeviceDtypeOp(device=DEVICE))
# %%
"""
It's always important to check that our model is working as expected before running any experiments.
"""
c = ld_circuit.update(
"tokens",
lambda _: rc.DiscreteVar(rc.Array(ioi_dataset.prompts_toks.to(DEVICE), name="tokens"), probs_and_group=group),
)
if MAIN:
transform = rc.Sampler(rc.RunDiscreteVarAllSpec([group]))
results = transform.sample(c).evaluate()
print(f"Logit difference for the first 5 prompts: {results[:5]}")
print(f"Average logit difference: {results.mean()} +/- {results.std()}")
ref_ld = results.mean()
assert ref_ld > 2.5 and ref_ld < 4 # usual range
# %%
"""
## Experiments
Now that we're all set, let's think about experiments!
We want to prove (or disprove) the claim: "Attention heads directly influencing the logits are either not influencing IO and S logits, or are increasing the IO logits more than the S."
Pause for a moment to think about what would be the first step to test this claim.
* How can you divide this claim into smaller claims?
* What is the first experiment you want to run?
<details>
<summary>Solution</summary>
* First, we have to find heads that directly influence the logits. We're looking at their effect by only considering the final layer norm and unembeddings as intermediate computations.
* Then, we need to identify the *direction* of the effect of these heads on the IO and S logits. A way to summarize this is to look at the logit difference.
Experimentally, we can implement this using the following techniques:
* Path patching to the logits
* Simple causal scrubbing hypothesis
* Projecting the output of the head using the unembedding matrix (the logit lens).
Here we'll focus on path patching. This technique is the most well-suited for this kind of experiment. Go through the docs (including the exercises!) introducing [path patching](https://docs.google.com/document/d/1FWJUwnD50-IMrr92K6w3LIaAjhJp-HBN7ixvXBJM70o/edit?usp=sharing) and then come back here.
</details>
"""
# %%
"""
## Build your tools
### Path patching
Let's implement path patching. From the description in the doc above, you can have the feeling that it looks like causal scrubbing. However, the implementation will be much simpler because there is no complicated dataset computation: only two are necessary. The similarity is the operation "change the input to one branch, but not to another".
Note: in this document, we'll use "matcher" to sometimes refer to `Matcher` and sometimes to `IterativeMatcher` - hopefully it will be clear from the context.
In this section we will create two main functions:
* `path_patching` takes as input a matcher (specifying the path to patch) and returns the patched circuit where inputs through the path are replaced by the patched input and all other inputs are set to a baseline input.
Once we can perform path patching, we are often interested in answering questions of the form "Given that I found the influence A->B, what is directly influencing A?". To answer this question, we will iterate over each node N that comes before A and run the path patching `N->A->B` (note that we are filtering downstream effect because B appears). Then we select the N that leads to the greater effect size: they are the ones directly influencing A. This is the motivation behind `iterative_path_patching`: make it easy to iterate over candidate steps like "connect N to A" extending already known paths.
* Concretely, the idea of `iterative_path_patching` is to expand hypothesis H by making some nodes grow "by one step". For instance, "grow by one step" can mean "starting on node A, add the direct path to head 5.2", where A is a parameter. This "grow by one step" operation is implemented using _matcher extenders_. A matcher extender is a function of type signature `IterativeMatcher -> IterativeMatcher`. `iterative_path_patching` takes as arguments a list of matcher extenders. In practice, a single matcher extender consists of a `.chain(node)` operation where `node` is a fixed node.
The loop implemented by `iterative_path_patching` is:
```
For each matcher extender E:
patching_matcher = empty matcher
For each node to connect in H:
patching_matcher = patching_matcher U E(node)
path_patching(patching_matcher)
```
Technical detail: there is no straightforward way to define an empty matcher, you might want to initialize `patching_matcher` with the first matcher extender.
In the code, we use causal scrubbing hypotheses for convenience, but they are never run. We use them as an easy way to store matchers and to specify which node we want to connect.
If `hypothesis` is a `Correspondence`, you can access the matcher of the `InterpNode` `node` using `hypothesis.corr[node]`.
Exercise: Implement `path_patching` and `iterative_path_patching` using the following skeleton. We provide you with the function `replace_inputs` that will replace the input starting from a given matcher by a `DiscreteVar` sampling from a given `Tensor`.
* For `path_patching` you can use a different `array_suffix` in `replace_inputs` when you replace by the baseline data and patched data. Then by inspecting the graph you can see if the correct inputs are replaced.
* For `iterative_path_patching` you have to implement the inner loop. You also have to increment the variable `nb_not_found` to count the number of times the inputs are not found after applying the matcher extender. This can happen if the extender tries to connect a node N that is at a later layer than the node to A. In this case, the path `N->A` does not exist.
In the next cell, you can find the definition of matcher extender to debug your implementation.
"""
# %%
def replace_inputs(
c: rc.Circuit,
x: torch.Tensor,
input_name: str,
m: rc.IterativeMatcher,
group: rc.Circuit,
array_suffix: str = "_array",
):
"""
Replace the input on the model branch define by the matcher `m` with a DiscreteVar.
The input in the circuit `c` are expected non batched.
"""
assert x.ndim >= 1
c = c.update(
m.chain(input_name),
lambda _: rc.DiscreteVar(
rc.Array(x, name=input_name + array_suffix),
name=input_name,
probs_and_group=group,
),
)
return c
def path_patching(
circuit: rc.Circuit,
baseline_data: torch.Tensor,
patch_data: torch.Tensor,
matcher: rc.IterativeMatcher,
group: rc.Circuit,
input_name: str,
) -> rc.Circuit:
"""Replace the input connected to the paths matched by `matcher` with `patch_data`. All the other inputs are replaced with `baseline_data`.
Return the patched circuit where inputs are DiscreteVar using the sampling group `group`."""
"SOLUTION"
baseline_circuit = replace_inputs(
circuit,
baseline_data,
input_name,
corr_root_matcher,
group,
array_suffix="_baseline",
)
if len(matcher.get(circuit)) == 0:
return baseline_circuit
patched_circuit = replace_inputs(
baseline_circuit,
patch_data,
input_name,
matcher,
group,
array_suffix="_patched",
)
return patched_circuit
def iterative_path_patching(
circuit: rc.Circuit,
hypothesis: Correspondence,
nodes_to_connect: List[InterpNode],
baseline_data: torch.Tensor,
patch_data: torch.Tensor,
group: rc.Circuit,
matcher_extenders: List[Callable[[rc.IterativeMatcher], rc.IterativeMatcher]],
input_name: str,
output_shape: Optional[Tuple[int, ...]] = None,
) -> torch.Tensor:
"""
This function apply a set of `matcher_extenders` the matchers from `nodes_to_connect` in the `hypothesis`. The result is the concatenation of the circuit outputs after the application of each matcher_extender.
* circuit - the circuit to patch
* hypothesis - a causal scrubbing hypothesis. No causal scrubbing method is used in the code, it's just a convenient way to store matchers linked to nodes and limit the number of objects to track.
* nodes_to_connect - the InterpNode from the hypothesis where we want to expand
* baseline_data - the baseline data to use for the replacement
* patch_data - the patch data to use for the replacement
* group - the group for the DiscreteVar
* matcher_extenders - a list of function that take a matcher and return a new matcher. This define in which way we want to make the hypothesis grow. FOr example, one element of this list is a `.chain` operation that matches one specific attention head by the most direct path.
* input_name - the name of the input in the circuit
* output_shape - Optional reshaping of the result. If None, the output shape is `(len(matcher_extenders)) + the shape of the circuit ouput (can be different from `circuit.shape` if there is a batch dimension added).
"""
t1 = time.time()
circuits = []
sampler = rc.Sampler(rc.RunDiscreteVarAllSpec([group]))
nb_not_found = 0
for matcher_extender in matcher_extenders:
if "SOLUTION":
matchers_to_h = []
for node in nodes_to_connect:
matchers_to_h.append(matcher_extender(hypothesis.corr[node]))
union_matcher = matchers_to_h[0]
for matcher in matchers_to_h[1:]:
union_matcher = union_matcher | matcher
if len(union_matcher.get(circuit)) == 0:
nb_not_found += 1
patched_circuit = path_patching(circuit, baseline_data, patch_data, union_matcher, group, input_name)
else:
raise NotImplementedError("Inner loop not implemented!")
patched_circuit = sampler(patched_circuit) # we replace discrete vars by the real arrays
circuits.append(patched_circuit)
if nb_not_found > 0:
print(f"Warning: No match found for {nb_not_found} matcher extenders")
# a fancy function to evaluate fast many circuit that share tensors in common
results = rc.optimize_and_evaluate_many(
circuits,
rc.OptimizationSettings(scheduling_simplify=False, scheduling_naive=True),
)
t2 = time.time()
print(f"Time for path patching :{t2 - t1:.2f} s")
if output_shape is None:
return torch.cat([x.unsqueeze(0) for x in results], dim=0)
return torch.cat(results).reshape(output_shape)
# %%
matcher = rc.Matcher("final.input").chain(
rc.restrict(
rc.Matcher("a2.p_bias"), # arbitrary target
start_depth=1,
end_depth=2,
)
)
# a matcher that match all the paths that are not taken by `matcher`.
complement_matcher = rc.Matcher("final.input").chain(
rc.restrict(
~rc.Matcher("a2.p_bias"), # the complement operation can only be made on Matcher
start_depth=1,
end_depth=2,
)
)
if MAIN:
patched_circuit = path_patching(
ld_circuit,
baseline_data=ioi_dataset.prompts_toks,
patch_data=flipped_IO_dataset.prompts_toks,
group=group,
matcher=matcher,
input_name="tokens",
)
patched_circuit.print_html()
## Test your implementation
patched_array = matcher.chain("tokens").chain(rc.Array).get(patched_circuit)
non_patched_array = complement_matcher.chain("tokens").chain(rc.Array).get(patched_circuit)
assert len(patched_array) == 1 and len(non_patched_array) == 1
assert list(patched_array)[0].name == "tokens_patched"
assert list(non_patched_array)[0].name == "tokens_baseline"
# %% path patching debug example
def extender1(m: rc.IterativeMatcher) -> rc.IterativeMatcher:
return m.chain(
rc.restrict(
rc.Matcher("m8.p_bias"),
end_depth=9, # we use the end_depth to select only the most direct path
term_if_matches=True,
)
)
def extender2(m: rc.IterativeMatcher) -> rc.IterativeMatcher:
return m.chain(
rc.restrict(
rc.Matcher("a2.p_bias"),
end_depth=9,
term_if_matches=True,
)
)
if MAIN:
# we only have our logit as a root
corr = Correspondence()
i_root = InterpNode(ExactSampler(), name="logits")
m_root = corr_root_matcher
corr.add(i_root, m_root)
results = iterative_path_patching(
circuit=ld_circuit,
hypothesis=corr,
nodes_to_connect=[i_root],
baseline_data=ioi_dataset.prompts_toks.to(DEVICE),
patch_data=flipped_IO_dataset.prompts_toks.to(DEVICE),
group=group,
matcher_extenders=[extender1, extender2],
input_name="tokens",
)
print(results.shape, results.mean(dim=-1))
# checking the results. Make sure you run the dataset definition cell only once!
assert torch.isclose(results[:, 0], torch.tensor([1.6240, 1.6022], device=DEVICE), atol=1e-3).all()
# %%
"""
### Our first path patching experiment
Instead of defining matcher extenders by hand, we often define an extender factory that takes a parameter (e.g. the layer and head number) and returns an extender that reaches this targeted head from an arbitrary starting point.
If our target head is H and we investigate its _direct effect_ on the logit, we don't allow paths of the form H->A->logits where A is an arbitrary node. To translate this into matchers, we'll use the `term_early_at` argument of `restrict`. This argument allows us to stop the traversal when we reach a certain node. In our case, we want to stop the traversal when we reach a node that is _not_ one we specify.
To this end, we define `ALL_NODES_NAMES` to be the set of all the MLP and attention heads at a particular position. We can then recover the set of names of all but the target: each time we come across a name in this set, we should stop.
To define `ALL_NODES_NAMES`, we use the class `MLPHeadAndPosSpec` that handles attention head and MLP nodes. It includes helpful methods such as `to_name` that return the name of the node given a prefix.
Advanced details:
If you look at the code below, you can spot two additional details in addition to this story. What are they?
<details>
<summary>Click here to see the answer</summary>
* We add a `qkv` parameter to the extender factory. This parameter is used to restrict the extender to a particular qkv head. This is useful to allow paths that go through only Q, K or V of heads.
* We add `rc.new_traversal(start_depth=1, end_depth=2)` before specifying the direct path. This is because the starting point of the matcher is part of `ALL_NODES_NAMES`, we don't want to stop at the root! So we force the matcher to go one level deeper.
</details>
Note: we heavily rely on `end_depth` when defining matchers. This makes them easier to understand, but they are also much more brittle! A single rewrite of the circuit can mess up the depth of the nodes we are interested in. Beware when copy-pasting such definitions in your project, and always print your circuit to be sure what you're matching.
"""
# %%
ALL_NODES_NAMES = set(
[
MLPHeadAndPosSpec(l, cast(HeadOrMlpType, h), pos).to_name("")
for l in range(12)
for h in (list(range(12)) + ["mlp"]) # type: ignore
for pos in range(MAX_LEN)
]
)
def extender_factory(node: MLPHeadAndPosSpec, qkv: Optional[str] = None):
"""
`qkv` define the input of the attention block we want to reach.
"""
assert qkv in ["q", "k", "v", None]
node_name = node.to_name("")
nodes_to_ban = ALL_NODES_NAMES.difference(set([node_name]))
if qkv is None:
attn_block_input = rc.new_traversal(start_depth=0, end_depth=1)
else:
attn_block_input = rc.restrict(f"a.{qkv}", term_if_matches=True, end_depth=8)
def matcher_extender(m: rc.IterativeMatcher):
return m.chain(attn_block_input).chain(
rc.new_traversal(start_depth=1, end_depth=2).chain(
rc.restrict(
rc.Matcher(node_name),
term_early_at=rc.Matcher(nodes_to_ban),
term_if_matches=True,
)
)
)
return matcher_extender
matcher_extenders = [
extender_factory(MLPHeadAndPosSpec(l, cast(HeadOrMlpType, h), END_POS), qkv=None)
for l in range(12)
for h in list(range(12)) + ["mlp"] # type: ignore
]
"""
### Let's run our first experiment!
Question: Question: Does the experiment affect the output of heads and MLPs that come after the patched connection? (i.e. do we filter for downstream effects?)
<details>
<summary>Click here to see the answer</summary>
Because we're path patching N->logits, any potential downstream effect would come _after_ the logits. However, the logits are the output of our model, so there is nothing to filter here. Filtering downstream effects doesn't mean anything for this particular experiment.
</details>
"""
# %%
if MAIN:
# we only have our logit as a root
corr = Correspondence()
i_root = InterpNode(ExactSampler(), name="logits")
m_root = corr_root_matcher
corr.add(i_root, m_root)
results_IO = iterative_path_patching(
circuit=ld_circuit,
hypothesis=corr,
nodes_to_connect=[i_root],
baseline_data=ioi_dataset.prompts_toks.to(DEVICE),
patch_data=flipped_IO_dataset.prompts_toks.to(DEVICE),
group=group,
matcher_extenders=matcher_extenders,
input_name="tokens",
output_shape=(12, 13, -1),
)
# %%
"""
#### Visualizing the results
We use [plotly](https://plotly.com/python/) to plot the results. It produces interactive graphs made from HTML. You can hover over the graph to see the exact values of each entry.
Some tricks that make results easier to visualize:
Format the results:
* We reshape the results to have a 12x13 matrix, where the 13th column is the mlp
* We compute the mean of the results over the 3 heads
Plotly tricks:
* Center the color map so white is zero
* Add labels to the axis
Feel free to reuse `show_mtx` for your experiments. The default value of the "title" variable is a nudge to encourage you to always define it ;)
"""
def show_mtx(mtx, title="NO TITLE :(", color_map_label="Logit diff variation"):
"""Show a plotly matrix with a centered color map. Designed to display results of path patching experiments."""
# we center the color scale on zero by defining the range (-max_abs, +max_abs)
max_val = float(max(abs(mtx.min()), abs(mtx.max())))
x_labels = [f"h{i}" for i in range(12)] + ["mlp"]
fig = px.imshow(
mtx,
title=title,
labels=dict(x="Head", y="Layer", color=color_map_label),
color_continuous_scale="RdBu",
range_color=(-max_val, max_val),
x=x_labels,
y=[str(i) for i in range(mtx.shape[0])],
aspect="equal",
)
fig.show()
if MAIN:
variation_ld_flipped_IO = results_IO.mean(dim=-1) - ref_ld
show_mtx(variation_ld_flipped_IO.cpu(), title="Logit diff variation (flipped IO)")
# %%
"""
Question: How do you interpret this plot? (you can try to answer by dividing two parts: "Observation" and "Interpretation"). What can you conclude about the claim just by looking at this plot?
Reminder: we're trying to investigate the claim "Attention heads directly influencing the logits are either not influencing IO and S logits, or are increasing the IO logits more than the S."
<details>
<summary>Click here to see the answer</summary>
### Observation
First, we observe that most of the heads in layers earlier than 7 don't influence the logit diff (the difference in logit diff is < 1% of `ref_ld`).
We can identify heads that are directly influencing logits. After path patching, some lead to a decrease in logit diff (e.g. 9.9, 9.6, and 10.0) while for others, we observe a higher logit diff after patching them (e.g. 10.7 and 11.10).
### Interpretation
For the heads causing a decrease: they are run on an input unrelated IO (let's call it A), so they are pushing to increase logit A - logit S instead of logit IO - logit S. The intervention leads to a decrease in total logit diff.
For the heads causing an increase in logit diff: when they are run on unrelated input, the final logit diff is higher. This means that on a fixed sentence for a random A, their contribution to logit A - logit S is greater than logit IO - logit S. This suggests that in normal conditions, they are pushing against the IO logit.
However, we cannot conclude the claim yet. We only compared the contribution to logit IO with a logit A, for a random name. We need to intervene the S logit to compare the relative contribution to IO _and_ S.
</details>
"""
# %%
"""
### Exercise!
We used `flipped_IO_dataset = ioi_dataset.gen_flipped_prompts("IO")` to generate the dataset to patch from.
Question: What would happen if we use the `flipped_S_dataset` defined by `flipped_S_dataset = ioi_dataset.gen_flipped_prompts("S")`? (i.e. the dataset is still from the IOI family but both occurrences of S are replaced by the same random name)
(It's not an easy exercise, it's unlikely you will be able to predict the result. However, it's a good way to practice thinking ahead about what would surprise you before running an experiment.)
Question: Run the experiment and show the results on the `flipped_S_dataset`. How to conclude the claim given these new pieces of evidence?
Hint: We introduce notation to help you formalize the intuition developed in the previous explanation and apply it to this new case.
We call $H(T, x)$ the "contribution" of head $H$ to the logit of token $T$ when run on $x$. This can be seen as the projection along the unembedding vector of $T$ if we neglect the role of the layer norm.
We call $x$ the baseline input $z$ the patched input. If a head H is pushing to increase the logit diff, then we have $H(IO_x, x)$ >> $H(S_x, x)$.
"""
# %%
variation_ld_flipped_S: t.Tensor
if MAIN:
variation_ld_flipped_S = (
iterative_path_patching(
circuit=ld_circuit,
hypothesis=corr,
nodes_to_connect=[i_root],
baseline_data=ioi_dataset.prompts_toks.to(DEVICE),
patch_data=flipped_S_dataset.prompts_toks.to(DEVICE),
group=group,
matcher_extenders=matcher_extenders,
input_name="tokens",
output_shape=(12, 13, -1),
).mean(dim=-1)
- ref_ld
)
if MAIN:
if "SOLUTION":
show_mtx(variation_ld_flipped_S.cpu(), title="Logit diff variation (flipped S)")
else:
if True:
show_mtx(variation_ld_flipped_S.cpu(), title="Logit diff variation (flipped S)")
else:
print("Think about what you expect before showing this plot!")
# %%
r"""
**Disclaimer: I tried to make a description as precise as possible. If at some point you feel like they are overly detailed because you understood the intuition, feel free to skip them.**
<details>
<summary>Answer</summary>
We observe a really different plot!
### Observation
First, the effect size is about half as big as the previous one. Most heads lead to a positive variation in logit diff after patching. Moreover, there are new heads that appear that were not present in the previous experiment. The plot is less sparse than the first one.
The heads appearing in the first plot have an opposite effect size in the second plot. The change in absolute effect size varies from being 50% smaller in the second (e.g. 10.0) to being 7x smaller in the second plot (e.g. 9.9).
### Interpretation
Let $H$ be one of the heads that leads to a strong positive variation in logit diff after patching H->logits.
The logit diff is measuring $H(IO_x, z)-H(S_x, z)$.
* $S_x$ is a random name from the point of view of $H$ run on $z$.
* $IO_x$ is the same in both $x$ and $z$. And $x$ and $z$ are similar up to the value of $S$, so we expect $H(IO_x, z) \simeq H(IO_x, x)$
* H lead to a positive variation in logit diff. We isolated the direct path H->logits such that we can consider that the variation in the global logit diff is in fact the variation in the contribution of H to the logit diff. We thus have $H(IO_x, x)-H(S_x, x)$ << $H(IO_x, z)-H(S_x, z)$ (pre patching << post patching)
* Finally, we have $H(S_x, x)$ >> $H(S_x, z)$
The heads we observe in the plot are the heads writing in the direction of $S$. Those are the ones responsible for the fact that the proba of $S$ is much higher that a random name.
### Conclusion
The same argument can be made to formalize the first plot. The only difference is that the effect is reversed (due to the negative sign in the logit diff). The first plot shows heads writing IO more than a random name.
The heads that appear in the two plots with a flipped sign are writing both S and IO more than random names.
By comparing the two plots, we can conclude the relative importance between $H(IO_x, x)$ and $H(S_x, x)$ (by looking at a single plot, the only reference was a random name). The effect size is much higher for the first plot: there is a sparse set of heads (10.0, 9.6, 9.9) specialized in identifying IO and boosting it much more than a random name. In the second plot, the large set of heads seems to implement a mechanism like "push tokens from the context", with a smaller effect size than in the case of IO.
To find the heads that have higher $H(IO_x, x)$ than $H(S_x, x)$, we can visualize the sum of the two plots (no need to take the difference). Positive values are heads that are pushing S more than IO. Negative values are heads that are pushing IO more than S.
### Back to the claim
Plotting this is the most accurate way to answer the claim as we control for the baseline effect of "writing all names in context more than a random name". In practice, the plots we obtain are really close to the ones we identified in the first plot. Still, we can see for instance that 10.0 is not pushing IO much more than S compared to what the first plot shows.
To conclude, we can clearly observe attention heads significantly contributing to the logit, and pushing for S more than IO (e.g. 10.7 and 11.10). The claim is thus **false**.
</details>
"""
# %%
if MAIN:
total = variation_ld_flipped_S + variation_ld_flipped_IO # type: ignore
show_mtx(
total.cpu(),
title="Sum of the two logit diff variation (flipped IO + flipped S)",
)
# %%
r"""
That was quite a convoluted explanation! Can you think of a way to show the same thing with only one iterative path patching experiment while still having input in the IOI family?
<details>
<summary>Answer</summary>
If we consider the final logit diff to be the sum of the head contribution, we have:
$logit diff = \sum{H(IO_x, x) - H(S_x, x)}$.
For that, we could just kill one term of the sum we're interested in by patching a head $H$ on input with a different S _and_ a different IO. The difference in logit diff before and after patching would be $H(IO_x, x) - H(S_x, x) - [H(IO_x, z) - H(S_x, z)]$. Both $H(S_x, z)$ and $H(IO_x, z)$ are about the same value: they are the average logit for names unrelated to the sentence $z$. We're left with $H(IO_x, x) - H(S_x, x)$, that's what we care about.
If you run this experiment, you can see that we have results that a close to the previous cell (with the sum of the two first experiments). Actually, by computing `variation_ld_flipped_S+variation_ld_flipped_IO-variation_ld_flipped_S_IO` you can evaluate the error from our simple model where we decompose the logit diff in the sum of head contribution: if the model was perfect, the error should be zero.
</details>
Exercise: write an experiment that addresses the claim using a single run of `iterative_path_patching`
"""
# %%
if "SOLUTION":
if MAIN:
variation_ld_flipped_S_IO = (
iterative_path_patching(
circuit=ld_circuit,
hypothesis=corr,
nodes_to_connect=[i_root],
baseline_data=ioi_dataset.prompts_toks.to(DEVICE),
patch_data=flipped_IO_S_dataset.prompts_toks.to(DEVICE),
group=group,
matcher_extenders=matcher_extenders,
input_name="tokens",
output_shape=(12, 13, -1),
).mean(dim=-1)
- ref_ld
)
# %%
"Plot the difference between your experiment and the previous best guess we had (variation_ld_flipped_S+variation_ld_flipped_IO)"
if "SOLUTION":
if MAIN:
show_mtx(
(variation_ld_flipped_S_IO - (variation_ld_flipped_S + variation_ld_flipped_IO)).cpu(), # type: ignore
title="Residual error from linear model (variation_ld_flipped_S_IO-(variation_ld_flipped_S+variation_ld_flipped_IO) ).",
)
# %%
"""
### Takeaway from the previous experiments
We went through a convoluted path to answer the claim, and we thoroughly detailed every experimental result. It is supposed to simulate a realistic chain of thoughts of research:
* Start with an experiment (the flipped IO)
* Realize that it did not show what you expected it to show
* Think about a way to make additional experiments to show your point (the flipped S)
* Realize that you could have done things more directly (the flipped S_IO)
In practice, you're of course encouraged to think about the more direct experiment first. Moreover, you might want to detail your thoughts less than what we did here at the risk of not being able to address only a few of the claims. This section shows a standard of "if you think carefully, this is how far you can get by interpreting the results of a single experiment".
### Checkpoint!
Read section 3.1 of the [IOI paper](https://arxiv.org/pdf/2211.00593.pdf). (Don't read section 3.2 to avoid spoilers)
Notice how the discovery of name movers is different (here we did not use the ABC dataset). It should give you more context on the heads you just identified.
## Advanced Tooling
We could stop here, however, to demonstrate more tools, we'll push further the investigation of these newly identified heads. We'll demonstrate:
* Getting activations of the heads
* Visualizing attention patterns with CUI
* Simple causal scrubbing experiments
* Example of moving pieces experiments
### Helper functions
Before delving into more advanced tools, we'll need to define a few helper functions. We will make our hypothesis grow step by step by adding new nodes that connect to previously discovered nodes. For instance here, we'd like to connect the name movers to the logit nodes.
`extend_corr` enables us to handle hypotheses more easily. It adds an `InterpNode` connected to a node and creates a new matcher by applying a matcher extender.
"""
# %%
def extend_corr(
corr: Correspondence,
from_name: str,
to_name: str,
matcher_extender: Callable[[rc.IterativeMatcher], rc.IterativeMatcher],
cond_sampler: CondSampler,
other_inputs_sampler: CondSampler = UncondSampler(),
):
prev_node = corr.get_by_name(from_name)
prev_matcher = corr.corr[prev_node]
new_node = prev_node.make_descendant(
name=to_name,
cond_sampler=cond_sampler,
other_inputs_sampler=other_inputs_sampler,
)
new_matcher = matcher_extender(prev_matcher)
corr.add(new_node, new_matcher)
# %%
"""
To specify its counterpart in the computation graph, we will define a matcher extender that matches the new head through a direct path.
This matcher extender is slightly different than the one we described above: we will reuse part of the code from `extender_factory`, however, instead of reaching a single target (the `MLPHeadAndPosSpec` object), we want to reach a group of nodes given by their names.
Exercise: inspired by the first matcher extender (see the definition of `extender_factory`), write the body of `add_path_to_group`.
Be careful: here we want to match a set of nodes given by their names.
<details>
<summary>Hint</summary>
You need to use the `term_if_matches` parameter of `restrict`. If we want to reach a set of heads `{H1, H2}` that are not at the same layer, sometimes we want to avoid paths of the form `logits -> heads H1 -> heads H2`. To do that, we'll use the flag `term_if_matches`: once a node is matched, we stop the traversal.
This restriction was unnecessary for `extender_factory` we defined earlier: as we are targeting a node with a name appearing once in the circuit, there is no risk of composition.
</details>
"""
# TBD adding exercises here?
def add_path_to_group(
m: rc.IterativeMatcher,
nodes_group: List[str],
term_if_matches=True,
qkv: Optional[str] = None,
):
"""Add the path from a matcher to a group of nodes using chain operation. Different filtering parameters.
If `term_if_matches=False` and `qkv` is not `None`, the `qkv` restrition will only be applied on the path to the first nodes found starting from `m`, indirect effect will not be restricted by `qkv`.
"""
assert qkv in ["q", "k", "v", None]
nodes_to_ban = ALL_NODES_NAMES.difference(set(nodes_group))
if qkv is None:
attn_block_input = rc.new_traversal(start_depth=0, end_depth=1)
else:
attn_block_input = rc.restrict(f"a.{qkv}", term_if_matches=True, end_depth=8)
if "SOLUTION":
return m.chain(attn_block_input).chain(
rc.new_traversal(start_depth=1, end_depth=2).chain(
rc.restrict(
rc.Matcher(*nodes_group),
term_early_at=rc.Matcher(nodes_to_ban),
term_if_matches=term_if_matches,
)
)
)
else:
raise NotImplementedError("You need to implement this function")
def extend_matcher(
match_nodes: List[str],
term_if_matches=True,
restrict=True,
qkv: Optional[str] = None,
):
def match_func(m: rc.IterativeMatcher):
if restrict:
return add_path_to_group(m, match_nodes, term_if_matches=term_if_matches, qkv=qkv)
else:
return m.chain(rc.restrict(rc.Matcher(*match_nodes), term_if_matches=term_if_matches))
return match_func
"""
### Matcher debugging: show path by groups of nodes
Before putting those helper functions in action, we'll create a debugger tool that can show succinctly the path our matcher is taking. Let's begin by giving a name to the nodes we just discovered. For consistency, we'll stick to the old names (name movers). You'll not have the chance to practice creative naming this time!
#### Debugger info
`print_all_heads_paths` print paths matched by an `IterativeMatcher`. Because paths are often long and involve nodes we're not interested in (e.g. layer norms), it applies various filtering:
* Only show nodes parts of `ALL_NODES_NAMES`
* If show_qkv is `True`, also show qkv nodes
* If a node name is a key of the dict `short_names`, print the value in this dict
* Never print the same string twice
Exercise: write the body of the function `keep_nodes_on_path` to filter the nodes we want to keep. You can run the next three cells to debug your implementation.
We'll see how it works in action in a second.
"""
qkv_names = [f"a{i}.q" for i in range(12)] + [f"a{i}.k" for i in range(12)] + [f"a{i}.v" for i in range(12)]
def keep_nodes_on_path(path: list[rc.Circuit], nodes_to_keep: set[str]) -> list[str]:
"""
Given a path as a list of nodes, create the list of the names of the nodes present in `nodes_to_keep`, in the order they appear in the path.
"""
filtered_path = []
if "SOLUTION":
for x in path:
if x.name in nodes_to_keep: # we keep only the interesting nodes: mlp and attn heads
filtered_path.append(x.name)
return filtered_path
def print_all_heads_paths(
matcher: rc.IterativeMatcher,
show_qkv=False,
short_names: Union[Dict[str, str], None] = None,
):
print_by_class = short_names is not None
if show_qkv:
nodes_to_include = set(list(ALL_NODES_NAMES) + ["a.q", "a.k", "a.v"])
else:
nodes_to_include = set(ALL_NODES_NAMES)
nodes_to_include.add("logits")
all_paths = matcher.get_all_paths(circuit)
for target, paths in all_paths.items():
print()
print(f"--- paths to {target.name} ---")
already_printed = set()
for i, path in enumerate(paths):
if print_by_class:
nodes_to_print = keep_nodes_on_path(path, nodes_to_include)
class_to_print = []
for n in nodes_to_print:
if n in short_names:
class_to_print.append(short_names[n]) # type: ignore
else:
class_to_print.append(n)
p = "->".join(class_to_print)
else:
p = "->".join(keep_nodes_on_path(path, nodes_to_include))
if p in already_printed:
continue
already_printed.add(p)
print(f"Path {i} : {p}")
# %%
"""
Let's add names to our nodes! We all call them "POS_NM" and "NEG_NM" for Positive / Negative Name Movers (short names are better to be printed along paths).
We will keep this list of heads up to date each time we find new nodes to add. It's useful to keep up to date with various variables that store the information about the found nodes. We create a function `add_node_to_pokedex` to handle bookkeeping for us.
"""
# %%
short_names = {}
grouped_nodes_name: dict[str, list[str]] = {}
grouped_nodes_spec: dict[str, list[MLPHeadAndPosSpec]] = {}
def add_node_to_pokedex(nodes: list[Tuple[MLPHeadAndPosSpec, str]]):
global short_names, grouped_nodes_name, grouped_nodes_spec
for node, name in nodes:
if node not in short_names:
short_names[node.to_name("")] = name
if name not in grouped_nodes_name:
grouped_nodes_name[name] = []
grouped_nodes_spec[name] = []
grouped_nodes_name[name].append(node.to_name(""))
grouped_nodes_spec[name].append(node)
# %%
"""
To avoid duplicate entries in the global variable above, don't run cells like these twice!
"""
add_node_to_pokedex(
[
(MLPHeadAndPosSpec(10, 0, END_POS), "POS_NM"),
(MLPHeadAndPosSpec(9, 6, END_POS), "POS_NM"),
(MLPHeadAndPosSpec(9, 9, END_POS), "POS_NM"),
]
)