Skip to content

Commit

Permalink
multihead_attention: tests for updated parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
OneAdder committed Feb 14, 2025
1 parent bcda13d commit 86cd7c0
Showing 1 changed file with 44 additions and 0 deletions.
44 changes: 44 additions & 0 deletions test/test_multihead_attention_layer.f90
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ program test_multihead_attention_layer
use iso_fortran_env, only: stderr => error_unit
use nf_multihead_attention_layer, only: multihead_attention_layer
use nf_linear2d_layer, only: linear2d_layer
use nf_optimizers, only: sgd
implicit none

logical :: ok = .true.
Expand All @@ -21,6 +22,7 @@ program test_multihead_attention_layer
call test_multihead_attention_combine_heads(attention, attention % sdpa, ok)
call test_multihead_attention_forward(attention, ok)
call test_multihead_attention_backward(attention, ok)
call test_multihead_attention_update_gradients(attention, ok)
! call test_multihead_attention_forward_reallife_shape(ok)

contains
Expand Down Expand Up @@ -239,4 +241,46 @@ subroutine test_multihead_attention_backward(attention, ok)
write(stderr, '(a)') 'backward returned incorrect values.. failed'
end if
end subroutine test_multihead_attention_backward

subroutine test_multihead_attention_update_gradients(attention, ok)
type(multihead_attention_layer), intent(in out) :: attention
logical, intent(in out) :: ok
real :: parameters(80)
real :: expected_parameters(80)
real :: updated_output(12)
real :: expected_updated_output(12) = [&
0.111365855, 0.115744293, 0.115733206, 0.185253710, 0.196646214, 0.196617395,&
-0.102874994, -0.118834510, -0.118794113, 0.179314315, 0.190210193, 0.190182626&
]
type(sgd) :: optim

if (attention % get_num_params() /= 80) then
ok = .false.
write(stderr, '(a)') 'incorrect number of parameters.. failed'
end if

expected_parameters(1: 64) = 0.100000001
expected_parameters(65: 80) = 0.109999999
parameters = attention % get_params()
if (.not. all(parameters.eq.expected_parameters)) then
ok = .false.
write(stderr, '(a)') 'incorrect parameters.. failed'
end if

optim = SGD(learning_rate=0.01)
call optim % minimize(parameters, attention % get_gradients())
call attention % set_params(parameters)

call attention % forward(&
reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4]),&
reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4]),&
reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4])&
)

updated_output = reshape(attention % output, [12])
if (.not. all(updated_output.eq.expected_updated_output)) then
ok = .false.
write(stderr, '(a)') 'incorrect output after parameters update.. failed'
end if
end subroutine test_multihead_attention_update_gradients
end program test_multihead_attention_layer

0 comments on commit 86cd7c0

Please sign in to comment.