Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multihead attention #199

Merged
merged 73 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from 64 commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
49e8507
linear2d_layer forward implementation
OneAdder Feb 2, 2025
feb7112
linear2d_layer: temporarily remove api
OneAdder Feb 14, 2025
8f320f0
Don't expose the concrete layer type via nf
milancurcic Feb 16, 2025
af4a5d7
Plumbing of linear2d with input2d and linear2d
milancurcic Feb 16, 2025
549d4e6
linear2d_layer: add flatten2d layer
OneAdder Feb 16, 2025
3218be0
linear2d_layer: make linear2d layer work with input2d and flatten2d
OneAdder Feb 16, 2025
39636f4
update cmake
OneAdder Feb 16, 2025
4cc7d1d
linear2d_layer: remove flatten2d layer
OneAdder Feb 16, 2025
d863ce7
linear2d_layer: remove public api
OneAdder Feb 16, 2025
78eb17a
linear2d_layer: update cmakelists
OneAdder Feb 16, 2025
567abc4
Add linear2d example
milancurcic Feb 17, 2025
32ac10d
linear2d_layer: remove redundant constructor args
OneAdder Feb 17, 2025
edd169d
linear2d_layer: make example converge
OneAdder Feb 17, 2025
aa5b83f
linear2d_layer: add loss stopping and more iterations
OneAdder Feb 17, 2025
dd3ce33
start impementing MultiHeadAttention
OneAdder Jan 31, 2025
0ed77e8
scaled dot product attention
OneAdder Jan 31, 2025
d6e6f3e
combine attention heads
OneAdder Jan 31, 2025
eb58006
forward (not working)
OneAdder Jan 31, 2025
452032e
rearrange attention dimensions in more efficient way
OneAdder Feb 5, 2025
e06d39b
initial forward implementation for multi-head attention
OneAdder Feb 5, 2025
519a6c8
tests for multihead_attention%forward
OneAdder Feb 6, 2025
9fdc7ae
multihead_attention: move most logic to subroutines (performance)
OneAdder Feb 6, 2025
bc67331
multihead_attention: update tests
OneAdder Feb 6, 2025
a0a6fc4
multihead_attention: concurrency
OneAdder Feb 6, 2025
f8101af
multihead_attention: proof of concept backward (works, but not mathem…
OneAdder Feb 8, 2025
63cce11
multihead_attention: fix minor scaling issue
OneAdder Feb 9, 2025
dfb8842
multihead_attention: complete backward implementation
OneAdder Feb 9, 2025
adcf5e6
multihead_attention: add comments for forward prop
OneAdder Feb 9, 2025
650e47c
multihead_attention: add tests for backward
OneAdder Feb 9, 2025
3d16161
multihead_attention: adjust expected test values for updated scaling
OneAdder Feb 9, 2025
dcae5d6
multihead_attention: calculate scaling factor only once
OneAdder Feb 9, 2025
9fceae7
multihead_attention: use heap-allocated arrays during back prop
OneAdder Feb 9, 2025
248e124
multihead_attention: use heap-allocated arrays in forward
OneAdder Feb 9, 2025
4693028
multihead_attention: set values from correct shape to tests
OneAdder Feb 9, 2025
32dd628
multihead_attention: fix issues with shapes (softmax prime became eve…
OneAdder Feb 9, 2025
33c33b9
multihead_attention: minor refactoring and optimization
OneAdder Feb 9, 2025
40c3f2b
multihead_attention: fix comments
OneAdder Feb 9, 2025
6a607b0
multihead_attention: tests, add checks for attention weights
OneAdder Feb 9, 2025
5fc5a5b
multihead_attention: remove some of the copypaste comments
OneAdder Feb 9, 2025
65fd88d
multihead_attention: optimize shapes
OneAdder Feb 12, 2025
fbc132d
multihead_attention: params api
OneAdder Feb 14, 2025
5422e4c
multihead_attention: fix incorrect dw bug
OneAdder Feb 14, 2025
39637e7
multihead_attention: tests for updated parameters
OneAdder Feb 14, 2025
60a49db
multihead_attention: remove reshape crutches
OneAdder Feb 16, 2025
7ab7769
multihead_attention: rename common forward and backward calls
OneAdder Feb 16, 2025
20c5eb0
multihead_attention: tidy mha up
OneAdder Feb 16, 2025
6098533
multihead_attention: self attention
OneAdder Feb 16, 2025
66b5023
multihead_attention: add cross attention
OneAdder Feb 16, 2025
ac813aa
multihead_attention: add more comments
OneAdder Feb 16, 2025
6b70f6b
multihead_attention: arrange attention into submodule
OneAdder Feb 16, 2025
b622d55
multihead_attention: update cmakelists
OneAdder Feb 16, 2025
ce03b39
multihead_attention: update attention in accordance with linear2d
OneAdder Feb 17, 2025
41a80cd
multihead_attention: remove redundand constructor args for attention …
OneAdder Feb 17, 2025
a84efd3
multihead_attention: use pure and elemental where necessary
OneAdder Feb 17, 2025
52c94c4
multihead_attention: plumbing
OneAdder Feb 17, 2025
66b539b
multihead_attention: add reference
OneAdder Feb 17, 2025
992da67
multihead_attention: remove rebase artifact
OneAdder Feb 17, 2025
d93be41
multihead_attention: remove redundant args
OneAdder Feb 19, 2025
70272cb
multihead_attention: update tests
OneAdder Feb 19, 2025
cb717f5
multihead_attention: add the most important lines to tests
OneAdder Feb 19, 2025
b7a6d06
multihead_attention: simple MHA example
OneAdder Feb 19, 2025
cb26afb
multihead_attention: update cmake
OneAdder Feb 19, 2025
4c92e9c
multihead_attention: remove debug line from tests
OneAdder Feb 19, 2025
df5f4cf
multihead_attention: set slightly higher margin for fp imprecision (d…
OneAdder Feb 19, 2025
46786d6
Merge upstream/main
milancurcic Feb 21, 2025
6162783
Rename mha_simple example
milancurcic Feb 21, 2025
89abf22
Update src/nf/nf_multihead_attention.f90
milancurcic Feb 21, 2025
29b7d2e
Update src/nf/nf_multihead_attention.f90
milancurcic Feb 21, 2025
e901479
Update src/nf/nf_multihead_attention.f90
milancurcic Feb 21, 2025
1eaee95
Update src/nf/nf_multihead_attention.f90
milancurcic Feb 21, 2025
588ecb1
Tidy up
milancurcic Feb 21, 2025
20ffe05
Add self_attention to the layers table
milancurcic Feb 21, 2025
e4c6548
Merge branch 'multihead_attention' of github.com:OneAdder/neural-fort…
milancurcic Feb 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ add_library(neural-fortran
src/nf/nf_base_layer.f90
src/nf/nf_conv2d_layer.f90
src/nf/nf_conv2d_layer_submodule.f90
src/nf/nf_cross_attention_layer.f90
src/nf/nf_datasets.f90
src/nf/nf_datasets_submodule.f90
src/nf/nf_datasets_mnist.f90
Expand All @@ -45,6 +46,8 @@ add_library(neural-fortran
src/nf/nf_maxpool2d_layer.f90
src/nf/nf_maxpool2d_layer_submodule.f90
src/nf/nf_metrics.f90
src/nf/nf_multihead_attention.f90
src/nf/nf_multihead_attention_submodule.f90
src/nf/nf_network.f90
src/nf/nf_network_submodule.f90
src/nf/nf_optimizers.f90
Expand All @@ -53,6 +56,7 @@ add_library(neural-fortran
src/nf/nf_random.f90
src/nf/nf_reshape_layer.f90
src/nf/nf_reshape_layer_submodule.f90
src/nf/nf_self_attention_layer.f90
src/nf/io/nf_io_binary.f90
src/nf/io/nf_io_binary_submodule.f90
)
Expand Down
1 change: 1 addition & 0 deletions example/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ foreach(execid
simple
sine
quadratic
mha_simple
)
add_executable(${execid} ${execid}.f90)
target_link_libraries(${execid} PRIVATE
Expand Down
37 changes: 37 additions & 0 deletions example/mha_simple.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
program simple
use nf, only: dense, input, network, sgd, self_attention, flatten
implicit none
type(network) :: net
real, allocatable :: x(:, :), y(:)
integer, parameter :: num_iterations = 500
integer :: n

print '("Simple")'
print '(60("="))'

net = network([ &
input(3, 8), &
self_attention(4), &
flatten(), &
dense(2) &
])

call net % print_info()

allocate(x(3, 8))
call random_number(x)

y = [0.123456, 0.246802]

do n = 0, num_iterations

call net % forward(x)
call net % backward(y)
call net % update(optimizer=sgd(learning_rate=1.))

if (mod(n, 50) == 0) &
print '(i4,2(3x,f8.6))', n, net % predict(x)

end do

end program simple
4 changes: 3 additions & 1 deletion src/nf.f90
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module nf
use nf_datasets_mnist, only: label_digits, load_mnist
use nf_layer, only: layer
use nf_layer_constructors, only: &
conv2d, dense, flatten, input, maxpool2d, reshape, linear2d
conv2d, dense, flatten, input, maxpool2d, reshape, linear2d, self_attention
use nf_loss, only: mse, quadratic
use nf_metrics, only: corr, maxabs
use nf_network, only: network
Expand All @@ -12,4 +12,6 @@ module nf
gaussian, linear, relu, leaky_relu, &
sigmoid, softmax, softplus, step, tanhf, &
celu
use nf_linear2d_layer, only: linear2d_layer
use nf_multihead_attention_layer, only: multihead_attention_layer
end module nf
66 changes: 66 additions & 0 deletions src/nf/nf_cross_attention_layer.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
module nf_cross_attention_layer
use iso_fortran_env, only: stderr => error_unit
use nf_activation, only: softmax
use nf_linear2d_layer, only: linear2d_layer
use nf_multihead_attention_layer, only: multihead_attention_layer

implicit none

type, extends(multihead_attention_layer) :: cross_attention_layer
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is intentional that there is no plumbing for this one yet. I suggest that we add it at later stage when we have more components for seq2seq models. At this stage it can be added like this: without any public access

!! Cross Attention Layer
!! Source:
!! Bahdanau, D. (2014)
!! Neural machine translation by jointly learning to align and translate.
!! https://arxiv.org/pdf/1409.0473
real, allocatable :: gradient(:, :, :)
contains
procedure :: forward
procedure :: backward
procedure :: init
end type cross_attention_layer

interface cross_attention_layer
module function cross_attention_layer_cons(n_heads) result(res)
!! This function returns the `cross_attention_layer` instance.
integer, intent(in) :: sequence_length, model_dimension, n_heads
type(cross_attention_layer) :: res
end function cross_attention_layer_cons
end interface cross_attention_layer

contains
module function cross_attention_layer_cons(n_heads) result(res)
!! This function returns the `cross_attention_layer` instance.
integer, intent(in) :: n_heads
type(cross_attention_layer) :: res
res % n_heads = n_heads
end function cross_attention_layer_cons

pure module subroutine backward(self, input, gradient)
!! Cross Attention Back propagation
class(cross_attention_layer), intent(in out) :: self
real, intent(in) :: input(:, :, :)
real, intent(in) :: gradient(:, :)

call self % common_backward(input(1, :, :), gradient)
self % gradient(1, :, :) = self % query_layer % gradient
self % gradient(2, :, :) = self % key_layer % gradient + self % value_layer % gradient
end subroutine backward

pure module subroutine forward(self, input)
!! Cross Attention Forward propagation
!! Input Shape (kind, sequence_length, model_dimension)
!! where kind is 1 for Query and 2 for Key-Value
class(cross_attention_layer), intent(in out) :: self
real, intent(in) :: input(:, :, :)

call self % common_forward(input(1, :, :), input(2, :, :), input(2, :, :))
end subroutine forward

module subroutine init(self, input_shape)
class(cross_attention_layer), intent(in out) :: self
integer, intent(in) :: input_shape(:)

call self % init_base(input_shape)
allocate(self % gradient(2, self % sequence_length, self % model_dimension))
end subroutine init
end module nf_cross_attention_layer
12 changes: 11 additions & 1 deletion src/nf/nf_layer_constructors.f90
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ module nf_layer_constructors
implicit none

private
public :: conv2d, dense, flatten, input, maxpool2d, reshape, linear2d
public :: conv2d, dense, flatten, input, maxpool2d, reshape, linear2d, self_attention

interface input

Expand Down Expand Up @@ -195,6 +195,16 @@ module function linear2d(out_features) result(res)
!! Resulting layer instance
end function linear2d

module function self_attention(n_heads) result(res)
!! Rank-2 (sequence_length, out_features) self attention constructor.
!! sequence_length and model_dimension are determined at layer initialization, based on the
!! output shape of the previous layer.
integer, intent(in) :: n_heads
!! Number of attention heads
type(layer) :: res
!! Resulting layer instance
end function self_attention

end interface

end module nf_layer_constructors
9 changes: 9 additions & 0 deletions src/nf/nf_layer_constructors_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
use nf_maxpool2d_layer, only: maxpool2d_layer
use nf_reshape_layer, only: reshape3d_layer
use nf_linear2d_layer, only: linear2d_layer
use nf_self_attention_layer, only: self_attention_layer
use nf_activation, only: activation_function, relu, sigmoid

implicit none
Expand Down Expand Up @@ -160,4 +161,12 @@ module function linear2d(out_features) result(res)

end function linear2d

module function self_attention(n_heads) result(res)
integer, intent(in) :: n_heads
type(layer) :: res

res % name = 'self_attention'
allocate(res % p, source=self_attention_layer(n_heads))
end function self_attention

end submodule nf_layer_constructors_submodule
45 changes: 44 additions & 1 deletion src/nf/nf_layer_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
use nf_maxpool2d_layer, only: maxpool2d_layer
use nf_reshape_layer, only: reshape3d_layer
use nf_linear2d_layer, only: linear2d_layer
use nf_self_attention_layer, only: self_attention_layer
use nf_optimizers, only: optimizer_base_type

contains
Expand Down Expand Up @@ -50,6 +51,8 @@ pure module subroutine backward_1d(self, previous, gradient)
call this_layer % backward(prev_layer % output, gradient)
type is(linear2d_layer)
call this_layer % backward(prev_layer % output, gradient)
type is(self_attention_layer)
call this_layer % backward(prev_layer % output, gradient)
end select

end select
Expand All @@ -72,6 +75,19 @@ pure module subroutine backward_2d(self, previous, gradient)
call this_layer % backward(prev_layer % output, gradient)
type is(linear2d_layer)
call this_layer % backward(prev_layer % output, gradient)
type is(self_attention_layer)
call this_layer % backward(prev_layer % output, gradient)
end select

type is(self_attention_layer)

select type(prev_layer => previous % p)
type is(input2d_layer)
call this_layer % backward(prev_layer % output, gradient)
type is(linear2d_layer)
call this_layer % backward(prev_layer % output, gradient)
type is(self_attention_layer)
call this_layer % backward(prev_layer % output, gradient)
end select

end select
Expand Down Expand Up @@ -219,6 +235,20 @@ pure module subroutine forward(self, input)
call this_layer % forward(prev_layer % output)
type is(linear2d_layer)
call this_layer % forward(prev_layer % output)
type is(self_attention_layer)
call this_layer % forward(prev_layer % output)
end select

type is(self_attention_layer)

! Upstream layers permitted: input2d, linear2d
select type(prev_layer => input % p)
type is(input2d_layer)
call this_layer % forward(prev_layer % output)
type is(linear2d_layer)
call this_layer % forward(prev_layer % output)
type is(self_attention_layer)
call this_layer % forward(prev_layer % output)
end select

end select
Expand Down Expand Up @@ -258,6 +288,8 @@ pure module subroutine get_output_2d(self, output)
allocate(output, source=this_layer % output)
type is(linear2d_layer)
allocate(output, source=this_layer % output)
type is(self_attention_layer)
allocate(output, source=this_layer % output)
class default
error stop '2-d output can only be read from an input2d or linear2d layer.'

Expand Down Expand Up @@ -301,7 +333,7 @@ impure elemental module subroutine init(self, input)
call this_layer % init(input % layer_shape)
end select

! The shape of linear2d, conv2d, maxpool2d, or flatten layers is not known
! The shape of self_attention, linear2d, conv2d, maxpool2d, or flatten layers is not known
! until we receive an input layer.
select type(this_layer => self % p)
type is(conv2d_layer)
Expand All @@ -312,6 +344,8 @@ impure elemental module subroutine init(self, input)
self % layer_shape = shape(this_layer % output)
type is(linear2d_layer)
self % layer_shape = shape(this_layer % output)
type is(self_attention_layer)
self % layer_shape = shape(this_layer % output)
end select

self % input_layer_shape = input % layer_shape
Expand Down Expand Up @@ -359,6 +393,8 @@ elemental module function get_num_params(self) result(num_params)
num_params = 0
type is (linear2d_layer)
num_params = this_layer % get_num_params()
type is (self_attention_layer)
num_params = this_layer % get_num_params()
class default
error stop 'Unknown layer type.'
end select
Expand Down Expand Up @@ -388,6 +424,8 @@ module function get_params(self) result(params)
! No parameters to get.
type is (linear2d_layer)
params = this_layer % get_params()
type is (self_attention_layer)
params = this_layer % get_params()
class default
error stop 'Unknown layer type.'
end select
Expand Down Expand Up @@ -417,6 +455,8 @@ module function get_gradients(self) result(gradients)
! No gradients to get.
type is (linear2d_layer)
gradients = this_layer % get_gradients()
type is (self_attention_layer)
gradients = this_layer % get_gradients()
class default
error stop 'Unknown layer type.'
end select
Expand Down Expand Up @@ -467,6 +507,9 @@ module subroutine set_params(self, params)
type is (linear2d_layer)
call this_layer % set_params(params)

type is (self_attention_layer)
call this_layer % set_params(params)

type is (maxpool2d_layer)
! No parameters to set.
write(stderr, '(a)') 'Warning: calling set_params() ' &
Expand Down
Loading
Loading