diff --git a/CMakeLists.txt b/CMakeLists.txt index eda96b28..c1bf2231 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,6 +20,7 @@ add_library(neural-fortran src/nf/nf_base_layer.f90 src/nf/nf_conv2d_layer.f90 src/nf/nf_conv2d_layer_submodule.f90 + src/nf/nf_cross_attention_layer.f90 src/nf/nf_datasets.f90 src/nf/nf_datasets_submodule.f90 src/nf/nf_datasets_mnist.f90 @@ -45,6 +46,8 @@ add_library(neural-fortran src/nf/nf_maxpool2d_layer.f90 src/nf/nf_maxpool2d_layer_submodule.f90 src/nf/nf_metrics.f90 + src/nf/nf_multihead_attention.f90 + src/nf/nf_multihead_attention_submodule.f90 src/nf/nf_network.f90 src/nf/nf_network_submodule.f90 src/nf/nf_optimizers.f90 @@ -53,6 +56,7 @@ add_library(neural-fortran src/nf/nf_random.f90 src/nf/nf_reshape_layer.f90 src/nf/nf_reshape_layer_submodule.f90 + src/nf/nf_self_attention_layer.f90 src/nf/io/nf_io_binary.f90 src/nf/io/nf_io_binary_submodule.f90 src/nf/nf_dropout_layer.f90 diff --git a/README.md b/README.md index a0eee745..a04ac32a 100644 --- a/README.md +++ b/README.md @@ -34,8 +34,9 @@ Read the paper [here](https://arxiv.org/abs/1902.06714). | Dropout | `dropout` | `dense`, `flatten`, `input1d` | 1 | ✅ | ✅ | | Convolutional (2-d) | `conv2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 | ✅ | ✅(*) | | Max-pooling (2-d) | `maxpool2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 | ✅ | ✅ | +| Linear (2-d) | `linear2d` | `input2d`, `linear2d`, `self_attention` | 2 | ✅ | ✅ | +| Self-attention | `self_attention` | `input2d`, `linear2d`, `self_attention` | 2 | ✅ | ✅ | | Flatten | `flatten` | `input2d`, `input3d`, `conv2d`, `maxpool2d`, `reshape` | 1 | ✅ | ✅ | -| Linear (2-d) | `linear2d` | `input2d`, `linear2d` | 2 | ✅ | ✅ | | Reshape (1-d to 3-d) | `reshape` | `input1d`, `dense`, `flatten` | 3 | ✅ | ✅ | (*) See Issue [#145](https://github.com/modern-fortran/neural-fortran/issues/145) regarding non-converging CNN training on the MNIST dataset. diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 28cf71a7..f4b706b8 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -6,6 +6,7 @@ foreach(execid simple sine quadratic + mha_simple ) add_executable(${execid} ${execid}.f90) target_link_libraries(${execid} PRIVATE diff --git a/example/mha_simple.f90 b/example/mha_simple.f90 new file mode 100644 index 00000000..2daa5ac2 --- /dev/null +++ b/example/mha_simple.f90 @@ -0,0 +1,37 @@ +program mha_simple + use nf, only: dense, input, network, sgd, self_attention, flatten + implicit none + type(network) :: net + real, allocatable :: x(:, :), y(:) + integer, parameter :: num_iterations = 500 + integer :: n + + print '("Simple")' + print '(60("="))' + + net = network([ & + input(3, 8), & + self_attention(4), & + flatten(), & + dense(2) & + ]) + + call net % print_info() + + allocate(x(3, 8)) + call random_number(x) + + y = [0.123456, 0.246802] + + do n = 0, num_iterations + + call net % forward(x) + call net % backward(y) + call net % update(optimizer=sgd(learning_rate=1.)) + + if (mod(n, 50) == 0) & + print '(i4,2(3x,f8.6))', n, net % predict(x) + + end do + +end program mha_simple diff --git a/src/nf.f90 b/src/nf.f90 index 7a989ea3..39f67ea3 100644 --- a/src/nf.f90 +++ b/src/nf.f90 @@ -3,7 +3,15 @@ module nf use nf_datasets_mnist, only: label_digits, load_mnist use nf_layer, only: layer use nf_layer_constructors, only: & - conv2d, dense, dropout, flatten, input, linear2d, maxpool2d, reshape + conv2d, & + dense, & + dropout, & + flatten, & + input, & + linear2d, & + maxpool2d, & + reshape, & + self_attention use nf_loss, only: mse, quadratic use nf_metrics, only: corr, maxabs use nf_network, only: network @@ -12,4 +20,6 @@ module nf gaussian, linear, relu, leaky_relu, & sigmoid, softmax, softplus, step, tanhf, & celu + use nf_linear2d_layer, only: linear2d_layer + use nf_multihead_attention_layer, only: multihead_attention_layer end module nf diff --git a/src/nf/nf_cross_attention_layer.f90 b/src/nf/nf_cross_attention_layer.f90 new file mode 100644 index 00000000..b3167a13 --- /dev/null +++ b/src/nf/nf_cross_attention_layer.f90 @@ -0,0 +1,66 @@ +module nf_cross_attention_layer + use iso_fortran_env, only: stderr => error_unit + use nf_activation, only: softmax + use nf_linear2d_layer, only: linear2d_layer + use nf_multihead_attention_layer, only: multihead_attention_layer + + implicit none + + type, extends(multihead_attention_layer) :: cross_attention_layer + !! Cross Attention Layer + !! Source: + !! Bahdanau, D. (2014) + !! Neural machine translation by jointly learning to align and translate. + !! https://arxiv.org/pdf/1409.0473 + real, allocatable :: gradient(:, :, :) + contains + procedure :: forward + procedure :: backward + procedure :: init + end type cross_attention_layer + + interface cross_attention_layer + module function cross_attention_layer_cons(n_heads) result(res) + !! This function returns the `cross_attention_layer` instance. + integer, intent(in) :: sequence_length, model_dimension, n_heads + type(cross_attention_layer) :: res + end function cross_attention_layer_cons + end interface cross_attention_layer + +contains + module function cross_attention_layer_cons(n_heads) result(res) + !! This function returns the `cross_attention_layer` instance. + integer, intent(in) :: n_heads + type(cross_attention_layer) :: res + res % n_heads = n_heads + end function cross_attention_layer_cons + + pure module subroutine backward(self, input, gradient) + !! Cross Attention Back propagation + class(cross_attention_layer), intent(in out) :: self + real, intent(in) :: input(:, :, :) + real, intent(in) :: gradient(:, :) + + call self % common_backward(input(1, :, :), gradient) + self % gradient(1, :, :) = self % query_layer % gradient + self % gradient(2, :, :) = self % key_layer % gradient + self % value_layer % gradient + end subroutine backward + + pure module subroutine forward(self, input) + !! Cross Attention Forward propagation + !! Input Shape (kind, sequence_length, model_dimension) + !! where kind is 1 for Query and 2 for Key-Value + class(cross_attention_layer), intent(in out) :: self + real, intent(in) :: input(:, :, :) + + call self % common_forward(input(1, :, :), input(2, :, :), input(2, :, :)) + end subroutine forward + + module subroutine init(self, input_shape) + class(cross_attention_layer), intent(in out) :: self + integer, intent(in) :: input_shape(:) + + call self % init_base(input_shape) + allocate(self % gradient(2, self % sequence_length, self % model_dimension)) + end subroutine init +end module nf_cross_attention_layer diff --git a/src/nf/nf_layer_constructors.f90 b/src/nf/nf_layer_constructors.f90 index 87ceeeea..db60cf0f 100644 --- a/src/nf/nf_layer_constructors.f90 +++ b/src/nf/nf_layer_constructors.f90 @@ -8,7 +8,16 @@ module nf_layer_constructors implicit none private - public :: conv2d, dense, dropout, flatten, input, linear2d, maxpool2d, reshape + public :: & + conv2d, & + dense, & + dropout, & + flatten, & + input, & + linear2d, & + maxpool2d, & + reshape, & + self_attention interface input @@ -213,6 +222,16 @@ module function linear2d(out_features) result(res) !! Resulting layer instance end function linear2d + module function self_attention(num_heads) result(res) + !! Rank-2 (sequence_length, out_features) self attention constructor. + !! sequence_length and model_dimension are determined at layer initialization, based on the + !! output shape of the previous layer. + integer, intent(in) :: num_heads + !! Number of attention heads + type(layer) :: res + !! Resulting layer instance + end function self_attention + end interface end module nf_layer_constructors diff --git a/src/nf/nf_layer_constructors_submodule.f90 b/src/nf/nf_layer_constructors_submodule.f90 index 9558a0bc..9e5322c1 100644 --- a/src/nf/nf_layer_constructors_submodule.f90 +++ b/src/nf/nf_layer_constructors_submodule.f90 @@ -11,6 +11,7 @@ use nf_maxpool2d_layer, only: maxpool2d_layer use nf_reshape_layer, only: reshape3d_layer use nf_linear2d_layer, only: linear2d_layer + use nf_self_attention_layer, only: self_attention_layer use nf_activation, only: activation_function, relu, sigmoid implicit none @@ -170,4 +171,12 @@ module function linear2d(out_features) result(res) end function linear2d + module function self_attention(num_heads) result(res) + integer, intent(in) :: num_heads + type(layer) :: res + + res % name = 'self_attention' + allocate(res % p, source=self_attention_layer(num_heads)) + end function self_attention + end submodule nf_layer_constructors_submodule diff --git a/src/nf/nf_layer_submodule.f90 b/src/nf/nf_layer_submodule.f90 index 701dfe29..ecdeb41d 100644 --- a/src/nf/nf_layer_submodule.f90 +++ b/src/nf/nf_layer_submodule.f90 @@ -11,6 +11,7 @@ use nf_maxpool2d_layer, only: maxpool2d_layer use nf_reshape_layer, only: reshape3d_layer use nf_linear2d_layer, only: linear2d_layer + use nf_self_attention_layer, only: self_attention_layer use nf_optimizers, only: optimizer_base_type contains @@ -57,6 +58,8 @@ pure module subroutine backward_1d(self, previous, gradient) call this_layer % backward(prev_layer % output, gradient) type is(linear2d_layer) call this_layer % backward(prev_layer % output, gradient) + type is(self_attention_layer) + call this_layer % backward(prev_layer % output, gradient) end select end select @@ -79,6 +82,19 @@ pure module subroutine backward_2d(self, previous, gradient) call this_layer % backward(prev_layer % output, gradient) type is(linear2d_layer) call this_layer % backward(prev_layer % output, gradient) + type is(self_attention_layer) + call this_layer % backward(prev_layer % output, gradient) + end select + + type is(self_attention_layer) + + select type(prev_layer => previous % p) + type is(input2d_layer) + call this_layer % backward(prev_layer % output, gradient) + type is(linear2d_layer) + call this_layer % backward(prev_layer % output, gradient) + type is(self_attention_layer) + call this_layer % backward(prev_layer % output, gradient) end select end select @@ -240,6 +256,20 @@ module subroutine forward(self, input) call this_layer % forward(prev_layer % output) type is(linear2d_layer) call this_layer % forward(prev_layer % output) + type is(self_attention_layer) + call this_layer % forward(prev_layer % output) + end select + + type is(self_attention_layer) + + ! Upstream layers permitted: input2d, linear2d + select type(prev_layer => input % p) + type is(input2d_layer) + call this_layer % forward(prev_layer % output) + type is(linear2d_layer) + call this_layer % forward(prev_layer % output) + type is(self_attention_layer) + call this_layer % forward(prev_layer % output) end select end select @@ -279,6 +309,8 @@ pure module subroutine get_output_2d(self, output) allocate(output, source=this_layer % output) type is(linear2d_layer) allocate(output, source=this_layer % output) + type is(self_attention_layer) + allocate(output, source=this_layer % output) class default error stop '2-d output can only be read from an input2d or linear2d layer.' @@ -322,8 +354,8 @@ impure elemental module subroutine init(self, input) call this_layer % init(input % layer_shape) end select - ! The shape of conv2d, dropout, flatten, linear2d, or maxpool2d layers - ! is not known until we receive an input layer. + ! The shape of conv2d, dropout, flatten, linear2d, maxpool2d, or + ! self_attention layers is not known until we receive an input layer. select type(this_layer => self % p) type is(conv2d_layer) self % layer_shape = shape(this_layer % output) @@ -333,6 +365,8 @@ impure elemental module subroutine init(self, input) self % layer_shape = shape(this_layer % output) type is(linear2d_layer) self % layer_shape = shape(this_layer % output) + type is(self_attention_layer) + self % layer_shape = shape(this_layer % output) type is(maxpool2d_layer) self % layer_shape = shape(this_layer % output) end select @@ -389,6 +423,8 @@ elemental module function get_num_params(self) result(num_params) num_params = 0 type is (linear2d_layer) num_params = this_layer % get_num_params() + type is (self_attention_layer) + num_params = this_layer % get_num_params() class default error stop 'Unknown layer type.' end select @@ -420,6 +456,8 @@ module function get_params(self) result(params) ! No parameters to get. type is (linear2d_layer) params = this_layer % get_params() + type is (self_attention_layer) + params = this_layer % get_params() class default error stop 'Unknown layer type.' end select @@ -451,6 +489,8 @@ module function get_gradients(self) result(gradients) ! No gradients to get. type is (linear2d_layer) gradients = this_layer % get_gradients() + type is (self_attention_layer) + gradients = this_layer % get_gradients() class default error stop 'Unknown layer type.' end select @@ -506,6 +546,9 @@ module subroutine set_params(self, params) type is (linear2d_layer) call this_layer % set_params(params) + type is (self_attention_layer) + call this_layer % set_params(params) + type is (maxpool2d_layer) ! No parameters to set. write(stderr, '(a)') 'Warning: calling set_params() ' & diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 new file mode 100644 index 00000000..80a59dfb --- /dev/null +++ b/src/nf/nf_multihead_attention.f90 @@ -0,0 +1,165 @@ +module nf_multihead_attention_layer + use iso_fortran_env, only: stderr => error_unit + use nf_activation, only: softmax + use nf_base_layer, only: base_layer + use nf_linear2d_layer, only: linear2d_layer + + implicit none + + private + public :: multihead_attention_layer + + type, extends(base_layer) :: multihead_attention_layer + !! MultiHead Attention + !! Attention mechanism is widely used in Machine Learning, particularly in + !! Natural Language Processing, and is the basis of modern Language Models. + !! Attention creates Saliency Map between tokens that helps the model + !! achieve deeper contextual understanding of the data. + !! This implementation is based upon the Transformers article and + !! uses attention heads to help parallelize computations. + !! Source: + !! Waswani A. et al. Attention is all you need. + !! https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf + integer :: sequence_length, model_dimension, n_heads, head_size + + type(linear2d_layer) :: query_layer + type(linear2d_layer) :: key_layer + type(linear2d_layer) :: value_layer + type(linear2d_layer) :: output_layer + + type(softmax) :: softmax_func + + real, allocatable :: attention_matrix(:, :, :) + real, allocatable :: sdpa(:, :, :) + real, allocatable :: output(:, :) + + real :: scaling_factor + + real, allocatable :: q_input(:, :) + real, allocatable :: k_input(:, :) + real, allocatable :: v_input(:, :) + real, allocatable :: o_input(:, :) + contains + + procedure :: common_backward + procedure :: common_forward + procedure :: get_num_params + procedure :: get_params + procedure :: get_gradients + procedure :: set_params + procedure :: init_base + procedure :: init => init_base ! in case general MHA needs to be used + + ! FIXME: those should be private but accessible by tests + procedure :: split_heads + procedure :: create_attention_matrix + procedure :: normalize_attention_matrix + procedure :: scaled_dot_product_attention + procedure :: combine_heads + end type multihead_attention_layer + + interface multihead_attention_layer + module function multihead_attention_layer_cons(n_heads) result(res) + !! This function returns the `multihead_attention_layer` instance. + integer, intent(in) :: n_heads + type(multihead_attention_layer) :: res + end function multihead_attention_layer_cons + end interface multihead_attention_layer + + interface + + pure module subroutine common_backward(self, input, gradient) + !! General backprop for MultiHead Attention mechanism + !! Might be used for both Self and Cross Attention + !! Self Attention: sum output gradients + !! Cross Attention: use them separately + class(multihead_attention_layer), intent(in out) :: self + real, intent(in) :: input(:, :) + real, intent(in) :: gradient(:, :) + end subroutine common_backward + + pure module subroutine common_forward(self, query, key, value) + !! General forward propagation for MultiHead Attention Mechanism + !! Might be used for both Self and Cross Attention + !! Self Attention: pass the same value thrice + !! Cross Attention: pass three values for your query, key and value + class(multihead_attention_layer), intent(in out) :: self + real, intent(in) :: query(:, :), key(:, :), value(:, :) + end subroutine common_forward + + pure module subroutine init(self, input_shape) + !! Initialize the layer data structures. + !! + !! This is a deferred procedure from the `base_layer` abstract type. + class(multihead_attention_layer), intent(in out) :: self + integer, intent(in) :: input_shape(:) + end subroutine init + + pure module function split_heads(self, input) result(output) + !! Split inputs into heads + !! + !! Example with two heads: + !! input (3, 4) + !! output (3, 2, 2) + class(multihead_attention_layer), intent(in) :: self + real, intent(in) :: input(:, :) + real :: output(self % sequence_length, self % head_size, self % n_heads) + end function split_heads + + pure module subroutine create_attention_matrix(self, query, key) + !! Create attention matrix for query and key + !! Output dimensions: sequence_length, sequence_length, n_heads + class(multihead_attention_layer), intent(in out) :: self + real, intent(in) :: query(:, :, :) + real, intent(in) :: key(:, :, :) + end subroutine create_attention_matrix + + pure module subroutine normalize_attention_matrix(self, attention_mask) + !! Create attention matrix for query and key + !! Output dims: sequence_length, sequence_length, n_heads + class(multihead_attention_layer), intent(in out) :: self + !! (sequence_length, sequence_length, n_heads) + real, optional, intent(in) :: attention_mask(:, :, :) + !! (sequence_length, sequence_length, n_heads) + end subroutine normalize_attention_matrix + + pure module subroutine scaled_dot_product_attention(self, value) + !! Create scaled dot product attention + !! Output dims: sequence_length, head_size, n_heads + class(multihead_attention_layer), intent(in out) :: self + real, intent(in) :: value(:, :, :) + end subroutine scaled_dot_product_attention + + pure module function combine_heads(self, input) result(output) + class(multihead_attention_layer), intent(in) :: self + real, intent(in) :: input(:, :, :) + !! (sequence_length, head_size, n_heads) + real :: output(self % sequence_length, self % model_dimension) + end function combine_heads + + elemental module function get_num_params(self) result(num_params) + class(multihead_attention_layer), intent(in) :: self + integer :: num_params + end function get_num_params + + module function get_params(self) result(params) + class(multihead_attention_layer), intent(in), target :: self + real, allocatable :: params(:) + end function get_params + + module function get_gradients(self) result(gradients) + class(multihead_attention_layer), intent(in), target :: self + real, allocatable :: gradients(:) + end function get_gradients + + module subroutine set_params(self, params) + class(multihead_attention_layer), intent(in out) :: self + real, intent(in), target :: params(:) + end subroutine set_params + + module subroutine init_base(self, input_shape) + class(multihead_attention_layer), intent(in out) :: self + integer, intent(in) :: input_shape(:) + end subroutine init_base + end interface +end module nf_multihead_attention_layer diff --git a/src/nf/nf_multihead_attention_submodule.f90 b/src/nf/nf_multihead_attention_submodule.f90 new file mode 100644 index 00000000..d0e43a2e --- /dev/null +++ b/src/nf/nf_multihead_attention_submodule.f90 @@ -0,0 +1,343 @@ +submodule(nf_multihead_attention_layer) nf_multihead_attention_layer_submodule +! use iso_fortran_env, only: stderr => error_unit + use nf_activation, only: softmax + use nf_base_layer, only: base_layer + use nf_linear2d_layer, only: linear2d_layer + + implicit none + +contains + module function multihead_attention_layer_cons(n_heads) result(res) + integer, intent(in) :: n_heads + type(multihead_attention_layer) :: res + + res % n_heads = n_heads + end function multihead_attention_layer_cons + + pure module subroutine common_backward(self, input, gradient) + class(multihead_attention_layer), intent(in out) :: self + real, intent(in) :: input(:, :) + real, intent(in) :: gradient(:, :) + + real, allocatable :: d_output(:, :, :) + real, allocatable :: v_heads(:, :, :) + real, allocatable :: k_heads(:, :, :) + real, allocatable :: q_heads(:, :, :) + real, allocatable :: dv(:, :, :) + real, allocatable :: d_sdpa(:, :) + real, allocatable :: jacobian(:, :) + real, allocatable :: d_normalize(:, :, :) + real, allocatable :: dq(:, :, :) + real, allocatable :: dk(:, :, :) + integer :: head, seq, i, j + + ! allocate temporary storages for backward computation + allocate(d_output(self % sequence_length, self % head_size, self % n_heads)) + allocate(v_heads(self % sequence_length, self % head_size, self % n_heads)) + allocate(k_heads(self % sequence_length, self % head_size, self % n_heads)) + allocate(q_heads(self % sequence_length, self % head_size, self % n_heads)) + + allocate(dv(self % sequence_length, self % head_size, self % n_heads)) + allocate(d_sdpa(self % sequence_length, self % sequence_length)) + allocate(jacobian(self % sequence_length, self % sequence_length)) + allocate(d_normalize(self % sequence_length, self % sequence_length, self % n_heads)) + allocate(dq(self % sequence_length, self % head_size, self % n_heads)) + allocate(dk(self % sequence_length, self % head_size, self % n_heads)) + + ! calculate output layer delta + call self % output_layer % backward(self % o_input, gradient) + + ! split heads from output gradient + d_output = self % split_heads(self % output_layer % gradient) + v_heads = self % split_heads(self % value_layer % output) + k_heads = self % split_heads(self % key_layer % output) + q_heads = self % split_heads(self % query_layer % output) + + ! iterate over heads to calculate deltas for each of them + do concurrent(head = 1: self % n_heads) + dv(:, :, head) = matmul(transpose(self % attention_matrix(:, :, head)), d_output(:, :, head)) + + ! calculate delta for attention matrix + d_sdpa = matmul(d_output(:, :, head), transpose(v_heads(:, :, head))) + + ! this monstrosity below is scaled derivative of softmax + do concurrent(seq = 1: self % sequence_length) + ! create jacobian matrix + do concurrent(i = 1: self % sequence_length, j = 1: self % sequence_length) + ! jacobian matrix is used to calculate derivative of softmax (temporary storage) + ! the idea behind this if-else is that for diagonal elements, the jacobian temp + ! should be: `softmax(x_i) * (1 - softmax(x_i))` + ! for off-diagonal: `-softmax(x_i) * softmax(x_j)` + if (i == j) then + jacobian(i, j) = & + self % attention_matrix(seq, i, head) & + * (1 - self % attention_matrix(seq, i, head)) + else + jacobian(i, j) = & + - self % attention_matrix(seq, i, head) & + * self % attention_matrix(seq, j, head) + end if + end do + ! attention normalization delta, the last step of softmax derivative: + ! multiply output of softmax by temp jacobian matrix + ! For computational efficiency (avoid more temp storages), scaling is also done here + ! reshapes: [3] -> [1, 3] @ [3, 3] = [1, 3] -> [3] + d_normalize(seq, :, head) = reshape(matmul(& + reshape(d_sdpa(seq, :), [1, self % sequence_length]),& + jacobian * self % scaling_factor& + ), [self % sequence_length]) + end do + + ! calculate delta for query + dq(:, :, head) = matmul(d_normalize(:, :, head), k_heads(:, :, head)) + + ! calculate delta for key, attention matrix should be transposed unlike for query + dk(:, :, head) = matmul(transpose(d_normalize(:, :, head)), q_heads(:, :, head)) + end do + + ! calculate deltas for input layers + call self % value_layer % backward(self % v_input, self % combine_heads(dv)) + call self % key_layer % backward(self % k_input, self % combine_heads(dk)) + call self % query_layer % backward(self % q_input, self % combine_heads(dq)) + + ! free temporary storages + deallocate(d_output) + deallocate(v_heads) + deallocate(k_heads) + deallocate(q_heads) + deallocate(d_sdpa) + deallocate(jacobian) + deallocate(d_normalize) + deallocate(dq) + deallocate(dk) + end subroutine common_backward + + pure module subroutine common_forward(self, query, key, value) + class(multihead_attention_layer), intent(in out) :: self + real, intent(in) :: query(:, :), key(:, :), value(:, :) + + real, allocatable :: q(:, :, :) + real, allocatable :: k(:, :, :) + real, allocatable :: v(:, :, :) + + ! allocate storage for intermidiate stages + allocate(q(self % sequence_length, self % head_size, self % n_heads)) + allocate(k(self % sequence_length, self % head_size, self % n_heads)) + allocate(v(self % sequence_length, self % head_size, self % n_heads)) + + self % q_input = query + self % k_input = key + self % v_input = value + + ! run inputs through linear layers (trainable params) + call self % query_layer % forward(query) + call self % key_layer % forward(key) + call self % value_layer % forward(value) + + ! split attention heads for more efficient computation + q = self % split_heads(self % query_layer % output) + k = self % split_heads(self % key_layer % output) + v = self % split_heads(self % value_layer % output) + + ! create key by value matrix + call self % create_attention_matrix(q, k) + ! apply softmax and scaling + call self % normalize_attention_matrix() + ! multiply attention matrix by value + call self % scaled_dot_product_attention(v) + + self % o_input = self % combine_heads(self % sdpa) + call self % output_layer % forward(self % o_input) + self % output = self % output_layer % output + + ! free temp vars from memory + deallocate(q) + deallocate(k) + deallocate(v) + end subroutine common_forward + + pure module function split_heads(self, input) result(output) + class(multihead_attention_layer), intent(in) :: self + real, intent(in) :: input(:, :) + real :: output(self % sequence_length, self % head_size, self % n_heads) + output = reshape(input, [self % sequence_length, self % head_size, self % n_heads]) + end function split_heads + + pure module subroutine create_attention_matrix(self, query, key) + class(multihead_attention_layer), intent(in out) :: self + real, intent(in) :: query(:, :, :) + real, intent(in) :: key(:, :, :) + integer :: head + ! create attention matrix for each sequence in each batch + do concurrent(head = 1: self % n_heads) + self % attention_matrix(:, :, head) = matmul(query(:, :, head), transpose(key(:, :, head))) + end do + end subroutine create_attention_matrix + + pure module subroutine normalize_attention_matrix(self, attention_mask) + class(multihead_attention_layer), intent(in out) :: self + real, optional, intent(in) :: attention_mask(:, :, :) + real, allocatable :: output(:, :, :) + integer :: head, seq + + ! temporary storage + allocate(output(self % sequence_length, self % sequence_length, self % n_heads)) + + ! scale dowm by square root of each head's size + self % attention_matrix = self % attention_matrix * self % scaling_factor + ! attention mask is used to mask out some of the tokens if necessary + if (present(attention_mask)) then + self % attention_matrix = self % attention_matrix + attention_mask + end if + ! softmax by last sequnce_length + do concurrent(head = 1: self % n_heads, seq = 1: self % sequence_length) + output(seq, :, head) = self % softmax_func % eval_1d(self % attention_matrix(seq, :, head)) + end do + self % attention_matrix = output + + deallocate(output) + end subroutine normalize_attention_matrix + + pure module subroutine scaled_dot_product_attention(self, value) + class(multihead_attention_layer), intent(in out) :: self + real, intent(in) :: value(:, :, :) + integer :: head + + do concurrent(head = 1: self % n_heads) + self % sdpa(:, :, head) = matmul(self % attention_matrix(:, :, head), value(:, :, head)) + end do + end subroutine scaled_dot_product_attention + + pure module function combine_heads(self, input) result(output) + class(multihead_attention_layer), intent(in) :: self + real, intent(in) :: input(:, :, :) + real :: output(self % sequence_length, self % model_dimension) + integer :: seq + + do concurrent(seq = 1: self % sequence_length) + output(seq, :) = reshape(transpose(input(seq, :, :)), [self % model_dimension]) + end do + end function combine_heads + + elemental module function get_num_params(self) result(num_params) + class(multihead_attention_layer), intent(in) :: self + integer :: num_params + + num_params = & + self % query_layer % get_num_params() & + + self % key_layer % get_num_params() & + + self % value_layer % get_num_params() & + + self % output_layer % get_num_params() + end function get_num_params + + module function get_params(self) result(params) + class(multihead_attention_layer), intent(in), target :: self + real, allocatable :: params(:) + + params = [& + self % query_layer % weights,& + self % key_layer % weights,& + self % value_layer % weights,& + self % output_layer % weights,& + self % query_layer % biases,& + self % key_layer % biases,& + self % value_layer % biases,& + self % output_layer % biases& + ] + end function get_params + + module function get_gradients(self) result(gradients) + class(multihead_attention_layer), intent(in), target :: self + real, allocatable :: gradients(:) + + gradients = [ & + self % query_layer % dw,& + self % key_layer % dw,& + self % value_layer % dw,& + self % output_layer % dw,& + self % query_layer % db,& + self % key_layer % db,& + self % value_layer % db,& + self % output_layer % db& + ] + end function get_gradients + + module subroutine set_params(self, params) + class(multihead_attention_layer), intent(in out) :: self + real, intent(in), target :: params(:) + real, pointer :: p_(:,:) => null() + integer :: i, j, window + + ! check if the number of parameters is correct + if (size(params) /= self % get_num_params()) then + error stop 'Error: number of parameters does not match' + end if + + ! FIXME: looks clumsy, better ideas? + window = self % model_dimension * self % model_dimension + i = 1 + j = window + self % query_layer % weights = reshape(params(i: j), [self % model_dimension, self % model_dimension]) + i = j + 1 + j = i + window - 1 + self % key_layer % weights = reshape(params(i: j), [self % model_dimension, self % model_dimension]) + i = j + 1 + j = i + window - 1 + self % value_layer % weights = reshape(params(i: j), [self % model_dimension, self % model_dimension]) + i = j + 1 + j = i + window - 1 + self % output_layer % weights = reshape(params(i: j), [self % model_dimension, self % model_dimension]) + + window = self % model_dimension + i = j + 1 + j = i + window - 1 + self % query_layer % biases = params(i: j) + i = j + 1 + j = i + window - 1 + self % key_layer % biases = params(i: j) + i = j + 1 + j = i + window - 1 + self % value_layer % biases = params(i: j) + i = j + 1 + j = i + window - 1 + self % output_layer % biases = params(i: j) + end subroutine set_params + + module subroutine init_base(self, input_shape) + class(multihead_attention_layer), intent(in out) :: self + integer, intent(in) :: input_shape(:) + + if (size(input_shape) /= 2) then + error stop "MultiHead Attention accepts 2D input" + end if + self % sequence_length = input_shape(1) + self % model_dimension = input_shape(2) + + if (mod(self % model_dimension, self % n_heads) /= 0) then + write(stderr, '(a)'), 'Number of heads must be divisible by model dimension' + error stop + end if + self % head_size = self % model_dimension / self % n_heads + self % softmax_func = softmax() + + self % query_layer = linear2d_layer(self % model_dimension) + self % key_layer = linear2d_layer(self % model_dimension) + self % value_layer = linear2d_layer(self % model_dimension) + self % output_layer = linear2d_layer(self % model_dimension) + call self % query_layer % init([self % sequence_length, self % model_dimension]) + call self % key_layer % init([self % sequence_length, self % model_dimension]) + call self % value_layer % init([self % sequence_length, self % model_dimension]) + call self % output_layer % init([self % sequence_length, self % model_dimension]) + + allocate(self % attention_matrix(self % sequence_length, self % sequence_length, self % n_heads)) + allocate(self % sdpa(self % sequence_length, self % head_size, self % n_heads)) + allocate(self % output(self % sequence_length, self % model_dimension)) + + self % scaling_factor = sqrt(1 / real(self % head_size)) + + allocate(self % q_input(self % sequence_length, self % model_dimension)) + allocate(self % k_input(self % sequence_length, self % model_dimension)) + allocate(self % v_input(self % sequence_length, self % model_dimension)) + allocate(self % o_input(self % sequence_length, self % model_dimension)) + end subroutine init_base +end submodule nf_multihead_attention_layer_submodule \ No newline at end of file diff --git a/src/nf/nf_network_submodule.f90 b/src/nf/nf_network_submodule.f90 index dd632d96..f344c5c5 100644 --- a/src/nf/nf_network_submodule.f90 +++ b/src/nf/nf_network_submodule.f90 @@ -10,6 +10,7 @@ use nf_maxpool2d_layer, only: maxpool2d_layer use nf_reshape_layer, only: reshape3d_layer use nf_linear2d_layer, only: linear2d_layer + use nf_self_attention_layer, only: self_attention_layer use nf_layer, only: layer use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d, reshape use nf_loss, only: quadratic @@ -160,6 +161,8 @@ module subroutine backward(self, output, loss) call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) type is(linear2d_layer) call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(self_attention_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) end select end if diff --git a/src/nf/nf_self_attention_layer.f90 b/src/nf/nf_self_attention_layer.f90 new file mode 100644 index 00000000..15e8f40c --- /dev/null +++ b/src/nf/nf_self_attention_layer.f90 @@ -0,0 +1,69 @@ +module nf_self_attention_layer + use iso_fortran_env, only: stderr => error_unit + use nf_activation, only: softmax + use nf_linear2d_layer, only: linear2d_layer + use nf_multihead_attention_layer, only: multihead_attention_layer + + implicit none + + type, extends(multihead_attention_layer) :: self_attention_layer + !! Self Attention Layer + !! Source: + !! Parikh, A. P., Taeckstroem, O., Das, D., & Uszkoreit, J. (2016) + !! A decomposable attention model for natural language inference. + !! https://arxiv.org/pdf/1606.01933 + real, allocatable :: gradient(:, :) + contains + procedure :: forward + procedure :: backward + procedure :: init + end type self_attention_layer + + interface self_attention_layer + module function self_attention_layer_cons(n_heads) result(res) + !! This function returns the `self_attention_layer` instance. + integer, intent(in) :: n_heads + type(self_attention_layer) :: res + end function self_attention_layer_cons + end interface self_attention_layer + +contains + module function self_attention_layer_cons(n_heads) result(res) + !! This function returns the `self_attention_layer` instance. + integer, intent(in) :: n_heads + type(self_attention_layer) :: res + res % n_heads = n_heads + end function self_attention_layer_cons + + pure module subroutine backward(self, input, gradient) + !! Self Attention back propagation + !! Returns sum of Query, Key and Value gradients + class(self_attention_layer), intent(in out) :: self + real, intent(in) :: input(:, :) + real, intent(in) :: gradient(:, :) + + call self % common_backward(input, gradient) + self % gradient = & + self % query_layer % gradient & + + self % key_layer % gradient & + + self % value_layer % gradient + end subroutine backward + + pure module subroutine forward(self, input) + !! Cross Attention forward propagation + !! Passes input three times into MultiHead Attention + !! Input Shape: (sequence_length, model_dimension) + class(self_attention_layer), intent(in out) :: self + real, intent(in) :: input(:, :) + + call self % common_forward(input, input, input) + end subroutine forward + + module subroutine init(self, input_shape) + class(self_attention_layer), intent(in out) :: self + integer, intent(in) :: input_shape(:) + + call self % init_base(input_shape) + allocate(self % gradient(self % sequence_length, self % model_dimension)) + end subroutine init +end module nf_self_attention_layer diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 1716dc8c..741e9930 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -11,6 +11,7 @@ foreach(execid flatten_layer insert_flatten reshape_layer + multihead_attention_layer dense_network get_set_network_params conv2d_network diff --git a/test/test_multihead_attention_layer.f90 b/test/test_multihead_attention_layer.f90 new file mode 100644 index 00000000..fdc6862d --- /dev/null +++ b/test/test_multihead_attention_layer.f90 @@ -0,0 +1,412 @@ +program test_multihead_attention_layer + use iso_fortran_env, only: stderr => error_unit + use nf_multihead_attention_layer, only: multihead_attention_layer + use nf_self_attention_layer, only: self_attention_layer + use nf_cross_attention_layer, only: cross_attention_layer + use nf_linear2d_layer, only: linear2d_layer + use nf_optimizers, only: sgd + implicit none + + logical :: ok = .true. + type(multihead_attention_layer) :: attention + real :: sample_input(3, 4) = reshape([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.11, 0.12], [3, 4]) + real :: split_heads_output(3, 2, 2) + real :: minput(3, 4) = reshape([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.11, 0.12], [3, 4]) + real :: output(3, 2, 2) + + attention = multihead_attention_layer(n_heads=2) + call attention % init_base([3, 4]) + call set_weights(attention) + + call test_multihead_attention_split_heads(attention, sample_input, ok, split_heads_output) + call test_multihead_attention_create_attention_matrix(attention, split_heads_output, ok) + call test_multihead_attention_normalization(attention, ok) + call test_multihead_attention_scaled_dot_product_attention(attention, split_heads_output, ok) + call test_multihead_attention_combine_heads(attention, attention % sdpa, ok) + call test_multihead_attention_forward(attention, ok) + call test_multihead_attention_backward(attention, ok) + call test_multihead_attention_update_gradients(attention, ok) + call test_multihead_attention_forward_reallife_shape(ok) + call test_self_attention(ok) + call test_cross_attention(ok) + + if (ok) then + print '(a)', 'test_multihead_attention_layer: All tests passed.' + else + write(stderr, '(a)') 'test_multihead_attention_layer: One or more tests failed.' + stop 1 + end if + +contains + function allclose(x, y) result(res) + real, intent(in) :: x(:) + real, intent(in) :: y(:) + logical :: res + + res = all(abs(x - y) <= (1e-06 + 1e-05 * abs(y))) + end function allclose + + subroutine set_weights(attention) + type(multihead_attention_layer), intent(in out) :: attention + attention % query_layer % weights = 0.1 + attention % key_layer % weights = 0.1 + attention % value_layer % weights = 0.1 + attention % output_layer % weights = 0.1 + attention % query_layer % biases = 0.11 + attention % key_layer % biases = 0.11 + attention % value_layer % biases = 0.11 + attention % output_layer % biases = 0.11 + end subroutine set_weights + + subroutine test_multihead_attention_split_heads(attention, input, ok, output) + type(multihead_attention_layer), intent(in) :: attention + real, intent(in) :: input(:, :) + logical, intent(in out) :: ok + real, intent(in out) :: output(3, 2, 2) + real :: output_shape(3) + real :: expected_shape(3) = [3, 2, 2] + real :: output_flat(12) + real :: expected_output_flat(12) = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.11, 0.12] + + output = attention % split_heads(input) + + output_shape = shape(output) + if (.not. all(output_shape.eq.expected_shape)) then + ok = .false. + write(stderr, '(a)') 'split_heads returned incorrect shape.. failed' + end if + output_flat = reshape(output, shape(output_flat)) + if (.not. all(output_flat.eq.expected_output_flat)) then + ok = .false. + write(stderr, '(a)') 'split_heads returned incorrect values.. failed' + end if + end subroutine test_multihead_attention_split_heads + + subroutine test_multihead_attention_create_attention_matrix(attention, input, ok) + type(multihead_attention_layer), intent(in out) :: attention + real, intent(in) :: input(:, :, :) + logical, intent(in out) :: ok + real :: attention_matrix_shape(3) + real, volatile :: attention_matrix_flat(18) + real :: expected_shape(3) = [3, 3, 2] + real :: expected_attention_matrix_flat(18) = [& + 0.09, 0.12, 0.15, 0.12, 0.17, 0.22,& + 0.15, 0.22, 0.29, 1.17, 0.519, 0.588,& + 0.519, 0.5021, 0.5732, 0.588, 0.5732, 0.6544& + ] + + call attention % create_attention_matrix(input, input) + + attention_matrix_shape = shape(attention % attention_matrix) + if (.not. all(attention_matrix_shape.eq.expected_shape)) then + ok = .false. + write(stderr, '(a)') 'create_attention_matrix returned incorrect shape.. failed' + end if + attention_matrix_flat = reshape(attention % attention_matrix, shape(expected_attention_matrix_flat)) + if (.not. allclose(attention_matrix_flat, expected_attention_matrix_flat)) then + ok = .false. + write(stderr, '(a)') 'create_attention_matrix returned incorrect values.. failed' + end if + end subroutine test_multihead_attention_create_attention_matrix + + subroutine test_multihead_attention_normalization(attention, ok) + type(multihead_attention_layer), intent(in out) :: attention + logical, intent(in out) :: ok + real, volatile :: output_flat(18) + real :: expected_output_flat(18) = [& + 0.326287806, 0.321620107, 0.316976935, 0.333283335, 0.333194494, 0.333061278,& + 0.340428889, 0.345185429, 0.349961787, 0.435975075, 0.330339372, 0.329200655,& + 0.275134116, 0.326415271, 0.325773478, 0.288890868, 0.343245387, 0.345025837& + ] + + call attention % normalize_attention_matrix() + + output_flat = reshape(attention % attention_matrix, shape(output_flat)) + if (.not. allclose(output_flat, expected_output_flat)) then + ok = .false. + write(stderr, '(a)') 'normalize_attention_matrix returned incorrect values.. failed' + end if + end subroutine test_multihead_attention_normalization + + subroutine test_multihead_attention_scaled_dot_product_attention(attention, value, ok) + type(multihead_attention_layer), intent(in out) :: attention + real, intent(in) :: value(:, :, :) + logical, intent(in out) :: ok + real, volatile :: output_flat(12) + real :: expected_output_flat(12) = [& + 0.101414114, 0.102356538, 0.103298485, 0.401414126, 0.402356565, 0.403298497,& + 0.685291648, 0.701290667, 0.701582491, 0.457309216, 0.374400556, 0.373518765& + ] + + call attention % scaled_dot_product_attention(value) + + output_flat = reshape(attention % sdpa, shape(output_flat)) + if (.not. allclose(output_flat, expected_output_flat)) then + ok = .false. + write(stderr, '(a)') 'scaled_dot_product_attention returned incorrect values.. failed' + end if + end subroutine test_multihead_attention_scaled_dot_product_attention + + subroutine test_multihead_attention_combine_heads(attention, scaled_dp_att, ok) + type(multihead_attention_layer), intent(in) :: attention + real, intent(in) :: scaled_dp_att(:, :, :) + logical, intent(in out) :: ok + real :: output(attention % sequence_length, attention % model_dimension) + real :: output_flat(12) + real :: expected_output_flat(12) = [& + 0.101414114, 0.102356538, 0.103298485, 0.685291648, 0.701290667, 0.701582491,& + 0.401414126, 0.402356565, 0.403298497, 0.457309216, 0.374400556, 0.373518765& + ] + + output = attention % combine_heads(scaled_dp_att) + + output_flat = reshape(output, shape(output_flat)) + if (.not. allclose(output_flat, expected_output_flat)) then + ok = .false. + write(stderr, '(a)') 'combine_heads returned incorrect values.. failed' + end if + end subroutine test_multihead_attention_combine_heads + + subroutine test_multihead_attention_forward(attention, ok) + type(multihead_attention_layer), intent(in out) :: attention + logical, intent(in out) :: ok + real :: input(3, 4) = reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4]) + real :: output(attention % sequence_length, attention % model_dimension) + real, volatile :: output_flat(12) + integer :: output_shape(2) + integer :: attn_weights_shape(3) + real, volatile :: attn_weights_flat(18) + integer :: expected_shape(2) = [3, 4] + real :: expected_output_flat(12) = [& + 0.982241452, 1.00407875, 1.00444126, 0.982241452, 1.00407875, 1.00444126,& + 0.982241452, 1.00407875, 1.00444126, 0.982241452, 1.00407875, 1.00444126& + ] + integer :: expected_attn_weights_shape(3) = [3, 3, 2] + real :: expected_attn_weights_flat(18) = [& + 7.89450705E-02, 2.28110179E-02, 2.18846574E-02, 0.447508544, 0.464612424, 0.464721352,& + 0.473546445, 0.512576580, 0.513393998, 7.89450705E-02, 2.28110179E-02, 2.18846574E-02,& + 0.447508544, 0.464612424, 0.464721352, 0.473546445, 0.512576580, 0.513393998& + ] + + call attention % common_forward(input, input, input) + + output_shape = shape(attention % output) + if (.not. all(output_shape.eq.expected_shape)) then + ok = .false. + write(stderr, '(a)') 'forward returned incorrect shape.. failed' + end if + output_flat = reshape(attention % output, shape(output_flat)) + if (.not. allclose(output_flat, expected_output_flat)) then + ok = .false. + write(stderr, '(a)') 'forward returned incorrect values.. failed' + end if + + attn_weights_shape = shape(attention % attention_matrix) + if (.not. all(attn_weights_shape.eq.expected_attn_weights_shape)) then + ok = .false. + write(stderr, '(a)') 'forward returned incorrect attention weights shape.. failed' + end if + attn_weights_flat = reshape(attention % attention_matrix, shape(attn_weights_flat)) + if (.not. allclose(attn_weights_flat, expected_attn_weights_flat)) then + ok = .false. + write(stderr, '(a)') 'forward returned incorrect attention weights values.. failed' + end if + end subroutine test_multihead_attention_forward + + subroutine test_multihead_attention_forward_reallife_shape(ok) + logical, intent(in out) :: ok + real :: input(148, 512) + real :: output(148, 512) + integer :: output_shape(2) + integer :: expected_shape(2) = [148, 512] + type(multihead_attention_layer) :: attention + + call random_number(input) + + attention = multihead_attention_layer(n_heads=8) + call attention % init_base([148, 512]) + call set_weights(attention) + + call attention % common_forward(input, input, input) + + output_shape = shape(attention % output) + if (.not. all(output_shape.eq.expected_shape)) then + ok = .false. + write(stderr, '(a)') 'forward returned incorrect shape.. failed' + end if + end subroutine test_multihead_attention_forward_reallife_shape + + subroutine test_multihead_attention_backward(attention, ok) + type(multihead_attention_layer), intent(in out) :: attention + logical, intent(in out) :: ok + real :: input(3, 4) = reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4]) + real :: gradient(3, 4) = reshape([0.1, 3., 2., 0.1, 3., 3., 0.1, 2., 0.1, 3., 0.1, 3.], [3, 4]) + real :: expected_output_flat(12) = [& + -2.29912549E-02, 0.381484956, 0.453185737,& + -2.29912549E-02, 0.381484956, 0.453185737,& + -2.29912549E-02, 0.381484956, 0.453185737,& + -2.29912549E-02, 0.381484956, 0.453185737& + ] + real :: expected_shape(2) = [3, 4] + real :: output(3, 4) + real, volatile :: output_flat(12) + real :: output_shape(2) + + call attention % common_backward(input, gradient) + + ! sample for Self Attention: sum of output gradients + ! FIXME: remove reshapes when linear2d situation is resolved + output = & + reshape(attention % query_layer % gradient, [attention % sequence_length, attention % model_dimension]) & + + reshape(attention % key_layer % gradient, [attention % sequence_length, attention % model_dimension]) & + + reshape(attention % value_layer % gradient, [attention % sequence_length, attention % model_dimension]) + + output_shape = shape(output) + if (.not. all(output_shape.eq.expected_shape)) then + ok = .false. + write(stderr, '(a)') 'backward returned incorrect shape.. failed' + end if + output_flat = reshape(output, shape(output_flat)) + if (.not. allclose(output_flat, expected_output_flat)) then + ok = .false. + write(stderr, '(a)') 'backward returned incorrect values.. failed' + end if + end subroutine test_multihead_attention_backward + + subroutine test_multihead_attention_update_gradients(attention, ok) + type(multihead_attention_layer), intent(in out) :: attention + logical, intent(in out) :: ok + real :: parameters(80) + real :: expected_parameters(80) + real, volatile :: updated_output(12) + real :: expected_updated_output(12) = [& + 0.111365855, 0.115744293, 0.115733206, 0.185253710, 0.196646214, 0.196617395,& + -0.102874994, -0.118834510, -0.118794113, 0.179314315, 0.190210193, 0.190182626& + ] + type(sgd) :: optim + + if (attention % get_num_params() /= 80) then + ok = .false. + write(stderr, '(a)') 'incorrect number of parameters.. failed' + end if + + expected_parameters(1: 64) = 0.100000001 + expected_parameters(65: 80) = 0.109999999 + parameters = attention % get_params() + if (.not. all(parameters.eq.expected_parameters)) then + ok = .false. + write(stderr, '(a)') 'incorrect parameters.. failed' + end if + + optim = SGD(learning_rate=0.01) + call optim % minimize(parameters, attention % get_gradients()) + call attention % set_params(parameters) + + call attention % common_forward(& + reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4]),& + reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4]),& + reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4])& + ) + + updated_output = reshape(attention % output, [12]) + if (.not. allclose(updated_output, expected_updated_output)) then + ok = .false. + write(stderr, '(a)') 'incorrect output after parameters update.. failed' + end if + end subroutine test_multihead_attention_update_gradients + + subroutine test_self_attention(ok) + logical, intent(in out) :: ok + type(self_attention_layer) :: attention + real :: input(2, 3) = reshape([-1., 0., 17., .4, 5., .6], [2, 3]) + real :: output(2, 3) + real, volatile :: output_flat(6) + real :: expected_output_flat(6) = [& + 0.772716165, 0.577548742, 0.772716165, 0.577548742, 0.772716165, 0.577548742& + ] + real :: gradient(2, 3) = reshape([1., 2., .17, 4., .5, 6.], [2, 3]) + real, volatile :: gradient_flat(6) + real :: expected_gradient_flat(6) = [& + 0.350671142, 0.607403040, 0.350671142, 0.607403040, 0.350671142, 0.607403040& + ] + + attention = self_attention_layer(n_heads=1) + call attention % init([2, 3]) + attention % query_layer % weights = 0.1 + attention % key_layer % weights = 0.1 + attention % value_layer % weights = 0.1 + attention % output_layer % weights = 0.1 + attention % query_layer % biases = 0.11 + attention % key_layer % biases = 0.11 + attention % value_layer % biases = 0.11 + attention % output_layer % biases = 0.11 + + call attention % forward(input) + output_flat = reshape(attention % output, shape(output_flat)) + if (.not. allclose(output_flat, expected_output_flat)) then + ok = .false. + write(stderr, '(a)') 'forward returned incorrect values.. failed' + end if + + call attention % backward(input, gradient) + gradient_flat = reshape(attention % gradient, shape(gradient_flat)) + if (.not. allclose(gradient_flat, expected_gradient_flat)) then + ok = .false. + write(stderr, '(a)') 'backward returned incorrect values.. failed' + end if + end subroutine test_self_attention + + subroutine test_cross_attention(ok) + logical, intent(in out) :: ok + type(cross_attention_layer) :: attention + real :: query(2, 3) = reshape([-1., 0., 17., .4, 5., .6], [2, 3]) + real :: key_value(2, 3) = reshape([0.1, -.2, 0.3, 4., 15., 0.5], [2, 3]) + real :: input(2, 2, 3) + real :: output(2, 2, 3) + real, volatile :: output_flat(6) + real :: expected_output_flat(6) = [& + 0.600311756, 0.471662223, 0.600311756, 0.471662223, 0.600311756, 0.471662223& + ] + real :: gradient(2, 3) = reshape([1., 2., .17, 4., .5, 6.], [2, 3]) + real, volatile :: query_gradient_flat(6) + real, volatile :: key_value_gradient_flat(6) + real :: expected_query_gradient_flat(6) = [& + 1.48406753E-03, 0.184446245, 1.48406753E-03, 0.184446245, 1.48406753E-03, 0.184446245& + ] + real :: expected_key_value_gradient_flat(6) = [& + 0.303095698, 0.107004307, 0.303095698, 0.107004307, 0.303095698, 0.107004307& + ] + input(1, :, :) = query + input(2, :, :) = key_value + + attention = cross_attention_layer(n_heads=1) + call attention % init([2, 3]) + attention % query_layer % weights = 0.1 + attention % key_layer % weights = 0.1 + attention % value_layer % weights = 0.1 + attention % output_layer % weights = 0.1 + attention % query_layer % biases = 0.11 + attention % key_layer % biases = 0.11 + attention % value_layer % biases = 0.11 + attention % output_layer % biases = 0.11 + + call attention % forward(input) + output_flat = reshape(attention % output, shape(output_flat)) + if (.not. allclose(output_flat, expected_output_flat)) then + ok = .false. + write(stderr, '(a)') 'forward returned incorrect values.. failed' + end if + + call attention % backward(input, gradient) + query_gradient_flat = reshape(attention % gradient(1, :, :), shape(query_gradient_flat)) + if (.not. allclose(query_gradient_flat, expected_query_gradient_flat)) then + ok = .false. + write(stderr, '(a)') 'backward returned incorrect query values.. failed' + end if + key_value_gradient_flat = reshape(attention % gradient(2, :, :), shape(key_value_gradient_flat)) + if (.not. allclose(key_value_gradient_flat, expected_key_value_gradient_flat)) then + ok = .false. + write(stderr, '(a)') 'backward returned incorrect key-value values.. failed' + end if + end subroutine test_cross_attention +end program test_multihead_attention_layer