-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathremix_d1_solution.py
3336 lines (2392 loc) · 131 KB
/
remix_d1_solution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# %%
"""
# REMIX Day 1 - Intro to Circuits
Today you'll implement a simplified version of Redwood's library Circuits and understand the key differences between our approach and other methods based on PyTorch hooks.
You'll use your library to investigate the behavior of two simple neural networks trained on the classic MNIST dataset.
Both networks are identical in architecture, train to approximately human level accuracy in a couple seconds, and generalize well to the test set.
Should we expect their internal machinery to implement basically equivalent algorithms with minor differences due to random initialization, or is it possible that they both answer the same examples correctly for different "reasons"?
<!-- toc -->
## Learning Objectives
After today's material, you should be able to:
- Explain the differences between hooks and Circuits
- Write computational graphs using the various `Circuit` subclasses
- Match nodes and paths using various crtieria
- Explain the concept and uses of graph rewrites
## Setup
Make sure you have the following installed:
`conda run -n remix pip install attrs einops jupyter ipywidgets ipykernel scikit-learn requests pandas fancy_einsum seaborn tqdm websockets get-mnist transformers tabulate plotly black`
In VS Code, we recommend you have word wrap enabled (View -> Word Wrap).
We also recommend you leave the Explorer pane on the left closed to maximize screen real estate, and open files as needed with Command-P or Ctrl-P. Finally, set your font size to one appropriate for your monitor: Command-Shift-P (or similar) opens the Command Palette, then type "font" to see available commands.
## A Note on Test Coverage
We have provided some unit tests to help you get through the material faster and reduce the amount of time stuck on bugs. These tests are *not* intended to be exhaustive and passing all tests does not definitely prove that your code is correct.
When you are doing research, you'll have to write your own unit tests, so it's good practice to get into the habit of writing additional test cases as needed. In particular, when you encounter issues you should get into the habit of considering whether earlier code could be broken in ways that earlier tests didn't detect.
I also recommend using assertions liberally throughout the exercises. I frequently assert that shapes and types are as expected, and that preconditions and postconditions are maintained.
## Getting Started
Copy the following code block into a new file, and ensure you're on an appropriate branch like `remix_d1/chris-and-max` (see the README if you don't know how to do this).
"""
# %%
# TBD lowpri: ensure that torch.Size is gone and we just have regular tuples. (they show up in print). Probably we want to use rc.Shape as an alias throughout.
# TBD lowpri: use children instead of inputs for consistency with real Circuits and note they're synonymous.
from __future__ import annotations
import os
import sys
import inspect
import itertools
import os
import re
import string
import sys
from abc import ABC, abstractmethod
from functools import partial
from typing import (
Any,
Callable,
Iterable,
Literal,
Optional,
Sequence,
Type,
Union,
cast,
)
import attrs
import pandas as pd
import seaborn as sns
import torch
import torch as t
import torch.nn as nn
import torch.nn.functional
import fancy_einsum
from attrs import evolve
from einops import rearrange, repeat
from matplotlib import pyplot as plt
from torch.utils.data import TensorDataset
from tqdm.notebook import tqdm
import remix_utils
pd.set_option("display.precision", 3)
MAIN = __name__ == "__main__"
if "SKIP":
IS_CI = os.getenv("IS_CI")
if IS_CI:
sys.exit(0)
# %%
"""
## Model Definition
We'll be starting with a very simple architecture, with about 600K parameters. Read the model definition and make sure you understand the flow of data - making your own diagram may be helpful.
Each input is a grayscale image 28 pixels wide and 28 pixels high, and is represented as a flat vector of 28*28 floats.
The output of the model is 10 logits, one for each possible digit.
Exercise: what is `model_a.last.weight.shape`? What is the equation for the computation that happens when you run `model_a.first(x)`? If you don't know offhand, review [the official docs on Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html).
<details>
<summary>Solution</summary>
The weight of a Linear is `(out_features, in_features)`, which may be the reverse of what you expect.
The equation is $y = x @ weight^T + bias$.
The method behind this madness? Each output element is a dot product between a row of the first operand and a column of the second operand.
If the Linear weight was stored as `(in_features, out_features)`, then data for a column would be scattered through memory. This trick makes the data for a column contiguous, improving speed.
</details>
"""
# %%
class TwoLayerSkip(nn.Module):
"""Model with two Linear layers and a skip connection that bypasses the first linear-relu section."""
def __init__(self):
super().__init__()
self.first = nn.Linear(28 * 28, 28 * 28, bias=True)
self.last = nn.Linear(28 * 28, 10, bias=True)
def forward(self, x: t.Tensor) -> t.Tensor:
skip = x.clone()
skip += torch.nn.functional.relu(self.first(x))
return self.last(skip)
model_a = TwoLayerSkip()
model_a.load_state_dict(t.load("./remix_d2_data/model_a.pickle", map_location=torch.device("cpu")))
model_b = TwoLayerSkip()
model_b.load_state_dict(t.load("./remix_d2_data/model_b.pickle", map_location=torch.device("cpu")))
models = [("A", model_a), ("B", model_b)]
# %%
"""
## MNIST
The details of the dataset are always important in ML. In a moment, we'll use Redwood's composable UI to explore the dataset. Some basic facts you should know:
- Each of the image shows a handwritten digit from 0-9 centered in the frame.
- These data have already been normalized to mean 0 and standard deviation 1.
- Human level performance on this data is around 97-99%. The reasons it's not higher are that a small number of images are too ambiguously written, and some images can correctly be classified by humans, but the label in the dataset is mistaken so this is (incorrectly) treated as a wrong prediction.
### TensorDataset
`train_dataset` and `test_dataset` are instances of the PyTorch class `TensorDataset`. Indexing into the `TensorDataset` returns a tuple of (image, label), and the `tensors` field gives access to the full tensors.
Run the code below to view some example images.
"""
train_dataset, test_dataset = remix_utils.get_mnist()
test_inputs, test_labels = test_dataset.tensors
# %%
def plot_img(img: t.Tensor):
arr = img.detach().cpu().numpy()
fig, ax = plt.subplots()
axes_img = ax.imshow(arr.reshape(28, 28), cmap="gray_r")
fig.colorbar(axes_img)
return fig
for i in range(3):
img, label = train_dataset[i]
fig = plot_img(img)
fig.suptitle(f"Ground truth label: {label.item()}")
# %%
"""
### Crash Course on the `attrs` library
We use the [`attrs`](https://www.attrs.org/en/stable/index.html#) library for many of the classes in Circuits. `attrs` is similar to the standard library `dataclasses` but has more features, which we'll introduce as they become important.
For today, you'll need to know that when we use `@attrs.define`, various things including a constructor, a string representation, and an equality operator are automatically generated. Equality is defined as all the corresponding fields being equal.
When we use `frozen=True`, trying to assign to the object's fields will raise an exception. One benefit of using frozen objects is that multiple places can hold a reference to the same frozen object without worrying about changes in one place affecting the other.
Exercise: try assigning to `a.x`.
- Verify that your IDE's static typechecking shows this as an error with a helpful message such as "Cannot assign member "x" for type "Point3d".
- Verify that when you run the assignment, `FrozenInstanceError` is raised at runtime.
<details>
<summary>Help - the VSCode checker isn't seeing my assignment as an error!</summary>
In the command palette, choose "Preference: Open Workspace Settings (JSON)" and ensure that you have an entry like "python.analysis.typeCheckingMode": "basic".
If this still doesn't work, call a TA as this day will be significantly harder without being able to rely on the type checker.
</details>
"""
# %%
@attrs.define(frozen=True)
class Point3d:
x: float
y: float
z: float
pa = Point3d(1.0, 2.0, -3.0)
pb = Point3d(1.0, 2.0, -3.0)
pc = Point3d(1.0, 2.0, -4.0)
print("Two objects defined via attrs are equal iff all their fields are equal:")
print("a == b", pa == pb)
print("a == c", pa == pc)
# %%
"""
### attrs.evolve
If you want to "modify" a frozen object, what you'll actually do is create a new object with new values for some of the fields.
This isn't bad with only 3 fields in our example, but it's cumbersome when you have many fields, so `attrs` provides a method called `evolve` that lets you specify only the changed fields and copies the rest.
"""
# %%
a_flipped = Point3d(-pa.x, pa.y, pa.z)
a_flipped_2 = evolve(pa, x=-pa.x)
print("Two ways to do the same thing: ", a_flipped, a_flipped_2)
# %%
"""
### Really modifying (mutating) frozen objects
A statement like `a.x = 5.0` actually calls `Point3d.__setattr__(a, 'x', 5.0)`, which has been implemented by `attrs` so that it raises an exception.
Later, we'll legitimately need to modify frozen objects (to initialize them properly), and we can do so by calling `object.__setattr__(a, 'x', 5.0)` instead.
## Data Visualization
The next thing we'll want to do is be able to run our model and inspect its performance.
Quick note on terminology: "performance" can be ambiguous and refer to:
- The model's competence for some (loss function, dataset) pair
- The number of floating point operations required to run the model
- The wall clock time needed to run the model
For the model's competence, it's best to be verbose and spell out "accuracy on the test set" or similar, as the exact details do matter for our experiments.
The exact details of FLOPS vs wall time are much less important, so I'll loosely say "speed" to refer to them.
Since the data is small, we'll run the entire test set in one large batch of size 10,000 for simplicity. We're running on CPU by default, just to avoid having to move tensors between devices.
This code is provided for you since it's very routine stuff.
"""
# %%
@attrs.define(frozen=True, eq=False)
class TestResult:
loss: float
acc: float
logits: t.Tensor
incorrect_indexes: t.Tensor
results: dict[tuple[str, str], TestResult] = {}
def show_results():
rows = []
for (exp_name, model), result in results.items():
rows.append((exp_name, model, result.loss, result.acc))
return pd.DataFrame.from_records(
rows,
columns=["experiment", "model", "test_loss", "test_acc"],
index=("experiment", "model"),
)
def test(model: nn.Module, dataset: TensorDataset, device: Union[str, t.device] = "cpu") -> TestResult:
model.eval()
model.to(device)
all_incorrect_indexes = []
data, target = dataset.tensors
data, target = data.to(device), target.to(device)
with torch.inference_mode():
logits = model(data)
test_loss = torch.nn.functional.cross_entropy(logits, target, label_smoothing=0.1).item()
pred = logits.argmax(dim=1) # (n, )
correct = pred == target # (n, )
all_incorrect_indexes.append((~correct).nonzero().flatten())
return TestResult(
loss=test_loss,
acc=correct.sum().item() / len(dataset),
logits=logits,
incorrect_indexes=t.cat(all_incorrect_indexes),
)
for name, model in models:
results[("baseline", name)] = result = test(model, test_dataset)
show_results()
# %%
"""
### Confusion Matrix
Let's also look at the confusion matrix.
Exercise: Which digits are misclassified most often?
"""
# %%
from sklearn.metrics import confusion_matrix
logits_a = results[("baseline", "A")].logits
logits_b = results[("baseline", "B")].logits
fig, axes = plt.subplots(figsize=(5, 10), nrows=2)
for (name, model), ax, logits in zip(models, axes, [logits_a, logits_b]):
mat = confusion_matrix(test_labels, logits.argmax(dim=-1))
for i in range(10):
mat[i, i] = 0
sns.heatmap(mat, ax=ax, vmax=18)
ax.set(xlabel=f"Model {name} Prediction", ylabel="True class")
# %%
"""
## Intro to Composable UI
The composable UI (CUI) is Redwood's internal tool for visualizing multi-dimensional tensors in a web frontend.
It allows us to visualize one or more tensors with meaningful labels on each dimension and position.
You can switch between different charts without having to write additional code, and you can "link" together different charts.
Code-wise, CUI has two parts: a frontend written in React that runs on a web server, and a backend which runs in your Python kernel.
When you open your browser to the frontend web page, your browser will try to communicate with the backend over a specific port, which defaults to 6789. If you're running over SSH, then you need to also forward the port to the remote machine.
In VSCode, there's a tab "Ports" which you can click, then "Add Port" -> 6789. VSCode should remember this for future sessions. (The "Ports" tab will only be present when you are connected to the remote host.)
You don't have to understand anything about the `await` keyword here, and you can ignore the type-checker's complaint - it doesn't know that we're running this in a "notebook".
There's definitely a learning curve to this tool. Some things you can try to start:
- Change the "example" dimension from "axis" to "facet". This will render a separate plot for each example image.
- The default chart type is set up for text; change the chart type to "Tiny matrix". This will allow you to see the images in a natural form. Which "mistakes" of the model do you think are not really mistakes? Can you find an instance where the preprocessing didn't work correctly?
- Using "Sided matrix" instead of "Tiny matrix" will display axes labels and will also let you see the values of the tensor's entries on hover.
- Click "Add new plot", then go to the bottom and set Plot 1 to "logits" to view that tensor.
- In Plot 1, set the chart type to "Sided matrix" and set the model dimension to "facet". For the same example, in what ways are the logits different across the two models?
Note that the selections you've made are saved in the URL. If the tool crashes or you want to save your selections for later, you can copy the URL.
"""
# %%
from rust_circuit.ui import cui
from rust_circuit.ui.very_named_tensor import VeryNamedTensor
# TBD lowpri: it feels weird that there are some empty feature maps. Is this a bug in the visualizer or is it really like this?
# TBD lowpri: idk how stable the below is, but we'll get people to just skip over
remix_utils.await_without_await(lambda: cui.init(port=6789))
idx = (
(
((logits_a.argmax(-1) != 4) | (logits_b.argmax(-1) != 4))
& ((logits_a.argmax(-1) == 9) | (logits_b.argmax(-1) == 9))
& (test_labels == 4)
)
.nonzero()
.flatten()
)
vnt_incorrect = VeryNamedTensor(
test_inputs[idx].reshape(-1, 28, 28),
dim_names="example height width".split(),
dim_types="example seq seq".split(),
dim_idx_names=[idx.tolist(), range(28), range(28)],
title="4s misclassified as 9s by at least one model",
)
vnt_logits = VeryNamedTensor(
t.stack([logits_a[idx], logits_b[idx]], dim=0),
dim_names="model example pred".split(),
dim_types="model example seq".split(),
dim_idx_names=[["A", "B"], idx.tolist(), range(10)],
title="logits",
)
remix_utils.await_without_await(lambda: cui.show_tensors(vnt_incorrect, vnt_logits))
# %%
"""
## Visualizing Weights
Exercise: create a `VeryNamedTensor` for `first_weights` and `last_weights`, then take up to 5 minutes to visualize them in the CUI. Can you discover anything interesting?
<details>
<summary>Does it matter what dim_type I use?</summary>
You can use "seq" as a `dim_type` to let CUI know that this axes is sequential (ordered integers). This is used to recommend appropriate visualizations. For the model dimension, they aren't really ordered so you can use any string you want and nothing special will happen.
</details>
<details>
<summary>Spoiler - things to see</summary>
Looking at the last layer weights with Tiny Matrix (or Sided Matrix) and the model and logit dimensions set to facet, you can clearly see digit patterns in Model B but not in Model A.
Looking at the first layer weights with Tiny Matrix and the model and output dimensions set to facet, Model B's first layer is much more sparse. Many of the feature maps are empty or nearly empty, meaning the first layer doesn't write anything to that output pixel.
</details>
"""
# %%
first_weights = t.stack([model.first.weight.detach().reshape(784, 28, 28) for _, model in models], dim=0)
last_weights = t.stack([model.last.weight.detach().reshape(10, 28, 28) for _, model in models], dim=0)
vnt_first: VeryNamedTensor
vnt_last: VeryNamedTensor
if "SOLUTION":
vnt_first = VeryNamedTensor(
first_weights,
dim_names="model out inheight inwidth".split(),
dim_types="model seq seq seq".split(),
dim_idx_names=[["A", "B"], range(28 * 28), range(28), range(28)],
title="W_first",
)
vnt_last = VeryNamedTensor(
last_weights,
dim_names="model pred inheight inwidth".split(),
dim_types="model seq seq seq".split(),
dim_idx_names=[["A", "B"], range(10), range(28), range(28)],
title="W_last",
)
remix_utils.await_without_await(lambda: cui.show_tensors(vnt_first, vnt_last))
# %%
"""
Optional: Use the Composable UI to further investigate the examples that the models get wrong.
Show:
- All misclassified test set examples
- Can we include the wrong logits beside this?
- A random sample grouped by digit (digit, height, width)
- Facet on digit
- Check out what the mean digit looks like
- Can we see the variance of each pixel? This is interesting because there are some zero variance pixels.
See [here](https://www.youtube.com/watch?v=zH8YBqdIB-w) for Max's video demo of the Composable UI.
"""
# %%
"""
## Reading Intermediates
One of the most common tasks we want to do in interpretability is examine and modify intermediate `Tensor`s computed during the forward pass.
Ideally, our tooling would allow us to pick any `Tensor` in the computation we care about and operate on it generically.
Specifically, in this architecture's forward pass there are three intermediate `Tensors` computed by PyTorch.
Exercise: what are they?
<details>
<summary>Solution - intermediates</summary>
The intermediates accessible from Python are:
1) Output of `first` (same as the input of `relu`)
2) Output of `relu`
3) Output of skip aka `relu` + input (same as the input of `last`)
</details>
If you wrote out the algebraic equations for the calculation in terms of matrices, there are two more terms we haven't considered yet. In the actual PyTorch implementation, these terms don't exist as independent objects - I'll call this sort of tensor a "hidden temporary" or just "temporary".
Exercise: what are the two hidden temporaries?
<details>
<summary>Solution - temporaries</summary>
Each `Linear` layer does a matrix-multiply by the weight, which the docs refer to as `xA^T`.
This temporary is added to the bias to produce the output of the Linear. The actual `Linear` implementation does the matmul and the add in one atomic step, without ever creating a `torch.Tensor` with the contents `xA^T`.
</details>
### Manual Editing
One thing you could do is just duplicate and edit the model definition directly whenever you want to do an experiment. This is a thing that people do, and has the advantage of being very explicit about what's happening and easy to understand. Unfortunately, it's quite difficult to maintain many copies of the model code and the edits can be error prone. We don't recommend this approach.
### Hooks
PyTorch provides a cleaner way to inspect and modify your model that doesn't require editing your model definition every time, called hooks [(official docs)](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=nn%20module#torch.nn.Module.register_forward_hook).
This brief introduction to hooks is mainly beneficial because I expect you to have to read other peoples's code that uses them; our claim is that our library for circuits (called Circuits from here on) can do everything hooks can do and more.
A hook is a user-defined function registered to a specific `nn.Module` by calling the `register_forward_hook(hook)` method on that module. Later, when the `forward` method of that module is run, the hook is called with three arguments: the module, a tuple of positional inputs, and its output. Your user-defined function can then do whatever it wants.
To allow you to disable the hook, `register_forward_hook` returns a `handle` object, which allows you to call `handle.remove()` which just unregisters the hook from the module. The benefit here is you can run various experiments against one model definition.
Exercise: Implement `plot_img_hook` and register it on the `last` layer, so that it calls `plot_img` on the input to `last` when called. Run both models on an input example such as `img`. What do you see?
<details>
<summary>Help! My hook is stuck on a module and I don't have access to the handle to remove it, possibly because my code threw an exception.</summary>
Sadly, this happens all the time when using hooks. The hooks are stored in a dictionary called `_forward_hooks`, so you can clear this to manually remove them. `remix_utils.remove_all_hooks(module)` is also provided as a helper function to do this.
</details>
"""
# %%
def plot_img_hook(module: nn.Module, input: tuple[t.Tensor, ...], output: t.Tensor) -> None:
"""Plot an image, assuming the module takes the image Tensor as the single positional argument."""
"SOLUTION"
(x,) = input
fig = plot_img(x)
if "SOLUTION":
for name, model in models:
handle = model.last.register_forward_hook(plot_img_hook)
model(img)
handle.remove()
else:
"""Run the model here."""
# %%
"""
## Zero Ablation with Hooks
We could use CUI to look at a lot of activations in this way to get a better picture, but we're going to speedrun through hooks instead.
One intervention that's easy and quick to run is to replace the first layer's activation (output of the ReLU) with zero and see what happens to the loss.
If we think about a specific neuron in the first layer, it might be the case that the neuron detects some pattern such as a vertical stripe using a weighted sum of pixels. The bias of the neuron can be set to a negative value such that if the stripe isn't sufficiently present, the ReLU clamps the neuron output to zero, otherwise the neuron starts to activate (providing increasing evidence for a 1 digit). Under this theory, if we set all the neuron outputs to zero, this is like disabling the information flow through the first layer by turning all the neurons to the "feature not sufficiently present" state.
This was a hand-waving argument and we would need to do more observations and carefully justify the conditions before this would be valid. A hazard of zero ablation is that it's very easy to do it even when it isn't valid and the results don't really mean anything. For now, let's just practice using hooks and run the experiment.
There are two ways to modify the output: if the return value of a hook is not `None`, the return value will replace the output of the hooked module. I generally prefer this way, but it's also fine to modify the `output` argument in-place.
Exercise: evaluate each model on the test set again, but with the first layer activations set to zero. How is the loss affected?
<details>
<summary>I'm confused about how to do this!</summary>
A limitation of hooks is that you can only hook `nn.Module` instances. A correct but tedious option is to go through and change the model definition to use `nn.ReLU()` instead of the functional form. It's easier to register a hook on the `first` module and modify the output there to be 0. Since relu(0) = 0, this achieves the goal.
</details>
"""
# %%
def zero_ablate_hook(module: nn.Module, input: tuple[t.Tensor, ...], output: t.Tensor) -> t.Tensor:
"""Return a zero tensor with the same shape and dtype as output."""
"SOLUTION"
return t.zeros_like(output)
if "SKIP":
def zero_ablate_hook_alternate(module: nn.Module, input: tuple[t.Tensor, ...], output: t.Tensor):
output.fill_(0)
if "SOLUTION":
for name, model in models:
# Annoying that this doesn't typecheck when it's officially supported but I don't care
handle = model.first.register_forward_hook(zero_ablate_hook) # type: ignore
results[("zero", name)] = test(model, test_dataset)
handle.remove()
else:
"""Register zero_ablate_hook, run test, and store the results in `results[("zero", model_name)]`."""
show_results()
# %%
"""
### Results of Zero Ablation
Model A's accuracy dropped to near random chance, while Model B's accuracy only dropped somewhat.
This suggests a hypothesis that Model A doesn't use the skip connection much and prefers to use features calculated by the first layer, while Model B substantially uses the input pixels directly via the skip. You may also have seen supporting evidence for this in the CUI weight visualization.
It might be tempting to treat this large effect size as strong evidence that we're measuring something real here. However, it might be more likely that zero ablation wasn't appropriate here and this experiment doesn't mean much. Again, let's keep moving and practice further with hooks.
## Mean Ablation
Another easy and almost as quick intervention is to replace the activation with the mean activation over some dataset. Here's a sketch of why mean ablation could be better than zero ablation:
Suppose you have a neuron in the first layer whose weights are all zero and whose bias is +2. Since `relu` is the identity for positive inputs, this neuron always outputs +2. Neurons in later layers are able to learn that this neuron is a constant feature and depend on it. For example, a neuron that "needs" a bias of +1 could either achieve this by setting a weight of 0.5 to the constant feature and leaving its own bias at 0.
Zero ablating this constant +2 activation will increase our loss, but this isn't actually corresponding to the notion of "disabling information flow" that we hoped for. Mean activation will set the activation to +2 which is intuitively sensible.
This argument is also hand-waving and again, it's easy to do mean ablation when it doesn't actually produce valid evidence; we'll cover this with more rigor later.
Exercise: suppose we want to perform mean ablation on the first layer's activation (after the ReLU). How easy is this to do?
<details>
<summary>Spoiler</summary>
We were able to just hook `first` for zero ablation and this was the same as hooking after the ReLU since `relu(0) == 0`. For mean ablation this usually doesn't work, except we only have one batch, so you can do something similar...
If we think about the mean of one element in the tensor, it's a scalar `1/N sum (relu(x_i))`, where N is the number of training examples and i is [0..N).
We can compute this formula by hooking before the ReLU to get `x_i`, manually applying the ReLU inside the hook to get `relu(x_i)`, and summing that.
</details>
Exercise: perform mean ablation on the first layer output (after the ReLU). Take the mean over the full test set.
<details>
<summary>I'm confused about how to do this!</summary>
Normally you would need two hooks and would need to do a first pass to compute the mean, followed by a second pass to replace the output of the ReLU with the stored mean. But we only have one batch so can do everything in the same hook.
Hook `first`, compute the `ReLU` manually, and calculate the mean as described in the Spoiler. The `relu` function has the property that `relu(relu(x)) = relu(x)` so it's okay that the output of the first layer already has the `relu` applied.
Ensure that the mean has the same shape as what it's replacing.
</details>
"""
# %%
if "SOLUTION":
def mean_ablation_hook(model, input, output) -> t.Tensor:
batch, hidden = output.shape
mean = output.relu().mean(0)
assert mean.shape == (hidden,)
return repeat(mean, "hidden -> batch hidden", batch=batch)
for name, model in models:
remix_utils.remove_all_hooks(model.first)
handle = model.first.register_forward_hook(mean_ablation_hook)
results[("mean", name)] = test(model, test_dataset)
handle.remove()
else:
"""Your experiment here. Store the results in `results[("mean", model_name)]`"""
show_results()
# %%
"""
### Results of Mean Ablation
You should observe that the performance of model A drops to around 39%, while that of model B drops to around 87%.
### Mean Ablation - Limitations
The computational cost of mean ablation is higher than zero ablation, but the mean computation only has to be done once per model-dataset pair and then you can save the mean to disk and reuse it easily in different experiments.
A bigger issue is that the mean activation might be very different from the notion of a "typical" activation. Say that a first layer neuron outputs +1 when the feature it detects is not present, and +5 when the feature is present. If the feature is present in 50% of the dataset examples, the mean activation is (1+5)/2 = +3.
If we have a later neuron with a weight of 1 and bias of -1, then when we are running real examples this neuron outputs 0 when the first feature is absent, and +4 when it's present. Running on the mean, it outputs +2 all the time.
Intuitively, the idea of mean ablation is to set activations to a "typical" value, but in this example the mean of a bi-modal distribution can be very atypical and in fact the later neuron would never output +2 on real data.
## Resampling Ablation
Hopefully it should seem natural at this point to try replacing intermediates not with zero or the mean, but with the intermediates from some other dataset example. One advantage of this is that it's valid anywhere in the network, whereas our hand-waving arguments for the other ablations relied on facts about relu.
The effect of this will be noisy and random, but by doing it repeatedly we can hope to get results that are more representative.
The full setup for resampling ablation is a bit finicky, so let's do a simplified version now.
Exercise: for each model, perform resampling ablation on the output of `first` (before the relu). That is, for each example, the output is replaced by that from a randomly chosen (with replacement) example. Run the ablation 3 times - does the result vary much?
<details>
<summary>I'm confused about how to do this!</summary>
Run the test set with a hook that appends first layer activations to a list. Concatenate the list into one big tensor.
Then, run the test set again with a hook that uses `torch.randint` to generate the indexes to select.
</details>
"""
# %%
if "SOLUTION":
outputs: list[t.Tensor] = []
all_outputs: t.Tensor
def save_hook(model, input, output):
outputs.append(output)
def random_ablation_hook(model, input, output) -> t.Tensor:
batch, _ = output.shape
idx = t.randint(0, len(all_outputs), (batch,))
return all_outputs[idx]
for name, model in models:
outputs = []
remix_utils.remove_all_hooks(model.first)
total = None
count = 0
handle = model.first.register_forward_hook(save_hook)
test(model, test_dataset)
handle.remove()
all_outputs = t.concat(outputs, dim=0)
for i in range(3):
handle = model.first.register_forward_hook(random_ablation_hook)
results[(f"resample_{i}", name)] = test(model, test_dataset)
handle.remove()
else:
"""Your experiment here"""
show_results()
# %%
"""
### Resampling Ablation - results
I found that the results were very consistent between random runs. A's accuracy dropped to around 12%, and B's accuracy dropped to around 70%. If your numbers are more than a few percent off of this, it's probably a bug.
Resampling ablation did support our initial hypothesis that there's a difference between models in how important the first layer features are.
## Path Patching
Next, suppose we want to ablate in some way just the pixel information passing through the skip connection, while leaving the first layer input alone. Our hypothesis predicts that we would observe the reverse effect of the first experiment: model A should preserve its ability, while model B should lose performance. If we observe something else, this is even more valuable as it might suggest that our code has a bug or that the behavior is more complex than it appears.
Conceptually, our computation tree currently has one instance of the pixel data `x`:
```mermaid
graph TD
logits --> last --> add --> relu --> first --> x["x"]
add --> x["x"]
```
(Computation trees are often drawn with the arrows the other way round. Here, we're using the convention that arrows go from nodes to their inputs.)
(If you can't see this diagram in VS Code, either look at it on GitHub, or install the "Markdown Preview Mermaid Support" extension and reload your window.)
But we want to counterfactually think about what would happen if we only changed one of the "paths" by which `x` affects the logits. Conceptually, this looks like a tree where the `x` are different:
```mermaid
graph TD
logits --> last --> add --> relu --> first --> x1["x_to_first"]
add --> x2["x_to_skip"]
```
### Limitations of Hooks
In our current setup, running this experiment is troublesome. We can only replace the output of `nn.Module` instances, but `skip` is only named during the forward method, and the `+=` operator isn't a module. So in order to use hooks here, I would go back and rewrite the model to use `nn.Module` in every place that I might want to hook.
This is always possible and can be done without thinking too hard. One option is to write new `nn.Module` subclasses that do the operator in the forward method. In some cases, these already exist such as `nn.ReLU`. Another option is to insert do-nothing modules like `nn.Identity` just so you can have additional hook points.
Taking the latter approach leads to a library like Neel Nanda's EasyTransformer or David Bau's baukit (likely others as well), but in Circuits we've chosen to take a third approach which we claim is better.
It's now time to put interpretability on hold and take a long detour into building out your own version of Circuits!
## Introducing the Circuits library
Our model only has two layers and very simple operators, but the code is already starting to get messy. On future days, you'll be doing much more complicated experiments on transformers which have a large number of moving parts. In particular, the number of possible paths grows rapidly with the number of layers and it gets really gnarly to do path pathing with hooks. We'll also find that causal scrubbing will be much easier with Circuits than with hooks.
The fundamental difficulty so far is that we have many different types of "thing" in our network:
- `Tensor` instances
- `Module` instances
- Functions like `torch.nn.functional.relu`
- Operators like `+=` (these translate to methods on Tensor)
- Temporaries (spooky, non-referencable ghosts)
There are also different relationships between "things". How many ways are there to add two things?
- We can have a `Module` that owns a `Parameter` and takes a `Tensor` as a forward argument
- We can have a `Module` that owns two `Module`s and runs them both and then adds the result together
- We can have a function that takes two `Tensors` and returns a new `Tensor`
- We can use an in-place operator to modify one of the operands
To make things clean and general, we're going to reduce the number of kinds of thing to two:
- `Tensor` instances
- `Circuit` instances
Within `Circuit`, we'll have subclasses:
- An `Array` subclass owns a `Tensor`, which could represent either a learned parameter or an input to the network. These are the same "thing" for our purposes in that they don't depend on any other input to compute their value.
- A subclass for each other operation. For example, an `Add` subclass will own a `Circuit` instance for each operand.
Every `Circuit` instance can be evaluated by first recursively evaluating its inputs, then running its operation. For this to always terminate, the graph of `Circuit` instances can't have any reference loops (formally, it's a directed acyclic graph).
I'll use "node" as a shorthand for "`Circuit` instance". Note that a node holds references to its input nodes, but doesn't know which other nodes (if any) consume its output.
This means if we hold a reference to the final output (in our case, the instance computing the logits), we can traverse the tree until we reach the inputs. Usually we'll draw this with the root (logits) at the top and the leaves (inputs) at the bottom, but this is arbitrary and I'll refer to the direction of the inputs as "input-ward".
It turns out that many confusing problems are avoided by having `Circuit` instances be `frozen` once created. For example, this immediately guarantees that we don't have any reference cycles.
Exercise: why?
<details>
<summary>Solution - Frozen instances and Cycles</summary>
The basic idea is that if we construct circuit C before circuit D, then it's impossible for C to depend on D as an input (because C's inputs were set in stone before D was created). Therefore, one valid way to evaluate any tree of `Circuit` is to evaluate them in the order they were constructed.
</details>
### Desiderata
Buck and/or Ryan will likely have presented on the design goals behind Circuits, but in brief some of the intents are:
- All `Tensors` computed are observable and modifiable in one and only one way, independent of whether they're an input, an activation, a parameter, or a temporary.
- All operations have a single representation: a node that computes a single output from zero or more input nodes.
- Algebra on trees will produce new trees which are equivalent by default. Common operations are provided by the library and extensively tested for correctness.
- The user doesn't have to think about execution speed. Evaluating a tree efficiently is the library's problem, not the user's problem.
- Trees of nodes are serializable and computations on them can be efficiently memoized (cached). This is the biggest aspect that we won't tackle today, but in general if you see something that seems overly complicated, it's probably the case that it is done in a complicated way to better support this feature.
## Build Your Own Circuits
We're going to take a long detour next and re-implement key parts of Circuits. The real version is implemented in Rust as the `rust_circuit` library (`rc` for short).
Your version will be largely API compatible, but we'll omit edge cases, tackle some things in less generality, and ignore some flags that are unimportant for conceptual understanding.
The goal of this is to introduce the API step by step and build a mental model of what each class in the library is doing so that when you get error messages they will make sense to you.
Just copy over the base class in the block below - there's nothing to implement here.
### Circuit Names
The name field of the `Circuit` doesn't do anything during evaluation, and is intended for the experimenter's convenience.
You can use whatever conventions you like for names. Names aren't required to be unique, but I recommend giving them unique names wherever possible to aid with debugging.
### Circuit Shape
In the constructor, Circuit instances should compute and store the shape of their output. Note that this is always computable given the shape of the inputs and the standard PyTorch broadcasting rules.
There is some weirdness in PyTorch where shapes can be a `tuple` of `int`, or `torch.Size` which is a subclass of `tuple`. To keep things simple, we'll only work with `tuple`.
The last couple functions `update` and `get` won't work yet - ignore them for now.
"""
# %%
if "SKIP":
# Only need this for testing in dev
import rust_circuit as rc
class Circuit:
"""A node in a computational graph.
For technical reasons, this isn't a frozen class, but treat it as if it was.
You won't need to create Circuit instances; this is just a base class to be subclassed.
Can evaluate itself by recursively evaluating its inputs.
Like rc.Circuit.
"""
name: str
"""A non-unique identifier that can follow any convention you like."""
inputs: list["Circuit"]
"""Other circuits that this circuit depends on in order to evaluate itself. Not intended to be mutated."""
shape: tuple[int, ...]
"""The shape of the tensor returned by self.evaluate(). Must be specified for Scalar, Symbol and is inferred for others."""
def __init__(self, *inputs: "Circuit", name: str = ""):
self.name = name
self.inputs = list(inputs)
def evaluate(self) -> t.Tensor:
"""Return the value of this Circuit as a Tensor."""
raise NotImplementedError
def __repr__(self):
return f"{self.name}({', '.join(inp.name for inp in self.inputs)})"
def children(self) -> list["Circuit"]:
return self.inputs
def visit(self, visitor: Callable[["Circuit"], bool]) -> None:
"""Call visitor(self), and recurse into our inputs if the visitor returns True."""
# Note: cache would be instantiated here, see visit_circuit_non_free
def recurse(c: Circuit):
should_recurse = visitor(c)
if should_recurse:
for inp in c.inputs:
recurse(inp)
recurse(self)
def cast_einsum(self) -> Einsum:
"""Downcast to Einsum, raising an error if we aren't actually an Einsum."""
assert isinstance(self, Einsum)
return cast(Einsum, self)
def cast_add(self) -> Add:
"""Downcast to Add, raising an error if we aren't actually an Add."""
assert isinstance(self, Add)
return cast(Add, self)
def cast_array(self) -> Array:
"""Downcast to Array, raising an error if we aren't actually an Array."""
assert isinstance(self, Array)
return cast(Array, self)
def update(self, matcher, transform: "TransformIn") -> "Circuit":
"""Return a new circuit with the transform applied to matching nodes."""
u: Union[Updater, BasicUpdater]
try:
u = Updater(transform)
except NameError:
u = BasicUpdater(transform)
return u.update(self, matcher)
def get(self, matcher) -> set["Circuit"]:
"""Return all nodes that match the criteria."""
m: Union[IterativeMatcher, Matcher]
try:
m = IterativeMatcher(matcher)
except NameError:
m = Matcher(matcher)
return m.get(self)
def get_unique(self, matcher: "IterativeMatcherIn") -> Circuit:
"""Convenience function."""
return IterativeMatcher(matcher).get_unique(self)
def print(self) -> None:
"""Simplified version of Circuit.print that doesn't take a PrintOptions."""
print(repr_tree(self))
# %%
"""
### Circuit Traversal
The `visit` method allows you to call a `Callable` on the node and recursively on each subtree. (The standard library type `Callable` includes regular functions, lambda functions, or any object that defines a `__call__` method).
Make sure you understand `visit`, since you'll be writing variants of it later. It's like the classic "visitor pattern" except we have the ability to skip parts of the tree.
Exercise: when you call `visit`, in what order are the nodes visited? Refer to [Wikipedia](https://en.wikipedia.org/wiki/Tree_traversal) if you're not sure what the names are.
<details>
<summary>Traversal Order</summary>
This is a pre-order traversal - first we call the visitor on the current node, then recurse from "left" (inputs[0]) to "right" (inputs[-1]).
</details>
Exercise: if we visit `e` in the below tree and `visitor(c)` returns False, which nodes are visited?
<details>
<summary>Solution</summary>
All descendants of `c` are skipped, meaning `a` and `b` are skipped.
It's important to understand that the traversal continues in other siblings of `c` though, so `d` is visited.
</details>
"""
# %%
a = Circuit(name="a")
b = Circuit(name="b")
c = Circuit(a, b, name="c")
d = Circuit(name="d")
e = Circuit(c, d, name="e")
def example_visitor(current: Circuit) -> bool:
print("Visiting: ", current.name)
return True
e.visit(example_visitor)
# %%
"""
### Array
Exercise: implement the methods so that the tests pass. In this case, `Array` doesn't depend on any other nodes to evaluate itself - it can just return its value field.
"""
# %%
class Array(Circuit):
"""Represents a learned parameter or a specific input to the network.