-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multihead attention #199
Merged
milancurcic
merged 73 commits into
modern-fortran:main
from
OneAdder:multihead_attention
Feb 21, 2025
Merged
Multihead attention #199
Changes from 64 commits
Commits
Show all changes
73 commits
Select commit
Hold shift + click to select a range
49e8507
linear2d_layer forward implementation
OneAdder feb7112
linear2d_layer: temporarily remove api
OneAdder 8f320f0
Don't expose the concrete layer type via nf
milancurcic af4a5d7
Plumbing of linear2d with input2d and linear2d
milancurcic 549d4e6
linear2d_layer: add flatten2d layer
OneAdder 3218be0
linear2d_layer: make linear2d layer work with input2d and flatten2d
OneAdder 39636f4
update cmake
OneAdder 4cc7d1d
linear2d_layer: remove flatten2d layer
OneAdder d863ce7
linear2d_layer: remove public api
OneAdder 78eb17a
linear2d_layer: update cmakelists
OneAdder 567abc4
Add linear2d example
milancurcic 32ac10d
linear2d_layer: remove redundant constructor args
OneAdder edd169d
linear2d_layer: make example converge
OneAdder aa5b83f
linear2d_layer: add loss stopping and more iterations
OneAdder dd3ce33
start impementing MultiHeadAttention
OneAdder 0ed77e8
scaled dot product attention
OneAdder d6e6f3e
combine attention heads
OneAdder eb58006
forward (not working)
OneAdder 452032e
rearrange attention dimensions in more efficient way
OneAdder e06d39b
initial forward implementation for multi-head attention
OneAdder 519a6c8
tests for multihead_attention%forward
OneAdder 9fdc7ae
multihead_attention: move most logic to subroutines (performance)
OneAdder bc67331
multihead_attention: update tests
OneAdder a0a6fc4
multihead_attention: concurrency
OneAdder f8101af
multihead_attention: proof of concept backward (works, but not mathem…
OneAdder 63cce11
multihead_attention: fix minor scaling issue
OneAdder dfb8842
multihead_attention: complete backward implementation
OneAdder adcf5e6
multihead_attention: add comments for forward prop
OneAdder 650e47c
multihead_attention: add tests for backward
OneAdder 3d16161
multihead_attention: adjust expected test values for updated scaling
OneAdder dcae5d6
multihead_attention: calculate scaling factor only once
OneAdder 9fceae7
multihead_attention: use heap-allocated arrays during back prop
OneAdder 248e124
multihead_attention: use heap-allocated arrays in forward
OneAdder 4693028
multihead_attention: set values from correct shape to tests
OneAdder 32dd628
multihead_attention: fix issues with shapes (softmax prime became eve…
OneAdder 33c33b9
multihead_attention: minor refactoring and optimization
OneAdder 40c3f2b
multihead_attention: fix comments
OneAdder 6a607b0
multihead_attention: tests, add checks for attention weights
OneAdder 5fc5a5b
multihead_attention: remove some of the copypaste comments
OneAdder 65fd88d
multihead_attention: optimize shapes
OneAdder fbc132d
multihead_attention: params api
OneAdder 5422e4c
multihead_attention: fix incorrect dw bug
OneAdder 39637e7
multihead_attention: tests for updated parameters
OneAdder 60a49db
multihead_attention: remove reshape crutches
OneAdder 7ab7769
multihead_attention: rename common forward and backward calls
OneAdder 20c5eb0
multihead_attention: tidy mha up
OneAdder 6098533
multihead_attention: self attention
OneAdder 66b5023
multihead_attention: add cross attention
OneAdder ac813aa
multihead_attention: add more comments
OneAdder 6b70f6b
multihead_attention: arrange attention into submodule
OneAdder b622d55
multihead_attention: update cmakelists
OneAdder ce03b39
multihead_attention: update attention in accordance with linear2d
OneAdder 41a80cd
multihead_attention: remove redundand constructor args for attention …
OneAdder a84efd3
multihead_attention: use pure and elemental where necessary
OneAdder 52c94c4
multihead_attention: plumbing
OneAdder 66b539b
multihead_attention: add reference
OneAdder 992da67
multihead_attention: remove rebase artifact
OneAdder d93be41
multihead_attention: remove redundant args
OneAdder 70272cb
multihead_attention: update tests
OneAdder cb717f5
multihead_attention: add the most important lines to tests
OneAdder b7a6d06
multihead_attention: simple MHA example
OneAdder cb26afb
multihead_attention: update cmake
OneAdder 4c92e9c
multihead_attention: remove debug line from tests
OneAdder df5f4cf
multihead_attention: set slightly higher margin for fp imprecision (d…
OneAdder 46786d6
Merge upstream/main
milancurcic 6162783
Rename mha_simple example
milancurcic 89abf22
Update src/nf/nf_multihead_attention.f90
milancurcic 29b7d2e
Update src/nf/nf_multihead_attention.f90
milancurcic e901479
Update src/nf/nf_multihead_attention.f90
milancurcic 1eaee95
Update src/nf/nf_multihead_attention.f90
milancurcic 588ecb1
Tidy up
milancurcic 20ffe05
Add self_attention to the layers table
milancurcic e4c6548
Merge branch 'multihead_attention' of github.com:OneAdder/neural-fort…
milancurcic File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
program simple | ||
use nf, only: dense, input, network, sgd, self_attention, flatten | ||
implicit none | ||
type(network) :: net | ||
real, allocatable :: x(:, :), y(:) | ||
integer, parameter :: num_iterations = 500 | ||
integer :: n | ||
|
||
print '("Simple")' | ||
print '(60("="))' | ||
|
||
net = network([ & | ||
input(3, 8), & | ||
self_attention(4), & | ||
flatten(), & | ||
dense(2) & | ||
]) | ||
|
||
call net % print_info() | ||
|
||
allocate(x(3, 8)) | ||
call random_number(x) | ||
|
||
y = [0.123456, 0.246802] | ||
|
||
do n = 0, num_iterations | ||
|
||
call net % forward(x) | ||
call net % backward(y) | ||
call net % update(optimizer=sgd(learning_rate=1.)) | ||
|
||
if (mod(n, 50) == 0) & | ||
print '(i4,2(3x,f8.6))', n, net % predict(x) | ||
|
||
end do | ||
|
||
end program simple |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
module nf_cross_attention_layer | ||
use iso_fortran_env, only: stderr => error_unit | ||
use nf_activation, only: softmax | ||
use nf_linear2d_layer, only: linear2d_layer | ||
use nf_multihead_attention_layer, only: multihead_attention_layer | ||
|
||
implicit none | ||
|
||
type, extends(multihead_attention_layer) :: cross_attention_layer | ||
!! Cross Attention Layer | ||
!! Source: | ||
!! Bahdanau, D. (2014) | ||
!! Neural machine translation by jointly learning to align and translate. | ||
!! https://arxiv.org/pdf/1409.0473 | ||
real, allocatable :: gradient(:, :, :) | ||
contains | ||
procedure :: forward | ||
procedure :: backward | ||
procedure :: init | ||
end type cross_attention_layer | ||
|
||
interface cross_attention_layer | ||
module function cross_attention_layer_cons(n_heads) result(res) | ||
!! This function returns the `cross_attention_layer` instance. | ||
integer, intent(in) :: sequence_length, model_dimension, n_heads | ||
type(cross_attention_layer) :: res | ||
end function cross_attention_layer_cons | ||
end interface cross_attention_layer | ||
|
||
contains | ||
module function cross_attention_layer_cons(n_heads) result(res) | ||
!! This function returns the `cross_attention_layer` instance. | ||
integer, intent(in) :: n_heads | ||
type(cross_attention_layer) :: res | ||
res % n_heads = n_heads | ||
end function cross_attention_layer_cons | ||
|
||
pure module subroutine backward(self, input, gradient) | ||
!! Cross Attention Back propagation | ||
class(cross_attention_layer), intent(in out) :: self | ||
real, intent(in) :: input(:, :, :) | ||
real, intent(in) :: gradient(:, :) | ||
|
||
call self % common_backward(input(1, :, :), gradient) | ||
self % gradient(1, :, :) = self % query_layer % gradient | ||
self % gradient(2, :, :) = self % key_layer % gradient + self % value_layer % gradient | ||
end subroutine backward | ||
|
||
pure module subroutine forward(self, input) | ||
!! Cross Attention Forward propagation | ||
!! Input Shape (kind, sequence_length, model_dimension) | ||
!! where kind is 1 for Query and 2 for Key-Value | ||
class(cross_attention_layer), intent(in out) :: self | ||
real, intent(in) :: input(:, :, :) | ||
|
||
call self % common_forward(input(1, :, :), input(2, :, :), input(2, :, :)) | ||
end subroutine forward | ||
|
||
module subroutine init(self, input_shape) | ||
class(cross_attention_layer), intent(in out) :: self | ||
integer, intent(in) :: input_shape(:) | ||
|
||
call self % init_base(input_shape) | ||
allocate(self % gradient(2, self % sequence_length, self % model_dimension)) | ||
end subroutine init | ||
end module nf_cross_attention_layer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is intentional that there is no plumbing for this one yet. I suggest that we add it at later stage when we have more components for seq2seq models. At this stage it can be added like this: without any public access