Today, you'll learn more features of the Circuits library and use them to write a GPT-style transformer from scratch. By the end, you'll be able to load pretrained weights and do inference on GPT-2.
- Learning Objectives
- Readings
- The Symbol class
- Useful Helper Functions
- LayerNorm in Circuits
- Substitutions and rc.Expander
- Fancy Validation
- Circuit Diffing
- Printing and Serialization
- Tensor Storage
- Redwood Research File System (RRFS)
- Circuit Simplification
- Automatic Simplification
- "Batchable" Symbols
- Loading A Reference Model
- Intro to Modules
- MLP Module
- Token Embedding
- Positional Embedding
- Computing the Attention Mask
- Self-Attention
- GPT2 Block
- Unembedding
- Full GPT-2
- Moment of Truth
- Bonus
After today's material, you should be able to:
- Explain all the individual operations in a transformer and how they combine
- Explain how batch dimensions are handled in Circuits
- Write your own
Module
s as necessary
- Language Modelling with Transformers - note that we're going to completely ignore dropout, since it isn't used at inference time.
from __future__ import annotations
import os
import sys
import torch
import torch as t
import torch.nn as nn
from tqdm.notebook import tqdm
import torch.nn.functional as F
from typing import Union, Any
import pandas as pd
from einops import rearrange, repeat
import rust_circuit as rc
from rust_circuit import Circuit, Add, Scalar, Einsum, Symbol, Array, Rearrange, Module, ModuleSpec, ModuleArgSpec
from dataclasses import dataclass
import remix_d2_utils
from remix_d2_utils import LayerWeights, GPT2Weights, get_weights
import remix_utils
pd.set_option("display.precision", 3)
MAIN = __name__ == "__main__"
@dataclass(frozen=True)
class GPTConfig:
"""Constants used throughout the GPT2 model."""
activation_function: str = "gelu"
num_layers: int = 12
num_heads: int = 12
vocab_size: int = 50257
hidden_size: int = 768
max_position_embeddings: int = 1024
dropout: float = 0.1
layer_norm_epsilon: float = 1e-05
config = GPTConfig()
In Day 1, our Circuit
had specific sizes like 28*28 "baked in", and it would've been a pain to reuse the same Circuit
for a situation where the structure of the network was the same, but the sizes of some things were different.
It would be nice if we had a Circuit
that represents the network architecture in the abstract. This Circuit
wouldn't reference any weight Array
s, and wouldn't reference specific sizes like the hidden size that can vary between models using this architecture. We could write this abstract thing once and then "substitute in" weights to obtain a Circuit
representing a specific model we can evaluate.
While we're at it, it would be nice to solve the problem of batching. In Day 1, we manually modified our Circuit
to have batching support. While most things "just work" with a batch dimension prepended to the front (thanks to broadcasting rules), some things didn't work without modification like our Einsum
string.
Both difficulties are addressed with a new subclass of Circuit
called Symbol
. It has a name and a shape, but no data. Trying to evaluate a Symbol
will throw an exception. Symbol
is just a placeholder that needs to be replaced with something else before evaluation happens. It's safer than an Array
as a placeholder - if we forgot to replace an Array
, then we would still be able to call evaluate
using the wrong data.
A Symbol
also has a UUID. If you've never heard of a UUID, all you need to know is that that they are random numbers big enough that when you generate one, it's almost certainly true that no other Circuits user has ever generated the same one. Refer to the Python standard library if you want to know more.
The point of the UUID is to distinguish between two Symbols
with identical names and shapes. Equality on Symbol
requires that the UUID is equal.
Can't we just use the memory address of the Symbol object for uniqueness?
We're going to be saving Circuits to disk, and they can't be guaranteed to be in the same memory location after reloading the Circuit. The UUID can be saved with the Symbol and thus stay the same between sessions.
To solve the issue with specific sizes being hardcoded, one option is to just use a special value that a size can't have, like -1, so a single input to a transformer would have size (-1, -1) for the sequence and hidden dimensions, for example.
One issue is we lose out on some consistency checking in this way - it's no longer possible to check that the sequence dimension is consistent everywhere in the network, for example.
The galaxy brain solution is to have a bunch of pre-determined special values, where one represents "hidden dimension", one represents "sequence dimension" and so on. If we use them consistently throughout the architecture, we'll be able to enforce consistency. In the current implementation, these are just some primes that are larger than all integers that are used for actual sizes.
This particular implementation is pretty unsafe and subject to change. Its main benefit is that all the regular shape inference code you wrote in Day 1 supports this feature without any modifications (it doesn't care whether the integer represents a real size or a symbolic size).
HIDDEN, SEQ, HEADS, MLP_PROJ, HEAD_SIZE, LOG_LIKELIHOOD_CLASSES, *__unused = rc.symbolic_sizes()
print("The hidden dimension (dimension of the residual stream", HIDDEN)
print("The number of classes (for our language model, the vocabulary size)", LOG_LIKELIHOOD_CLASSES)
print("The sequence dimension (equal to the length of the context window)", SEQ)
def sym(shape: rc.Shape, name: str):
"""Create a new symbol."""
return Symbol.new_with_random_uuid(shape, name)
ln_input = sym((HIDDEN,), "ln.input")
ln_weight = sym((HIDDEN,), "ln.w.scale")
ln_bias = sym((HIDDEN,), "ln.w.bias")
It would be really confusing to show the primes when printing. Instead, the printing and serialization code knows about the primes and renders them as "0s", "1s", "2s", etc.
print("This is a placeholder of symbolic shape (HIDDEN,).", ln_input)
print("When serialized or printed, these are rendered specially:")
ln_input.print()
print("\nSymbolic sizes are propagated as normal: ")
x = Einsum.from_einsum_string("a,b->ab", ln_weight, ln_bias)
x.print(rc.PrintOptions(shape_only_when_necessary=False))
print(
"\nSymbolic sizes can also be products of symbolic sizes (this works because they are primes and have a unique factorization)"
)
x.rearrange_str("a b -> (a b)").print(rc.PrintOptions(shape_only_when_necessary=False))
print("\nSymbolic sizes can be products of regular and symbolic sizes")
cat = rc.Concat(ln_input, ln_input, ln_input, axis=0)
cat.print(rc.PrintOptions(shape_only_when_necessary=False))
We'll need to compute 1/x
for a Circuit
x
today. We could use operator overloading to make this syntax work, but then we'd have no way of giving the result a name
field.
Instead, we have functions like rc.reciprocal
, which evaluate to the same as torch.reciprocal
but allow specifying the name. rc.reciprocal
also has metadata describing algebraic properties of this function. For example, (1/x)[i] ==
1/(x[i])`.
x_data = t.tensor([0.1, 0.2, 0.3])
x_arr = Array(x_data, name="x")
out = rc.reciprocal(x_arr, name="x_inv")
t.testing.assert_close(out.evaluate(), t.reciprocal(x_data))
t.testing.assert_close(out.evaluate(), 1 / x_data)
out.print()
Another one you'll need is rc.rsqrt
, which stands for reciprocal square root. Compared to computing sqrt(x)
and then doing a reciprocal, rsqrt
is one operation so it could be faster and/or more accurate depending on the implementation.
out = rc.rsqrt(Array(x_data), name="x_rsqrt")
print(out.evaluate())
diff = out.evaluate() - (1 / x_data) ** 0.5
print(diff.max().item())
Finally, rc.last_dim_size(x)
evaluates to torch.full(x.shape[:-1], x.shape[-1])
.
This is kinda weird and subject to change, but the point of using this is that symbolic sizes work properly with it.
Exercise: discuss with your partner why mean_broken
below computes the wrong output. This is an important point, so call a TA if you don't feel confident. Write a new mean
that works correctly.
def mean_broken(x: Circuit) -> Circuit:
"""Compute the mean of x. This one doesn't work properly.
x: shape (n,)
Output: shape ()
"""
denominator = rc.reciprocal(Scalar(x.shape[-1], name="last_dim_size"))
return Einsum.from_einsum_string("a,->", x, denominator, name=f"mean({x.name})")
def mean(x: Circuit) -> Circuit:
"""Compute the mean of x.
x: shape (n,)
Output: shape ()
"""
"TODO: YOUR CODE HERE"
pass
mean_broken_circuit = mean_broken(ln_input)
replaced = mean_broken_circuit.update("ln.input", lambda _: Array(x_data))
print("Correct mean is 0.2, but broken version returned: ", replaced.evaluate().item())
mean_circuit = mean(ln_input)
replaced = mean_circuit.update("ln.input", lambda _: Array(x_data))
actual = replaced.evaluate().item()
print("Expected 0.2, got: ", actual)
Exercise: explain to your partner why sum_last_broken
doesn't work properly. Fix the issue.
def sum_last_broken(x: Circuit) -> Circuit:
"""Compute the sum of x along the last dimension. This one doesn't work properly.
x: shape (n,)
Output: shape ()
"""
return Einsum.from_einsum_string("a->", x)
sum_broken_circuit = sum_last_broken(ln_input)
replaced = sum_broken_circuit.update("ln.input", lambda _: Array(x_data))
print("Expected 0.6, got: ", replaced.evaluate())
A PyTorch implementation of LayerNorm is provided for reference. Additionally, in the remix_images
folder there are some PDF schematics you can refer to. To view these, either copy them to your local machine or install the vscode-pdf extension.
Your implementation can be less general than the PyTorch version. In particular, you should assume:
- The parameter
elementwise_affine
is alwaysTrue
; - The shapes of
input
,weight
, andbias
are all(HIDDEN,)
(we'll get to batching later); - Normalization is computed over the last (and only) dimension; and
- You don't need to worry about resetting the
weight
andbias
as we'll only be using the Circuit for inference today.
This is sufficient for doing inference in transformers, and specializing in this way makes our job easier.
It's not important today to have an intuition about what LayerNorm is doing other than "it's surprisingly complicated". The article Re-Examining LayerNorm is worth a read at some future time.
Exercise: Implement LayerNorm as a Circuit
. The names of the internal nodes are up to you; it can be helpful to give unique names to things for debugging purposes but it's also fine to ignore the names.
In addition to Einsum
, Add
, and Scalar
from before, use the below functions. You could write your own lambdas for these, but using these ones from the API ensures that your Circuit will be serializable (more on this later).
Tip: Circuit provides some helpful methods like Circuit.add(other)
as a shorthand for Add(self, other)
. The method Circuit.mul
is also useful shorthand for elementwise multiplication.
Hint: help me get started!
Let's start by computing the mean with an Einsum. The syntax is the same as yesterday: mean = rc.Einsum.from_einsum_string("h->", input)
or rc.Einsum.from_fancy_string("hidden -> ", input)
.
I have a weird Scalar 0.000099930048965724 in my Circuit and I can't figure out why!
This value is equal to 1/HIDDEN
and it probably means you accessed the input shape directly when building the Circuit. You have to use rc.last_dim_size
to delay evaluation until after real shapes have been substituted for the symbolic shapes. This is a subtle point - ask a TA if this isn't clear!
class TorchLayerNorm(nn.Module):
"""A PyTorch implementation of nn.LayerNorm, provided for reference."""
weight: nn.Parameter
bias: nn.Parameter
def __init__(
self, normalized_shape: Union[int, tuple, t.Size], eps=1e-05, elementwise_affine=True, device=None, dtype=None
):
super().__init__()
if isinstance(normalized_shape, int):
normalized_shape = (normalized_shape,)
self.normalized_shape = tuple(normalized_shape)
self.normalize_dims = tuple(range(-1, -1 - len(self.normalized_shape), -1))
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(t.empty(self.normalized_shape, device=device, dtype=dtype))
self.bias = nn.Parameter(t.empty(self.normalized_shape, device=device, dtype=dtype))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self) -> None:
"""Initialize the weight and bias, if applicable."""
if self.elementwise_affine:
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, x: t.Tensor) -> t.Tensor:
"""x and the output should both have shape (batch, *)."""
mean = x.mean(dim=self.normalize_dims, keepdim=True)
var = x.var(dim=self.normalize_dims, keepdim=True, unbiased=False)
x = x - mean
x = x / (var + self.eps) ** 0.5
if self.elementwise_affine:
x = x * self.weight
x = x + self.bias
return x
def layernorm(input: Circuit, weight: Circuit, bias: Circuit, eps=1e-05) -> Circuit:
"""Circuit computing the same thing as TorchLayerNorm, subject to the simplifications detailed in the instructions."""
"TODO: YOUR CODE HERE"
pass
ln = layernorm(ln_input, ln_weight, ln_bias)
ln.print()
Above, we didn't have to specify the shape of ln
- it was automatically inferred to be (HIDDEN,)
via the normal rules you implemented yesterday.
If we try to call ln.evaluate()
, we'll see an error about it not being explicitly computable. This is telling us that we have Symbol
instances remaining somewhere in the tree.
We could use an Updater
like in Day 1 to replace the Symbol
s with Array
, but then we'd have to update our Einsum
again, and we'd need a separate Updater
for each thing we want to replace.
rc.Expander
is a class that does all the updates in one shot and also takes care of the batching. Here's how to use it. (If your output doesn't match the reference output, continue reading for troubleshooting tips).
Help - It's not working!
Some steps to start with:
- Check the shapes in the printout and see if they all make sense. Unintended broadcasting is one bad thing that can happen.
- Give your nodes unique names - this helps with reading the printout, and also prevents bugs related to multiple nodes having the same name.
If it's still not working, don't spend much time stuck here - call a TA!
t.manual_seed(0)
actual_inputs = [("single dim", t.randn((8,))), ("batch and seq dims prepended", t.randn((5, 7, 8)))]
actual_weight = t.randn((8,))
actual_bias = t.randn((8,))
torch_ln = TorchLayerNorm((8,))
torch_ln.weight.data = actual_weight
torch_ln.bias.data = actual_bias
for name, actual_input in actual_inputs:
print(f"\nTesting LayerNorm with {name}")
expected = torch_ln(actual_input)
expander = rc.Expander(
("ln.input", lambda _: Array(actual_input)),
("ln.w.scale", lambda _: Array(actual_weight)),
("ln.w.bias", lambda _: Array(actual_bias)),
)
computable_ln = expander(ln, fancy_validate=True)
computable_ln.print(rc.PrintOptions(shape_only_when_necessary=False))
assert computable_ln.is_explicitly_computable
t.testing.assert_close(computable_ln.evaluate(), expected)
You can pass fancy_validate=True
to have the Expander
spend time doing additional checks - I recommend always doing this unless it's too slow.
Exercise: With fancy_validate=True
, mess around with different inputs and see what happens. For example:
- Pass a
Matcher
that doesn't match anything - Pass an
Array
which is an incompatible shape - Pass two
Matcher
s that match the same node
print(rc.diff_circuits(circuit1, circuit2))
can take two Circuit
s and render a colored diff between them. The problem of generating a minimal diff between two arbitrary trees is actually quite hard, and at the time of this writing this function isn't particularly smart.
If your layernorm
didn't produce the right output, diff it now against the version produced by remix_d2_utils.working_layernorm
and see if this helps. Play with some of the keyword arguments to diff_circuits
. This may or may not be actually helpful, and is mainly to emphasize that doing diffs is a thing you should consider when debugging. (Also, note the argument order: the first argument is new
and the second is old
, so added (green) nodes will be from the first argument.) If you feel stuck, call a TA.
Exercise: Once your layernorm
is correct, diff working_layernorm
against mystery_layernorm_a
and mystery_layernorm_b
. Describe the differences in your own words.
Solution - Issues with other layernorm implementations
A forgot to use the value for epsilon passed in, which makes the results very slightly off.
B hardcodes in the scaling factor for the mean (1/hidden_size) instead of computing it on each forward pass. This is a problem because if we start with abstract circuits, this will be (1/large prime) rather than (1/actual hidden size).
"TODO: YOUR CODE HERE"
If you just do print(circuit)
or repr(circuit)
, this renders a small representation which is limited in depth.
circuit.print()
however, is a completely different beast. This takes (or constructs) a rc.PrintOptions
instance which has a ton of arguments you can fiddle with to customize your printing. Generally, you'll want to create your own PrintOptions
instance which is set up the way you like, and call printer.print(circuit)
or printer.repr(circuit)
on that.
The most important argument to PrintOptions
is bijection
, which causes your printout to be a complete representation of the circuit (with some caveats like custom GeneralFunctions). When bijection=True
, you can convert the string representation back into an equivalent tree with rc.Parser
.
parser = rc.Parser()
printer = rc.PrintOptions(bijection=True)
text = printer.repr(ln)
print(text)
deserialized_ln = parser(text)
assert ln == deserialized_ln
Serialization requires a way to serialize any tensors inside the circuit, such as model weights. Storing the tensors directly in the text representation would be very inefficient, so instead we hash the contents of the tensor and store the tensor on a disk in binary format at a path derived from their hash.
Importantly, when bijection=True
this happens automatically, so unlike print(circuit)
which is cheap and has no side effects, circuit.print()
might write a lot of data to disk if you have large tensors.
By default, a series of nested folders on your local disk will be used for storage. For example, if a tensor's hash is 4f2fb80a7d72349dd5353e32c67a921368d6ab8eec58590a0c09b62492000be9
, then by default a file is created in ~/tensors_by_hash_cache/4f2f/b80a/
with the remainder of the hash as the filename.
Exercise: Check the hash of rand1
below using Array.tensor_hash_base16()
. Print out the circuit total
, and then check your local filesystem to find the file that was created during printing. Verify that the file is roughly the size you expect, and that the tensor is only stored once (even though there are two references to it).
t.manual_seed(0)
rand_tensor = t.randn((256,))
rand1 = Array(rand_tensor, name="rand1")
rand2 = Array(rand_tensor, name="rand2")
total = rand1.add(rand2)
"TODO: YOUR CODE HERE"
Redwood Research File System (RRFS) is just a big network filesystem managed by AWS where we store things that are too big for version control, such as model weights or datasets for experiments. If you type ls ~/rrfs
in a terminal, you can check if you have this filesystem mounted on your machine.
If you upload your tensors to RRFS using sync_tensors=True
, other users can automatically download them as needed over the network. Code for this is included in the rc.Parser
, so the whole process is transparent - if you parse a Circuit
and don't have the referenced Tensor
s on their local machine, the library will automatically copy them over the network.
Use caution when sync_tensors=True
to avoid using up excessive network bandwidth and disk space on tensors you don't actually care about (like the random tensor above).
If you find yourself creating lots of large temporary tensors, you may want to use bijection=False
to avoid filling up your local disk.
For today, you won't have to upload any tensors and we'll only be downloading the pretrained weights via this mechanism.
In the PyTorch version, we needed to specify a flag for elementwise_affine
and then we had to check this in the forward pass to avoid unnecessary computation. In Circuits, if we didn't want to have a bias in our layer norm, we can just substitute a Scalar
with the appropriate shape and the value zero. While this won't use much memory (remember, Scalar
is a view into a 1-element Tensor
), this will still cause PyTorch to perform an unnecessary addition operation. We can detect and optimize this out with a rewrite.
Exercise: implement add_elim_zeros
, then replace the bias node in ln
with a zero Scalar
and call add_elim_zeros
. Check the diff and verify that the bias node was completely removed.
Help, I get a `ExpandBatchingRankTooLowError` error
Using the expander
construction from above to replace the bias node won't work because of the dimension-mismatch, if that's what you tried. Instead use
zero_bias = ln.update("ln.w.bias", lambda _: Scalar(0.0, ()))
def add_elim_zeros(add: Add) -> Add:
"""Rewrite "add" to eliminate any operands that are Scalar(0.0).
Return "add" unmodified if nothing was done, and a new Add instance otherwise.
"""
"TODO: YOUR CODE HERE"
pass
"TODO: YOUR CODE HERE"
The function rc.simp
automatically perfoms the above rewrite as well as around 20 more common ones for you.
Exercise: Run rc.simp
and diff the result against the output of your simplification. Are they equal? If not, what else did rc.simp
do?
Solution - rc.simp
After removing the bias node, it noticed that an Add
with only one remaining argument, meaning the Add
can also be replaced with its argument.
"TODO: YOUR CODE HERE"
In GPT, our LayerNorm would run independently at each position of the sequence and independently on each batch element. In other words, the input would actually (BATCH, SEQ, HIDDEN)
shape, not (HIDDEN,)
as in our input Symbol
.
In fact, we could keep adding more leading dimensions to the (HIDDEN,)
shape, and LayerNorm would work properly until we hit some internal library limitation.
When you can safely replace a Symbol
with an object that has more leading dimensions but is otherwise compatible, we say that symbol is "batchable".
Exercise: Make a new random Tensor
that has batch and sequence dimensions and pass it into Expander
. Print the result and verify that the new Einsum
s look like what you'd expect. How many leading dimensions can you add? What breaks first once you add too many?
"TODO: YOUR CODE HERE"
When writing your own implementation, it's generally easiest to start with a known-good implementation and check that your version has matching behavior along the way.
We'll load a reference copy of the pretrained GPT2-small now. When people say GPT-2, they are usually referring to the 774M parameter version; the small version has only 117M parameters. You don't need to understand the loading code for today. The first time you run this, it will take some time to download over the network (on later runs, it will use the local cache).
By the end of the day, your model will produce the same output as ref_circuit
!
def tokenize(prompts: list[str]) -> tuple[t.Tensor, t.Tensor]:
out = tokenizer(prompts, padding=True, return_tensors="pt")
input_ids = out["input_ids"]
input_ids[input_ids == tokenizer.pad_token_id] = 0
return (input_ids, out["attention_mask"])
prompts = ["Former President of the United States of America, George", "Paris is the capital city of"]
circ_dict, tokenizer, model_info = remix_utils.load_gpt2_small_circuit()
ref_input_ids_tensor, attention_mask_tensor = tokenize(prompts)
ref_input_ids = Array(ref_input_ids_tensor)
ref_attention_mask = Array(attention_mask_tensor.float())
bind_module = circ_dict["t.bind_w"]
tok_embeds = rc.GeneralFunction.gen_index(circ_dict["t.w.tok_embeds"], ref_input_ids, index_dim=0)
assert tok_embeds.shape == (2, 10, 768)
expected = circ_dict["t.w.tok_embeds"].cast_array().value[ref_input_ids_tensor]
t.testing.assert_close(tok_embeds.evaluate(), expected)
ref_circuit = model_info.bind_to_input(bind_module, tok_embeds, circ_dict["t.w.pos_embeds"], ref_attention_mask)
assert ref_circuit.shape == (2, 10, 50257)
def eval_circuit(circuit: Circuit) -> None:
logits = circuit.evaluate()
for i in range(2):
end_seq = ref_input_ids_tensor[i].nonzero().max().item()
topk = t.topk(logits[i, end_seq], k=10).indices
next_tokens = tokenizer.batch_decode(topk.reshape(-1, 1))
print("Prompt: ", prompts[i])
print("Top 10 predictions: ", next_tokens)
if i == 0:
assert " Washington" in next_tokens
assert " Bush" in next_tokens
elif i == 1:
assert " France" in next_tokens
eval_circuit(ref_circuit)
Our LayerNorm has quite a few nodes representing low-level operations, and often we just want to view and work with it as "yeah, that's a box containing a standard LayerNorm" and not care about the specifics unless we specifically need to "open the box" and do rewrites that affect the internals.
In Circuits, we'll use the class rc.ModuleSpec
to represent the concept of a "standard LayerNorm". All the LayerNorms in our transformer will share one instance of this. Inside the ModuleSpec
, we have a Circuit with Symbol instances as placeholders - this is called the "spec circuit" or just "spec" for short. The output of our function layernorm
above is already the appropriate thing.
The ModuleSpec
also contains a rc.ModuleArgSpec
instance, one for each unique Symbol
in the spec circuit. This ModuleArgSpec
stores information needed for batching.
Below we've written LN_SPEC
in caps to emphasize that this is an immutable constant throughout our transformer. Note that a ModuleSpec
is not itself a Circuit and thus can't directly appear in a tree.
I feel confused about which `Symbol`s should be batchable.
The only reason to make something not batchable is to get a helpful error if we mess up and try to pass in an input of the wrong shape. In our case, we already specialized our Circuit to only work when normalizing the last dimension, so it wouldn't make sense to have a weight or bias of more than 1D. On the other hand, in the real transformer we will have an input shape of (batch, seq, hidden)
where the first two dimensions are treated as batch dimensions.
def layernorm_spec() -> ModuleSpec:
"""Return a ModuleSpec representing GPT's standard LayerNorm.
Note that you only ever need to call this once - all Modules can share it.
"""
sym_input = sym((HIDDEN,), "ln.input")
sym_weight = sym((HIDDEN,), "ln.w.scale")
sym_bias = sym((HIDDEN,), "ln.w.bias")
spec_circuit = layernorm(sym_input, sym_weight, sym_bias)
return ModuleSpec(
spec_circuit,
arg_specs=[
ModuleArgSpec(sym_input, batchable=True),
ModuleArgSpec(sym_weight, batchable=False),
ModuleArgSpec(sym_bias, batchable=False),
],
)
LN_SPEC = layernorm_spec()
Each time in our computation tree that a standard LayerNorm appears, we will have a rc.Module
instance representing "a box containing a standard LayerNorm with these specific weights and these specific inputs."
In other words, if the ModuleSpec
is like a function with argument names and a body, the Module
instance is like a function call site where specific values are supplied for the arguments.
Module
has a method substitute()
which will use the same underlying logic as Expander
to substitute the arguments for the placeholders. Importantly, this happens by name and not using Matcher
, so it's necessary that within the ModuleSpec
all distinct Symbol
s have unique names. The result of substitution is a regular node, Add
in this case.
When you evaluate()
a Module
, it just does substitute
to replace itself with a regular node, and then calls evaluate
on the regular node.
To reiterate: each place in the graph where a LayerNorm is used will be a different Module
instance because the inputs are different, but all the Module
instances will share the same ModuleSpec
.
Below we show how to make a Module
version of LayerNorm. After this, you'll do the MLP and Attention Module
implementations on your own.
Exercise: Run the cell and check out the notation in the printed Module
. Note that the "!" means that when we evaluate the Module
, all occurrences of the name on the right are replaced with the value on the left.
The first letter of "ftt" means that the first flag (batchable) is False
; the remaining letters mean that the other two flags (which we'll completely ignore today) are True
.
Exercise: Why do we need to unpack the dictionary literal with **
below?
Solution - Dictionary Unpacking
It's because the symbol names have a period in them, which makes them not valid Python identifiers. If the name was "ln_input" instead then we could just pass ln_input=Array(actual_input)
without requiring unpacking.
print("LayerNorm module should match earlier version")
some_layernorm = Module(
spec=LN_SPEC,
name="ln1.call",
**{
"ln.input": Array(actual_input, name="ln1_actual_input"),
"ln.w.scale": Array(actual_weight, name="ln1_actual_scale"),
"ln.w.bias": Array(actual_bias, name="ln1_actual_bias"),
}
)
some_layernorm.print()
t.testing.assert_close(some_layernorm.evaluate(), computable_ln.evaluate())
print("Two LayerNorm modules should have equal specs:")
another_layernorm = Module(
spec=LN_SPEC,
name="ln2.call",
**{
"ln.input": Array(t.randn((50,)), name="ln2_real_input"),
"ln.w.scale": Array(t.randn((50,)), name="ln2_real_scale"),
"ln.w.bias": Array(t.randn((50,)), name="ln2_real_bias"),
}
)
assert some_layernorm.spec == another_layernorm.spec
The first child of a Module
is its ModuleSpec
's circuit, followed by the concrete argument values in the same order as they are in the ModuleArgSpec
s. In my case I wrote the input at index 1, the scale at index 2, and the bias at index 3 but yours may be different.
Exercise: Write an IterativeMatcher(...).children_matcher(...) that matches the ln1_actual_scale
node when called on some_layernorm
.
some_layernorms_scale: Array
"TODO: YOUR CODE HERE"
assert some_layernorms_scale.name == "ln1_actual_scale"
It's always a good idea to do sanity checks when you're writing specs. Two things that should be true are:
- Each
Symbol
should appear somewhere in the "body" of the spec. This is like checking a function for arguments that are provided but not used. - Each of your
Symbol
s needs a unique name. This is like checking a function for two arguments with the same name - if this happened the library wouldn't know which one to substitute into the body.
LN_SPEC.check_all_inputs_used()
LN_SPEC.check_unique_arg_names()
We will be writing a Module
for the attention layer, but often we want to "open the box" on this one and interact with the internals. We need a convenient way to get rid of the Module
and just have regular nodes that we can do regular rewrites on.
In the analogy where a Module
node represents a function call, we want to "inline" the call and not have a function call at all but instead substitute the body of the function with all the arguments replaced.
Recall that when we evaluate
a Module
, it just calls its own substitute
method to turn into a regular node. We can just call substitute
ourselves to do the inlining! This new form is extensionally equal.
Exercise: compare the before and after printouts. Precisely what did substitute
do?
print("Before substitute:")
some_layernorm.print()
print("\nAfter substitute: ")
some_layernorm.substitute().print()
t.testing.assert_close(some_layernorm.evaluate(), some_layernorm.substitute().evaluate())
Naturally, Modules can have other Modules as arguments. This works as you'd expect, and you can perform matching and update operations as usual.
Exercise: make updated_outer
be the same as outer_layernorm
except that the inner module some_layernorm
is substituted away. It should be one line.
outer_layernorm = Module(
spec=LN_SPEC,
name="ln2.call",
**{
"ln.input": some_layernorm,
"ln.w.scale": Array(t.randn((8,)), name="ln2_real_scale"),
"ln.w.bias": Array(t.randn((8,)), name="ln2_real_bias"),
}
)
updated_outer: Circuit
"TODO: YOUR CODE HERE"
ln1_call = rc.IterativeMatcher("ln2.call").children_matcher({1}).get_unique(updated_outer)
assert ln1_call.name == "ln1.call"
assert isinstance(ln1_call, Add)
We'll do one last subtlety with Module
and then get back to building GPT-2.
Once we have multiple Module
s inside each other, it's legal to replace a placeholder Symbol
with another placeholder Symbol
during substitution. This works as long as all the Symbol
instances are gone at the end.
Exercise: inspect the following Circuit and explain to your partner what is happening. When we evaluate this Circuit, does it matter which of the two Module
s gets removed first via substitute()
? Why or why not?
Solution - substitution order
x_doubler.substitute()
means replacing a
and b
with x
, and bind_x.substitute()
means replacing x
with x_arr
.
Just like substituting variables in a system of equations, the order of substitution doesn't affect the final evaluation result.
a_sym = sym((), name="a")
b_sym = sym((), name="b")
x_sym = sym((), name="x")
ADDER = ModuleSpec(
Add(a_sym, b_sym, name="a_plus_b"),
arg_specs=[ModuleArgSpec(a_sym, batchable=True), ModuleArgSpec(b_sym, batchable=True)],
)
dbl_module = rc.Module(ADDER, name="x_doubler", a=x_sym, b=x_sym)
BIND_SPEC = ModuleSpec(dbl_module, arg_specs=[ModuleArgSpec(x_sym, batchable=True)])
bind_x = rc.Module(BIND_SPEC, x=Array(t.tensor([2.0, 4.0]), name="x_arr"), name="bind_x")
bind_x.print()
Exercise: implement a Module
version of the MLP in GPT-2. Again, ignore the flags besides batchable
.
The variable bindings
in the test code uses our internal conventions for naming the Symbols
- the m
stands for MLP and w
stands for "weight". While these are just conventions, it'll be helpful to get used to reading them as you'll be frequently writing Matcher
s that match these names.
Note that the input to the MLP Module
should be declared as shape (HIDDEN,)
. Just like the LayerNorm, this reminds us that the MLP does the same operation at every sequence position. When we use this in the full transformer, since the MLP input is batchable
then Expander
will take care of the leading batch and sequence dimensions for us.
For the nonlinearity, use rc.gelu
.
I'm confused about what symbolic shapes to use.
After the first projection, the shape in all common transformers as of this writing will be 4 * HIDDEN
- it's fine to just write it this way.
To be extra generic, you could use MLP_PROJ
instead because there's nothing special or optimal about the number 4.
def mlp(inp: Circuit, linear1_weight: Circuit, linear1_bias, linear2_weight: Circuit, linear2_bias: Circuit) -> Circuit:
"""Return a Circuit computing GPT's standard MLP."""
"TODO: YOUR CODE HERE"
pass
def mlp_spec() -> ModuleSpec:
"""Return a ModuleSpec representing GPT's standard MLP."""
"TODO: YOUR CODE HERE"
pass
print("MLP: Testing with random weights - should match PyTorch")
m_input = t.randn((8,))
m_proj_in = t.randn((32, 8))
m_in_bias = t.randn((32,))
m_proj_out = t.randn((8, 32))
m_out_bias = t.randn((8,))
t.manual_seed(0)
bindings: dict[str, Circuit] = {
"m.input": Array(m_input),
"m.w.proj_in": Array(m_proj_in),
"m.w.in_bias": Array(m_in_bias),
"m.w.proj_out": Array(m_proj_out),
"m.w.out_bias": Array(m_out_bias),
}
MLP_SPEC = mlp_spec()
MLP_SPEC.check_all_inputs_used()
MLP_SPEC.check_unique_arg_names()
mlp_call_site = Module(spec=MLP_SPEC, name="m.call", **bindings)
expected = F.linear(F.gelu(F.linear(m_input, m_proj_in, m_in_bias)), m_proj_out, m_out_bias)
t.testing.assert_close(mlp_call_site.evaluate(), expected)
Now let's work our way through the network and get each section working in isolation.
Circuits provides two ways of indexing. For indexing by values that are known at the time you're creating the Circuit, use rc.Index
(but note that there are limitation on the types of indexes that rc.Index
supports). Where the index is based on the output of a Circuit, as is the case here, use rc.GeneralFunction.gen_index
. When calling rc.GeneralFunction.gen_index
, leave the batch_x
argument at the default value False
.
Exercise: complete token_embed
. It should be one line.
def token_embed(embed_weight: Circuit, input_ids: Circuit) -> Circuit:
"""
embed_weight: shape (vocab_size, hidden_size)
input_ids: shape (batch, seq), dtype int
out: shape (batch, seq, hidden_size)
You can assume sizes aren't symbolic.
"""
"TODO: YOUR CODE HERE"
pass
weight_arr = Array(t.randn((10, 6)))
input_ids_arr = Array(t.tensor([1, 3, 5], dtype=t.int64))
actual_embeds = token_embed(weight_arr, input_ids_arr).evaluate()
t.testing.assert_close(actual_embeds[0], weight_arr.value[1])
t.testing.assert_close(actual_embeds[1], weight_arr.value[3])
t.testing.assert_close(actual_embeds[2], weight_arr.value[5])
Exercise: complete pos_embed
.
Tip: Index
does support indexing by a Tensor
, but when the Tensor
is multidimensional, Index
indexes into each dimension independently which is different than PyTorch and not what you want. Use rc.GeneralFunction.gen_index
instead.
def pos_embed(pos_embed_weight: Circuit, input_ids: Circuit) -> Circuit:
"""
pos_embed_weight: shape (max_position_embeddings, hidden_size)
input_ids: shape (batch, seq), dtype int
out: shape (batch, seq, hidden_size)
You can assume sizes aren't symbolic.
"""
"TODO: YOUR CODE HERE"
pass
max_len = 8
vocab_size = 10
hidden = 4
input_ids = t.tensor([[9, 2, 0, 0, 3], [2, 5, 3, 4, 5]])
pos_embed_weight = t.randn((max_len, hidden))
actual = pos_embed(Array(pos_embed_weight), Array(input_ids)).evaluate()
expected = t.stack((pos_embed_weight[:5], pos_embed_weight[:5]))
t.testing.assert_close(actual, expected)
Before we can do the attention module, we need to compute our attention mask.
-
The tokenizer returns a dictionary with a key "attention_mask". The value is a
Tensor
with shape(batch, seq)
which contains 1 if this is a valid token to attend to, and 0 if this is a padding token that shouldn't be attended to. To disambiguate, I'll call this the "padding mask". -
In an autoregressive transformer, we need a tensor of shape
(seq_q, seq_k)
which hascausal_mask[q][k] = 0 if k > q
. For speed, this tensor can be allocated once and shared between all the attention heads in the network. I'll call this the "causal mask". -
When we actually mask the attention scores, we have a tensor of shape
(batch, seq_q, seq_k)
with 0 for valid (this is the opposite of the meaning of 0 previously) and -10000 for invalid. This doesn't have a standard name, and I'll call this the "additive attention mask" because you add it to the attention scores. You'll need to compute this inapply_attention_mask
soon.
Exercise: implement causal_mask
.
def causal_mask(seq_len: int) -> t.Tensor:
"""
Return shape (seq_q, seq_k) of dtype float32.
"""
"TODO: YOUR CODE HERE"
pass
mask = causal_mask(5)
for q in range(5):
for k in range(5):
assert mask[q, k].item() == (0.0 if k > q else 1.0)
Exercise: implement the self-attention block. Note that we're omitting the batch dimension - since our input is batchable, Expander
is able to prepend a batch dimension later without any issue.
I'm confused about how to set invalid attention scores to -10000.0!
While it's possible to write a custom GeneralFunction
for this, it's tidier to use the basic classes whenever possible to avoid problems with serialization.
You can do this using Einsum
, Scalar
, and Add
by first multiplying invalid positions by 0, then adding -10000.0.
The multiply by zero is not strictly necessary. Even if the attention score was pretty large to begin with, after subtracting -10000 and exponentiating the result should be basically zero (or identically zero depending on the floating point precision). We're just following the reference implementation here.
def apply_attention_mask(attn_scores: Circuit, mask: Circuit) -> Circuit:
"""Return attn_scores with invalid scores set to exactly -10000.0.
attn_scores: shape (head, seq_q, seq_k).
mask: shape (seq_q, seq_k). Contains 1.0 if this is a valid token to attend to, and 0.0 otherwise.
"""
"TODO: YOUR CODE HERE"
pass
print("Testing attention mask with 1 head, 3 positions and 3rd position is padding:")
sample_mask = t.tensor([[1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, 0.0]])
sample_scores = t.arange(9, dtype=t.float32).view((1, 3, 3))
mask_circ = apply_attention_mask(Array(sample_scores), Array(sample_mask))
actual = mask_circ.evaluate()
expected = t.tensor([[[0.0, -10000.0, -10000.0], [3.0, 4.0, -10000.0], [-10000.0, -10000.0, -10000.0]]])
t.testing.assert_close(actual, expected)
print("Testing by comparing attention mask application with reference circuit")
seq_len = 10
rand_scores = t.randn((1, seq_len, seq_len))
rand_mask = causal_mask(seq_len).float()
your_mask_circ = apply_attention_mask(Array(rand_scores), Array(rand_mask))
your_mask_circ_actual = your_mask_circ.evaluate()
ref_attention_mask_circ_expanded = remix_d2_utils.get_ref_attn_mask_expanded(
ref_circuit, Array(rand_scores), Array(rand_mask)
)
ref_attention_mask_circ_expected = ref_attention_mask_circ_expanded.evaluate()
t.testing.assert_close(your_mask_circ_actual, ref_attention_mask_circ_expected)
Exercise: implement attention_scores
. You'll likely want a helper functon to reduce duplication between the k
, q
, and v
projections.
I'm confused about the shapes or the calculations here!
Try working out the test example's first head scores with pencil and paper.
If you feel stuck, look at w2d1_solution.py for some more guidance or call a TA - don't spend a lot of time stuck here.
def attention_scores(x: Circuit, q_weight: Circuit, q_bias: Circuit, k_weight: Circuit, k_bias: Circuit) -> Circuit:
"""
Return the attention pattern after scaling but before softmax or attention masking.
pattern[head, q, k] should be the match between a query at sequence position q and a key at sequence position k.
IMPORTANT: recall the shape of `rc.last_dim_size` is 1 dimension smaller than the output, which is kinda wonky (I hope to get this changed so it always returns a scalar). Ensure you account for this.
x: shape (seq, hidden_size)
q_weight: shape (heads, head_size, hidden_size)
q_bias: shape (heads, head_size)
k_weight: shape (heads, head_size, hidden_size)
k_bias: shape (heads, head_size)
"""
"TODO: YOUR CODE HERE"
pass
print("Testing example with 2 heads, hidden dimension of 2, and head size of 1")
print("Note this won't test your scale factor, since head size is 1.")
sample_q_weight = t.tensor([[[1.0, 0.0]], [[1.0, -1.0]]])
sample_q_bias = t.tensor([[0.0], [0.0]])
sample_k_weight = t.tensor([[[0.0, 2.0]], [[1.0, 2.0]]])
sample_k_bias = t.tensor([[1.0], [2.0]])
sample_x = t.tensor([[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.0, -1.0]])
attn_score_circ = attention_scores(
Array(sample_x), Array(sample_q_weight), Array(sample_q_bias), Array(sample_k_weight), Array(sample_k_bias)
)
actual = attn_score_circ.evaluate()
print(actual)
expected = t.tensor(
[
[[1.0, 3.0, 1.0, -1.0], [0.0, 0.0, 0.0, -0.0], [-1.0, -3.0, -1.0, 1.0], [0.0, 0.0, 0.0, -0.0]],
[[3.0, 4.0, 1.0, 0.0], [-3.0, -4.0, -1.0, -0.0], [-3.0, -4.0, -1.0, -0.0], [3.0, 4.0, 1.0, 0.0]],
]
)
t.testing.assert_close(actual, expected)
print("Testing attention scores with real weights, random embeddings")
batch, seq_len = ref_input_ids.shape
rand_emb = t.rand((seq_len, config.hidden_size))
pretrained_weights = get_weights(circ_dict, bind_module)
attn_weight = pretrained_weights.weights_by_layer[0].attn
your_attn_score_circ = attention_scores(
Array(rand_emb), attn_weight["a.w.q"], attn_weight["a.w.q_bias"], attn_weight["a.w.k"], attn_weight["a.w.k_bias"]
)
actual = your_attn_score_circ.evaluate()
ref_attn_score_circ_expanded = remix_d2_utils.get_ref_attn_score_expanded(ref_circuit, Array(rand_emb), attn_weight)
ref_expected = ref_attn_score_circ_expanded.evaluate()
t.testing.assert_close(actual[0], ref_expected, atol=0.0001, rtol=0.0001)
Exercise: implement attention
by calling your previous functions, and then doing the remaining steps.
def attention(
x: Circuit,
q_weight: Circuit,
q_bias: Circuit,
k_weight: Circuit,
k_bias: Circuit,
v_weight: Circuit,
v_bias: Circuit,
o_weight: Circuit,
o_bias: Circuit,
mask: Circuit,
) -> Circuit:
"""TODO: YOUR CODE HERE"""
pass
def attention_spec() -> ModuleSpec:
"""Return a ModuleSpec representing GPT's standard Attention Layer.
This is provided since it's just the same boilerplate as before.
Note that the mask is passed in; the mask will be calculated once per batch and shared between every attention layer.
"""
sym_inp = sym((SEQ, HIDDEN), "a.input")
sym_q_weight = sym((HEADS, HEAD_SIZE, HIDDEN), "a.w.q")
sym_q_bias = sym((HEADS, HEAD_SIZE), "a.w.q_bias")
sym_k_weight = sym((HEADS, HEAD_SIZE, HIDDEN), "a.w.k")
sym_k_bias = sym((HEADS, HEAD_SIZE), "a.w.k_bias")
sym_v_weight = sym((HEADS, HEAD_SIZE, HIDDEN), "a.w.v")
sym_v_bias = sym((HEADS, HEAD_SIZE), "a.w.v_bias")
sym_o_weight = sym((HEADS, HIDDEN, HEAD_SIZE), "a.w.o")
sym_o_bias = sym((HIDDEN,), "a.w.o_bias")
mask = sym((SEQ, SEQ), "a.mask")
spec_circuit = attention(
sym_inp,
sym_q_weight,
sym_q_bias,
sym_k_weight,
sym_k_bias,
sym_v_weight,
sym_v_bias,
sym_o_weight,
sym_o_bias,
mask,
)
argspecs = [
ModuleArgSpec(sym_inp, batchable=True),
ModuleArgSpec(sym_q_weight, batchable=False),
ModuleArgSpec(sym_q_bias, batchable=False),
ModuleArgSpec(sym_k_weight, batchable=False),
ModuleArgSpec(sym_k_bias, batchable=False),
ModuleArgSpec(sym_v_weight, batchable=False),
ModuleArgSpec(sym_v_bias, batchable=False),
ModuleArgSpec(sym_o_weight, batchable=False),
ModuleArgSpec(sym_o_bias, batchable=False),
ModuleArgSpec(mask, batchable=False),
]
return ModuleSpec(spec_circuit, argspecs)
ATTN_SPEC = attention_spec()
ATTN_SPEC.check_all_inputs_used()
ATTN_SPEC.check_unique_arg_names()
print("Testing attention with real weights, random embeddings")
batch, seq_len = ref_input_ids.shape
rand_emb = t.rand((batch, seq_len, config.hidden_size))
pretrained_weights = get_weights(circ_dict, bind_module)
pretrained_weights.set_attention_mask(Array(causal_mask(seq_len).float()))
attn_weight = pretrained_weights.weights_by_layer[0].attn
your_attn = Module(spec=ATTN_SPEC, name="a.call", **attn_weight, **{"a.input": Array(rand_emb)})
actual = your_attn.evaluate()
ref_attn_expanded = remix_d2_utils.get_ref_attn_expanded(ref_circuit, Array(rand_emb), attn_weight)
expected = ref_attn_expanded.evaluate()
t.testing.assert_close(actual, expected, atol=0.0001, rtol=0.0001)
Exercise: implement gpt2_block
. This doesn't need to itself be a Module
.
def gpt2_block(x: Circuit, weights: LayerWeights) -> Circuit:
"""Build Modules for each part and implement the skip connections."""
"TODO: YOUR CODE HERE"
pass
print("Testing block with real weights, random embeddings")
weights = pretrained_weights.weights_by_layer[0]
block = gpt2_block(Array(rand_emb), weights)
assert block.shape == (batch, seq_len, config.hidden_size)
actual = block.evaluate()
assert actual.shape == (batch, seq_len, config.hidden_size)
ref_block_expanded = remix_d2_utils.get_ref_block_expanded(ref_circuit, Array(rand_emb), weights)
expected = ref_block_expanded.evaluate()
t.testing.assert_close(actual, expected)
Exercise: implement unembed
- it should just be one line.
def unembed(x: Circuit, unembed_weight: Circuit) -> Circuit:
"""
x: shape (hidden,)
unembed_weight: (vocab, hidden)
Out: (vocab,)
"""
"TODO: YOUR CODE HERE"
pass
def unembed_spec() -> ModuleSpec:
"""Return a ModuleSpec representing the unembedding layer of GPT.
This is provided since it's just the same boilerplate as before."""
sym_inp = sym((HIDDEN,), "unembed.input")
sym_weight = sym((LOG_LIKELIHOOD_CLASSES, HIDDEN), "unembed.weight")
spec_circuit = unembed(sym_inp, sym_weight)
argspecs = [ModuleArgSpec(sym_inp, batchable=True), ModuleArgSpec(sym_weight, batchable=False)]
return ModuleSpec(spec_circuit, argspecs)
UNEMBED_SPEC = unembed_spec()
UNEMBED_SPEC.check_all_inputs_used()
UNEMBED_SPEC.check_unique_arg_names()
vocab_size = 10
hidden = 4
weight = Array(t.randn((vocab_size, hidden)))
unembed_x = Array(weight.value[2])
unembed_module = Module(
spec=UNEMBED_SPEC, name="unembed.call", **{"unembed.weight": weight}, **{"unembed.input": unembed_x}
)
assert unembed_module.evaluate().shape == (vocab_size,)
Exercise: implement gpt2
to wire together the pieces. Note that in GPT-2, the unembedding weight is "tied" - it's the same as the embedding weight.
def gpt2(input_ids: Circuit, weights: GPT2Weights) -> Circuit:
"""
x: shape (batch, seq, hidden): sum of token and positional embeddings
Return logits of shape (batch, seq, vocab_size)
"""
"TODO: YOUR CODE HERE"
pass
If you've done everything right, your GPT-2 will make the same predictions as the reference model from before.
Mine did not work on the first try, so don't be discouraged if yours is broken. Debugging is an Authentic ML Experience - try to devise a systematic plan for narrowing down where the issue could be and remember to make liberal use of assertions and write additional test cases.
print("Real GPT-2 said: ")
eval_circuit(ref_circuit)
print("\n\nYour GPT-2 said: ")
your_circuit = gpt2(ref_input_ids, pretrained_weights)
eval_circuit(your_circuit)
Congratulations on completing the main content for the day :)
- Practice printing parts of the model. Play with the
PrintOptions
to find a balance where you print enough to see what's going on, but don't feel overwhelmed by too much detail.