Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multihead attention #199

Open
wants to merge 64 commits into
base: main
Choose a base branch
from

Conversation

OneAdder
Copy link
Collaborator

@OneAdder OneAdder commented Feb 9, 2025

Hello, Milan! I hope, I'm not bothering you too much with my pull requests, but this is a good one. At this stage it is a draft of MultiHead Attention. It cannot be merged until work on input2d_layer and linear2d_layer is completed.
Implementation of dropout would also help improve MHA, but it can be added later.

MultiHead Attention

MultiHead Attention is the main component of Transformer architecture, which is the most advanced modern approach in the area of Natural Language Processing, as well as some other areas.
Here I propose an implementation based on the Transformer article. It works and its output conforms with SOTA implementation in PyTorch.

Problems

  1. 3D. It is 3D so far. I plan on making it 2D when the infrastructure is ready (the aforementioned PRs). When it is done, I'll also optimize iterations over attention heads. I'll also try optimizing back prop, it looks very scary now.
  2. Cross Attention. My current implementation accepts query, key and value separately which will not work in the current paradigm. What can be done: implement it as solely Self Attention -- the input will be only one, then it will be copied thee times. However, this approach will not work for Cross Attention. For that, the layer has to be implemented as 3D Layer, while Self Attention will be 2D.
  3. No working example. I'll add one along with all the connections, when the 2D problem is solved.
  4. Commits History. Here I rebased from my Linear2D Layer. When it is merged, I'll rebase from main again, so, it will not be an issue. At this stage, only those two files are to be reviewed: nf_multihead_attention.f90 and test_multihead_attention_layer.f90.

Python Reference

Here is the snippet of code that uses PyTorch to calculate MultiHead Attention:

import torch
import numpy as np


mha = torch.nn.MultiheadAttention(num_heads=2, embed_dim=4, batch_first=True)
mha.in_proj_weight.data = torch.zeros(12, 4) + 0.1
mha.in_proj_bias.data = torch.zeros(12) + 0.11
mha.out_proj.weight.data = torch.zeros(4, 4) + 0.1
mha.out_proj.bias.data = torch.zeros(4) + 0.11

x = torch.tensor(
    np.array(
        [0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12],
    ).reshape(3, 4, 1, order='F').transpose(2, 0, 1),
    dtype=torch.float32,
    requires_grad=True,
)
out, weights = mha(x, x, x)
print('Output:', np.array(np.nditer(out.detach().numpy(), order='F')))
print('Attention Weights:', np.array(np.nditer(weights.detach().numpy(), order='F')))
gradient = torch.tensor(
    [.1, .1, .1, 3., 3., 3., 2., .1, 2., 3., .1, 3.],
    requires_grad=True,
).reshape(1, 3, 4)
out.backward(gradient=gradient)
print('Gradient:', np.array(np.nditer(x.grad.numpy(), order='F')))

Output:

Output: [0.98224145 1.0040786  1.0044413  0.98224145 1.0040786  1.0044413
 0.98224145 1.0040786  1.0044413  0.98224145 1.0040786  1.0044413 ]
Weights: [0.07894507 0.02281102 0.02188467 0.44750854 0.46461242 0.46472138
 0.47354645 0.5125766  0.513394  ]
Gradient: [-0.02299126  0.38148475  0.45318538 -0.02299126  0.38148475  0.45318538
 -0.02299126  0.38148475  0.45318538 -0.02299126  0.38148475  0.45318538]

It is the same as in my tests.

@ricor07
Copy link
Collaborator

ricor07 commented Feb 9, 2025

Hello Michael, I saw your pull requests and I think what you do is very interesting. Could you take a look at mine? milancurcic#2

What you have to look here is not the locally connected 1d layer but rather the reshape architecture I am trying to make. Your help would be very appreciated. Thanks

@milancurcic
Copy link
Member

Amazing, thank you! Yes, let's wrap up Input2d in #198, then Linear2d in #197, to avoid the 3d crutch.

@OneAdder OneAdder force-pushed the multihead_attention branch from 48d93b2 to 86cd7c0 Compare February 14, 2025 17:37
@ricor07
Copy link
Collaborator

ricor07 commented Feb 14, 2025

@OneAdder i see you are talking about a 2d handling. Would you like to make this together? I have to make this as well since I'm implementing a conv 1d layer

@milancurcic
Copy link
Member

Hi guys, thanks for pushing this forward. Today I'm finishing a lot of busywork with some proposals so next week I'm getting back more actively with neural-fortran work, and will be able to contribute to the reviews more actively

@OneAdder
Copy link
Collaborator Author

@ricor07 Great idea! But we have a problem with generics again. The issues is that a predict_2d and predict_1d_batch have the same rank of input. I think we should simply make a separate generic for predict_batch. @milancurcic your thoughts?

@ricor07
Copy link
Collaborator

ricor07 commented Feb 15, 2025

Yes, I think we can make a generic predict. But I suggest you to create a new branch

@milancurcic
Copy link
Member

I think it's fine to make predict_batch its own generic name because it's getting in the way. 👍

@OneAdder
Copy link
Collaborator Author

@milancurcic Done, here: #198
@ricor07 There is still one piece of the puzzle missing: a general flatten layer with interfaces for both 3D and 2D. Do you want to implement it? Or I can do it together with linear2d layer

@ricor07
Copy link
Collaborator

ricor07 commented Feb 15, 2025

You can make it. I'll work on maxpool

@OneAdder OneAdder force-pushed the multihead_attention branch 3 times, most recently from f9e7a7c to 0900990 Compare February 17, 2025 11:07
@OneAdder OneAdder force-pushed the multihead_attention branch from 6a09663 to 992da67 Compare February 17, 2025 21:53
@milancurcic
Copy link
Member

Hi Michael, just in case you weren't aware, Ondrej has a procedural implementation of GPT-2, including a MHA subroutine here: https://github.com/certik/fastGPT/blob/main/gpt2.f90. You can use it as a reference Fortran implementation if you need.

When you're ready for me to start playing with this PR, just mark it Ready for review.

@OneAdder
Copy link
Collaborator Author

@milancurcic The attention backward and forward have both been done and tested like a week ago. I mostly used Attention Is All You Need paper as a reference as it is pretty straightforward and includes most of the formulae. And, besides, that example is not particularly useful here as it doesn't implement back prop. But we can reuse at a later point the piece that loads weights from HuggingFace.
I'm now working on figuring out how to create an example for MHA. So, most of the code is ready for review

@OneAdder
Copy link
Collaborator Author

Except one thing: redundant arguments. I'll fix that

@OneAdder
Copy link
Collaborator Author

OneAdder commented Feb 19, 2025

So, my example works, the full check list:

  • Compile threaded version (I still didn't manage to figure out what Debian packages I need)
  • (Not possible, all iterations are needed) Figure out how to optimize out two extra passes through n_heads without sacrificing readability
  • Remove redundant arguments in constructors
  • Add an example

@milancurcic I have this text classification example which can also be used as an example for MHA. But it's large and doesn't make much sense until I add Positional Encoding (WIP). Should I add it here anyway or should I add it a later point in a separate PR? I can come up with a small toy example here

@milancurcic
Copy link
Member

Excellent, thanks, I'll begin reviewing tomorrow morning. How large is it? In the past, I'd upload data tarballs as an attachment in a GitHub issue and use its permanent URL to download from programs. That way data storage doesn't go into git history. But there's some size limit. Our MNIST data used here is stored like that.

@milancurcic
Copy link
Member

For example of datasets stored here, see

character(*), parameter :: keras_snippets_baseurl = &
'https://github.com/neural-fortran/keras-snippets/files'
character(*), parameter :: neural_fortran_baseurl = &
'https://github.com/modern-fortran/neural-fortran/files'
character(*), parameter :: keras_cnn_mnist_url = &
keras_snippets_baseurl // '/8892585/keras_cnn_mnist.tar.gz'
character(*), parameter :: keras_dense_mnist_url = &
keras_snippets_baseurl // '/8788739/keras_dense_mnist.tar.gz'
character(*), parameter :: keras_reshape_url = &
keras_snippets_baseurl // '/9667603/keras_reshape.tar.gz'
character(*), parameter :: mnist_url = &
neural_fortran_baseurl // '/8498876/mnist.tar.gz'

@OneAdder OneAdder marked this pull request as ready for review February 19, 2025 18:36
@OneAdder
Copy link
Collaborator Author

@milancurcic I think it's ready for review! I'll add the complicated example later. At this stage I added a simple example that converges nicely and doesn't require datasets and extra deps


implicit none

type, extends(multihead_attention_layer) :: cross_attention_layer
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is intentional that there is no plumbing for this one yet. I suggest that we add it at later stage when we have more components for seq2seq models. At this stage it can be added like this: without any public access

logical :: res

res = all(abs(x - y) <= (1e-06 + 1e-05 * abs(y)))
end function allclose
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion for future: create nf_utils.f90 (or similar) and put this procedure there

@OneAdder OneAdder requested a review from milancurcic February 19, 2025 18:47
end do
end subroutine create_attention_matrix

pure module subroutine normalize_attention_matrix(self, attention_mask)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attention_mask is not accessible to the users by design at this point. It will be used by transformer decoder later and I'll add corresponding logic later

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants