From c316ee1ed1c318a0f74fb63c8aec04dda2674aa7 Mon Sep 17 00:00:00 2001 From: Michael Voronov Date: Mon, 17 Feb 2025 22:31:52 +0400 Subject: [PATCH] Linear2d layer (#197) * linear2d_layer forward implementation * implement backward * introduce concurrency, outtroduce stupidity * fix style * add parameters api to linear2d_layer * add constructor for linear2d_layer * add integration for linear2d layer * set usage rules for linear2d_layer * add linear2d_layer to public api * update tests for linear2d layer * remove extra comment * remove rubbish * move linear2d layer logic into submodule * update cmake for linear2d_layer * update tests for linear2d_layer * update linear2d_layer tests * update linear2d_layer tests for batch last * make linear2d_layer with batch as last dimension (performance) * linear2d_layer: fix gradient updates * linear2d_layer: make it 2d * linear2d_layer: forgot a file * linear2d_layer: temporarily remove api * Don't expose the concrete layer type via nf * Report success to stdout * Include linear2d test in cmake * Add Linear2d to README * Plumbing of linear2d with input2d and linear2d * linear2d_layer: add flatten2d layer * linear2d_layer: make linear2d layer work with input2d and flatten2d * update cmake * linear2d_layer: use flatten layer instead of flatten2d * linear2d_layer: remove flatten2d layer * linear2d_layer: remove public api * linear2d_layer: update cmakelists * linear2d_layer: workaround cpu imprecision to make ci happy * Add linear2d example * linear2d_layer: remove redundant constructor args * linear2d_layer: make example converge * linear2d_layer: make weighs init with normal distribution * linear2d_layer: add loss stopping and more iterations * linear2d_layer: update tests * Tidy up * Require passing only out_features to linear2d(); tidy up * Remove linear2d example --------- Co-authored-by: milancurcic --- CMakeLists.txt | 2 + README.md | 1 + src/nf.f90 | 2 +- src/nf/nf_layer_constructors.f90 | 12 +- src/nf/nf_layer_constructors_submodule.f90 | 12 ++ src/nf/nf_layer_submodule.f90 | 53 +++++- src/nf/nf_linear2d_layer.f90 | 77 +++++++++ src/nf/nf_linear2d_layer_submodule.f90 | 136 ++++++++++++++++ src/nf/nf_network_submodule.f90 | 13 +- test/CMakeLists.txt | 1 + test/test_linear2d_layer.f90 | 177 +++++++++++++++++++++ 11 files changed, 476 insertions(+), 10 deletions(-) create mode 100644 src/nf/nf_linear2d_layer.f90 create mode 100644 src/nf/nf_linear2d_layer_submodule.f90 create mode 100644 test/test_linear2d_layer.f90 diff --git a/CMakeLists.txt b/CMakeLists.txt index 1a0a1be4..fc2ddfcb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,6 +38,8 @@ add_library(neural-fortran src/nf/nf_layer_constructors_submodule.f90 src/nf/nf_layer.f90 src/nf/nf_layer_submodule.f90 + src/nf/nf_linear2d_layer.f90 + src/nf/nf_linear2d_layer_submodule.f90 src/nf/nf_loss.f90 src/nf/nf_loss_submodule.f90 src/nf/nf_maxpool2d_layer.f90 diff --git a/README.md b/README.md index d2cff5b1..ebf7704d 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,7 @@ Read the paper [here](https://arxiv.org/abs/1902.06714). | Convolutional (2-d) | `conv2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 | ✅ | ✅(*) | | Max-pooling (2-d) | `maxpool2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 | ✅ | ✅ | | Flatten | `flatten` | `input2d`, `input3d`, `conv2d`, `maxpool2d`, `reshape` | 1 | ✅ | ✅ | +| Linear (2-d) | `linear2d` | `input2d` | 2 | ✅ | ✅ | | Reshape (1-d to 3-d) | `reshape` | `input1d`, `dense`, `flatten` | 3 | ✅ | ✅ | (*) See Issue [#145](https://github.com/modern-fortran/neural-fortran/issues/145) regarding non-converging CNN training on the MNIST dataset. diff --git a/src/nf.f90 b/src/nf.f90 index b97d9e62..e9b027c1 100644 --- a/src/nf.f90 +++ b/src/nf.f90 @@ -3,7 +3,7 @@ module nf use nf_datasets_mnist, only: label_digits, load_mnist use nf_layer, only: layer use nf_layer_constructors, only: & - conv2d, dense, flatten, input, maxpool2d, reshape + conv2d, dense, flatten, input, maxpool2d, reshape, linear2d use nf_loss, only: mse, quadratic use nf_metrics, only: corr, maxabs use nf_network, only: network diff --git a/src/nf/nf_layer_constructors.f90 b/src/nf/nf_layer_constructors.f90 index ea1c08df..2983ddcd 100644 --- a/src/nf/nf_layer_constructors.f90 +++ b/src/nf/nf_layer_constructors.f90 @@ -8,7 +8,7 @@ module nf_layer_constructors implicit none private - public :: conv2d, dense, flatten, input, maxpool2d, reshape + public :: conv2d, dense, flatten, input, maxpool2d, reshape, linear2d interface input @@ -185,6 +185,16 @@ module function reshape(output_shape) result(res) !! Resulting layer instance end function reshape + module function linear2d(out_features) result(res) + !! Rank-2 (sequence_length, out_features) linear layer constructor. + !! sequence_length is determined at layer initialization, based on the + !! output shape of the previous layer. + integer, intent(in) :: out_features + !! Number of output features + type(layer) :: res + !! Resulting layer instance + end function linear2d + end interface end module nf_layer_constructors diff --git a/src/nf/nf_layer_constructors_submodule.f90 b/src/nf/nf_layer_constructors_submodule.f90 index 4c5994ee..ae7d05dc 100644 --- a/src/nf/nf_layer_constructors_submodule.f90 +++ b/src/nf/nf_layer_constructors_submodule.f90 @@ -9,6 +9,7 @@ use nf_input3d_layer, only: input3d_layer use nf_maxpool2d_layer, only: maxpool2d_layer use nf_reshape_layer, only: reshape3d_layer + use nf_linear2d_layer, only: linear2d_layer use nf_activation, only: activation_function, relu, sigmoid implicit none @@ -71,6 +72,7 @@ module function flatten() result(res) end function flatten + module function input1d(layer_size) result(res) integer, intent(in) :: layer_size type(layer) :: res @@ -148,4 +150,14 @@ module function reshape(output_shape) result(res) end function reshape + + module function linear2d(out_features) result(res) + integer, intent(in) :: out_features + type(layer) :: res + + res % name = 'linear2d' + allocate(res % p, source=linear2d_layer(out_features)) + + end function linear2d + end submodule nf_layer_constructors_submodule diff --git a/src/nf/nf_layer_submodule.f90 b/src/nf/nf_layer_submodule.f90 index ab8d5b5d..22eabe9e 100644 --- a/src/nf/nf_layer_submodule.f90 +++ b/src/nf/nf_layer_submodule.f90 @@ -9,6 +9,7 @@ use nf_input3d_layer, only: input3d_layer use nf_maxpool2d_layer, only: maxpool2d_layer use nf_reshape_layer, only: reshape3d_layer + use nf_linear2d_layer, only: linear2d_layer use nf_optimizers, only: optimizer_base_type contains @@ -47,6 +48,8 @@ pure module subroutine backward_1d(self, previous, gradient) call this_layer % backward(prev_layer % output, gradient) type is(maxpool2d_layer) call this_layer % backward(prev_layer % output, gradient) + type is(linear2d_layer) + call this_layer % backward(prev_layer % output, gradient) end select end select @@ -60,9 +63,19 @@ pure module subroutine backward_2d(self, previous, gradient) class(layer), intent(in) :: previous real, intent(in) :: gradient(:,:) - ! Backward pass from a 2-d layer downstream currently implemented - ! only for dense and flatten layers - ! CURRENTLY NO LAYERS, tbd: pull/197 and pull/199 + select type(this_layer => self % p) + + type is(linear2d_layer) + + select type(prev_layer => previous % p) + type is(input2d_layer) + call this_layer % backward(prev_layer % output, gradient) + type is(linear2d_layer) + call this_layer % backward(prev_layer % output, gradient) + end select + + end select + end subroutine backward_2d @@ -182,6 +195,8 @@ pure module subroutine forward(self, input) call this_layer % forward(prev_layer % output) type is(reshape3d_layer) call this_layer % forward(prev_layer % output) + type is(linear2d_layer) + call this_layer % forward(prev_layer % output) end select type is(reshape3d_layer) @@ -196,6 +211,16 @@ pure module subroutine forward(self, input) call this_layer % forward(prev_layer % output) end select + type is(linear2d_layer) + + ! Upstream layers permitted: input2d, linear2d + select type(prev_layer => input % p) + type is(input2d_layer) + call this_layer % forward(prev_layer % output) + type is(linear2d_layer) + call this_layer % forward(prev_layer % output) + end select + end select end subroutine forward @@ -231,8 +256,10 @@ pure module subroutine get_output_2d(self, output) type is(input2d_layer) allocate(output, source=this_layer % output) + type is(linear2d_layer) + allocate(output, source=this_layer % output) class default - error stop '1-d output can only be read from an input1d, dense, or flatten layer.' + error stop '2-d output can only be read from an input2d or linear2d layer.' end select @@ -274,7 +301,7 @@ impure elemental module subroutine init(self, input) call this_layer % init(input % layer_shape) end select - ! The shape of conv2d, maxpool2d, or flatten layers is not known + ! The shape of linear2d, conv2d, maxpool2d, or flatten layers is not known ! until we receive an input layer. select type(this_layer => self % p) type is(conv2d_layer) @@ -283,9 +310,11 @@ impure elemental module subroutine init(self, input) self % layer_shape = shape(this_layer % output) type is(flatten_layer) self % layer_shape = shape(this_layer % output) + type is(linear2d_layer) + self % layer_shape = shape(this_layer % output) end select - self % input_layer_shape = input % layer_shape + self % input_layer_shape = input % layer_shape self % initialized = .true. end subroutine init @@ -328,6 +357,8 @@ elemental module function get_num_params(self) result(num_params) num_params = 0 type is (reshape3d_layer) num_params = 0 + type is (linear2d_layer) + num_params = this_layer % get_num_params() class default error stop 'Unknown layer type.' end select @@ -355,6 +386,8 @@ module function get_params(self) result(params) ! No parameters to get. type is (reshape3d_layer) ! No parameters to get. + type is (linear2d_layer) + params = this_layer % get_params() class default error stop 'Unknown layer type.' end select @@ -379,9 +412,11 @@ module function get_gradients(self) result(gradients) type is (maxpool2d_layer) ! No gradients to get. type is (flatten_layer) - ! No gradients to get. + ! No parameters to get. type is (reshape3d_layer) ! No gradients to get. + type is (linear2d_layer) + gradients = this_layer % get_gradients() class default error stop 'Unknown layer type.' end select @@ -429,6 +464,9 @@ module subroutine set_params(self, params) type is (conv2d_layer) call this_layer % set_params(params) + type is (linear2d_layer) + call this_layer % set_params(params) + type is (maxpool2d_layer) ! No parameters to set. write(stderr, '(a)') 'Warning: calling set_params() ' & @@ -446,6 +484,7 @@ module subroutine set_params(self, params) class default error stop 'Unknown layer type.' + end select end subroutine set_params diff --git a/src/nf/nf_linear2d_layer.f90 b/src/nf/nf_linear2d_layer.f90 new file mode 100644 index 00000000..f785a14c --- /dev/null +++ b/src/nf/nf_linear2d_layer.f90 @@ -0,0 +1,77 @@ +module nf_linear2d_layer + + use nf_activation, only: activation_function + use nf_base_layer, only: base_layer + + implicit none + + private + public :: linear2d_layer + + type, extends(base_layer) :: linear2d_layer + integer :: sequence_length, in_features, out_features, batch_size + + real, allocatable :: weights(:,:) + real, allocatable :: biases(:) + real, allocatable :: output(:,:) + real, allocatable :: gradient(:,:) ! input gradient + real, allocatable :: dw(:,:) ! weight gradients + real, allocatable :: db(:) ! bias gradients + + contains + + procedure :: backward + procedure :: forward + procedure :: init + procedure :: get_num_params + procedure :: get_params + procedure :: get_gradients + procedure :: set_params + + end type linear2d_layer + + interface linear2d_layer + module function linear2d_layer_cons(out_features) result(res) + integer, intent(in) :: out_features + type(linear2d_layer) :: res + end function linear2d_layer_cons + end interface linear2d_layer + + interface + pure module subroutine forward(self, input) + class(linear2d_layer), intent(in out) :: self + real, intent(in) :: input(:,:) + end subroutine forward + + pure module subroutine backward(self, input, gradient) + class(linear2d_layer), intent(in out) :: self + real, intent(in) :: input(:,:) + real, intent(in) :: gradient(:,:) + end subroutine backward + + module subroutine init(self, input_shape) + class(linear2d_layer), intent(in out) :: self + integer, intent(in) :: input_shape(:) + end subroutine init + + pure module function get_num_params(self) result(num_params) + class(linear2d_layer), intent(in) :: self + integer :: num_params + end function get_num_params + + module function get_params(self) result(params) + class(linear2d_layer), intent(in), target :: self + real, allocatable :: params(:) + end function get_params + + module function get_gradients(self) result(gradients) + class(linear2d_layer), intent(in), target :: self + real, allocatable :: gradients(:) + end function get_gradients + + module subroutine set_params(self, params) + class(linear2d_layer), intent(in out) :: self + real, intent(in), target :: params(:) + end subroutine set_params + end interface +end module nf_linear2d_layer diff --git a/src/nf/nf_linear2d_layer_submodule.f90 b/src/nf/nf_linear2d_layer_submodule.f90 new file mode 100644 index 00000000..0dfe7e27 --- /dev/null +++ b/src/nf/nf_linear2d_layer_submodule.f90 @@ -0,0 +1,136 @@ +submodule(nf_linear2d_layer) nf_linear2d_layer_submodule + use nf_base_layer, only: base_layer + use nf_random, only: random_normal + implicit none + +contains + + module function linear2d_layer_cons(out_features) result(res) + integer, intent(in) :: out_features + type(linear2d_layer) :: res + + res % out_features = out_features + + end function linear2d_layer_cons + + + module subroutine init(self, input_shape) + class(linear2d_layer), intent(in out) :: self + integer, intent(in) :: input_shape(:) + + if (size(input_shape) /= 2) then + error stop "linear2d layer requires 2D input." + end if + self % sequence_length = input_shape(1) + self % in_features = input_shape(2) + + allocate(self % output(self % sequence_length, self % out_features)) + allocate(self % gradient(self % sequence_length, self % in_features)) + + allocate(self % weights(self % in_features, self % out_features)) + call random_normal(self % weights) + + allocate(self % biases(self % out_features)) + call random_normal(self % biases) + + allocate(self % dw(self % in_features, self % out_features)) + self % dw = 0 + + allocate(self % db(self % out_features)) + self % db = 0 + + end subroutine init + + + pure module subroutine forward(self, input) + class(linear2d_layer), intent(in out) :: self + real, intent(in) :: input(:, :) + integer :: i + + self % output(:,:) = matmul(input(:,:), self % weights) + do concurrent(i = 1:self % sequence_length) + self % output(i,:) = self % output(i,:) + self % biases + end do + + end subroutine forward + + + pure module subroutine backward(self, input, gradient) + class(linear2d_layer), intent(in out) :: self + real, intent(in) :: input(:,:) + real, intent(in) :: gradient(:,:) + real :: db(self % out_features) + real :: dw(self % in_features, self % out_features) + integer :: i + + self % dw = self % dw + matmul(transpose(input(:,:)), gradient(:,:)) + self % db = self % db + sum(gradient(:,:), 1) + self % gradient(:,:) = matmul(gradient(:,:), transpose(self % weights)) + end subroutine backward + + + pure module function get_num_params(self) result(num_params) + class(linear2d_layer), intent(in) :: self + integer :: num_params + + ! Number of weights times number of biases + num_params = self % in_features * self % out_features + self % out_features + + end function get_num_params + + + module function get_params(self) result(params) + class(linear2d_layer), intent(in), target :: self + real, allocatable :: params(:) + + real, pointer :: w_(:) => null() + + w_(1: product(shape(self % weights))) => self % weights + + params = [ & + w_, & + self % biases & + ] + + end function get_params + + + module function get_gradients(self) result(gradients) + class(linear2d_layer), intent(in), target :: self + real, allocatable :: gradients(:) + + real, pointer :: dw_(:) => null() + + dw_(1: product(shape(self % dw))) => self % dw + + gradients = [ & + dw_, & + self % db & + ] + + end function get_gradients + + + module subroutine set_params(self, params) + class(linear2d_layer), intent(in out) :: self + real, intent(in), target :: params(:) + + real, pointer :: p_(:,:) => null() + + ! check if the number of parameters is correct + if (size(params) /= self % get_num_params()) then + error stop 'Error: number of parameters does not match' + end if + + associate(n => self % in_features * self % out_features) + ! reshape the weights + p_(1:self % in_features, 1:self % out_features) => params(1 : n) + self % weights = p_ + + ! reshape the biases + self % biases = params(n + 1 : n + self % out_features) + end associate + + end subroutine set_params + +end submodule nf_linear2d_layer_submodule \ No newline at end of file diff --git a/src/nf/nf_network_submodule.f90 b/src/nf/nf_network_submodule.f90 index e90d92d9..c2a9c903 100644 --- a/src/nf/nf_network_submodule.f90 +++ b/src/nf/nf_network_submodule.f90 @@ -8,6 +8,7 @@ use nf_input3d_layer, only: input3d_layer use nf_maxpool2d_layer, only: maxpool2d_layer use nf_reshape_layer, only: reshape3d_layer + use nf_linear2d_layer, only: linear2d_layer use nf_layer, only: layer use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d, reshape use nf_loss, only: quadratic @@ -129,6 +130,11 @@ module subroutine backward(self, output, loss) self % layers(n - 1), & self % loss % derivative(output, this_layer % output) & ) + type is(flatten_layer) + call self % layers(n) % backward( & + self % layers(n - 1), & + self % loss % derivative(output, this_layer % output) & + ) end select else ! Hidden layer; take the gradient from the next layer @@ -145,12 +151,13 @@ module subroutine backward(self, output, loss) else call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_3d) end if - type is(maxpool2d_layer) call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) type is(reshape3d_layer) call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(linear2d_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) end select end if @@ -275,6 +282,10 @@ module function predict_2d(self, input) result(res) select type(output_layer => self % layers(num_layers) % p) type is(dense_layer) res = output_layer % output + type is(flatten_layer) + res = output_layer % output + class default + error stop 'network % output not implemented for this output layer' end select end function predict_2d diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 35954894..12236416 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -2,6 +2,7 @@ foreach(execid input1d_layer input2d_layer input3d_layer + linear2d_layer parametric_activation dense_layer conv2d_layer diff --git a/test/test_linear2d_layer.f90 b/test/test_linear2d_layer.f90 new file mode 100644 index 00000000..28b99bf0 --- /dev/null +++ b/test/test_linear2d_layer.f90 @@ -0,0 +1,177 @@ +program test_linear2d_layer + use iso_fortran_env, only: stderr => error_unit + use nf_linear2d_layer, only: linear2d_layer + implicit none + + logical :: ok = .true. + real :: sample_input(3, 4) = reshape(& + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2],& + [3, 4]) + real :: sample_gradient(3, 1) = reshape([2., 2., 3.], [3, 1]) + type(linear2d_layer) :: linear + + linear = linear2d_layer(out_features=1) + call linear % init([3, 4]) + linear % weights = 0.1 + linear % biases = 0.11 + + call test_linear2d_layer_forward(linear, ok, sample_input) + call test_linear2d_layer_backward(linear, ok, sample_input, sample_gradient) + call test_linear2d_layer_gradient_updates(ok) + + if (ok) then + print '(a)', 'test_linear2d_layer: All tests passed.' + else + write(stderr, '(a)') 'test_linear2d_layer: One or more tests failed.' + stop 1 + end if + +contains + + subroutine test_linear2d_layer_forward(linear, ok, input) + type(linear2d_layer), intent(in out) :: linear + logical, intent(in out) :: ok + real, intent(in) :: input(3, 4) + real :: output_shape(2) + real :: output_flat(3) + real :: expected_shape(2) = [3, 1] + real :: expected_output_flat(3) = [0.17, 0.17, 0.17] + + call linear % forward(input) + + output_shape = shape(linear % output) + if (.not. all(output_shape.eq.expected_shape)) then + ok = .false. + write(stderr, '(a)') 'forward returned incorrect shape.. failed' + end if + output_flat = reshape(linear % output, shape(output_flat)) + if (.not. all(output_flat.eq.expected_output_flat)) then + ok = .false. + write(stderr, '(a)') 'forward returned incorrect values.. failed' + end if + end subroutine test_linear2d_layer_forward + + subroutine test_linear2d_layer_backward(linear, ok, input, gradient) + type(linear2d_layer), intent(in out) :: linear + logical, intent(in out) :: ok + real, intent(in) :: input(3, 4) + real, intent(in) :: gradient(3, 1) + real :: gradient_shape(2) + real :: dw_shape(2) + real :: db_shape(1) + real :: gradient_flat(12) + integer :: dw_flat(4) ! cpu imprecision workaround + real :: expected_gradient_shape(2) = [3, 4] + real :: expected_dw_shape(2) = [4, 1] + real :: expected_db_shape(1) = [1] + real :: expected_gradient_flat(12) = [& + 0.2, 0.2, 0.3, 0.2,& + 0.2, 0.3, 0.2, 0.2,& + 0.3, 0.2, 0.2, 0.3& + ] + integer :: expected_dw_flat(4) = [7, 7, 14, 14] ! cpu imprecision workaround + real :: expected_db(1) = [7] + + call linear % backward(input, gradient) + + gradient_shape = shape(linear % gradient) + if (.not. all(gradient_shape.eq.expected_gradient_shape)) then + ok = .false. + write(stderr, '(a)') 'backward returned incorrect gradient shape.. failed' + end if + dw_shape = shape(linear % dw) + if (.not. all(dw_shape.eq.expected_dw_shape)) then + ok = .false. + write(stderr, '(a)') 'backward returned incorrect dw shape.. failed' + end if + db_shape = shape(linear % db) + if (.not. all(db_shape.eq.expected_db_shape)) then + ok = .false. + write(stderr, '(a)') 'backward returned incorrect db shape.. failed' + end if + + gradient_flat = reshape(linear % gradient, shape(gradient_flat)) + if (.not. all(gradient_flat.eq.expected_gradient_flat)) then + ok = .false. + write(stderr, '(a)') 'backward returned incorrect gradient values.. failed' + end if + dw_flat = nint(reshape(linear % dw, shape(dw_flat)) * 10) + if (.not. all(dw_flat.eq.expected_dw_flat)) then + ok = .false. + write(stderr, '(a)') 'backward returned incorrect dw values.. failed' + end if + if (.not. all(linear % db.eq.expected_db)) then + ok = .false. + write(stderr, '(a)') 'backward returned incorrect db values.. failed' + end if + end subroutine test_linear2d_layer_backward + + subroutine test_linear2d_layer_gradient_updates(ok) + logical, intent(in out) :: ok + real :: input(3, 4) = reshape([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.11, 0.12], [3, 4]) + real :: gradient(3, 2) = reshape([0.0, 10., 0.2, 3., 0.4, 1.], [3, 2]) + type(linear2d_layer) :: linear + + integer :: num_parameters + real :: parameters(10) + real :: expected_parameters(10) = [& + 0.100000001, 0.100000001, 0.100000001, 0.100000001, 0.100000001, 0.100000001, 0.100000001, 0.100000001,& + 0.109999999, 0.109999999& + ] + real :: gradients(10) + real :: expected_gradients(10) = [& + 1.03999996, 4.09999990, 7.15999985, 1.12400007, 0.240000010, 1.56000006, 2.88000011, 2.86399961,& + 10.1999998, 4.40000010& + ] + real :: updated_parameters(10) + real :: updated_weights(8) + real :: updated_biases(2) + real :: expected_weights(8) = [& + 0.203999996, 0.509999990, 0.816000044, 0.212400019, 0.124000005, 0.256000012, 0.388000011, 0.386399955& + ] + real :: expected_biases(2) = [1.13000000, 0.550000012] + + integer :: i + + linear = linear2d_layer(out_features=2) + call linear % init([3, 4]) + linear % weights = 0.1 + linear % biases = 0.11 + call linear % forward(input) + call linear % backward(input, gradient) + + num_parameters = linear % get_num_params() + if (num_parameters /= 10) then + ok = .false. + write(stderr, '(a)') 'incorrect number of parameters.. failed' + end if + + parameters = linear % get_params() + if (.not. all(parameters.eq.expected_parameters)) then + ok = .false. + write(stderr, '(a)') 'incorrect parameters.. failed' + end if + + gradients = linear % get_gradients() + if (.not. all(gradients.eq.expected_gradients)) then + ok = .false. + write(stderr, '(a)') 'incorrect gradients.. failed' + end if + + do i = 1, num_parameters + updated_parameters(i) = parameters(i) + 0.1 * gradients(i) + end do + call linear % set_params(updated_parameters) + updated_weights = reshape(linear % weights, shape(expected_weights)) + if (.not. all(updated_weights.eq.expected_weights)) then + ok = .false. + write(stderr, '(a)') 'incorrect updated weights.. failed' + end if + updated_biases = linear % biases + if (.not. all(updated_biases.eq.expected_biases)) then + ok = .false. + write(stderr, '(a)') 'incorrect updated biases.. failed' + end if + end subroutine test_linear2d_layer_gradient_updates + +end program test_linear2d_layer \ No newline at end of file