Skip to content

Commit

Permalink
Generic flatten (2d and 3d) (#202)
Browse files Browse the repository at this point in the history
* Generic flatten() with 2-d and 3-d inputs

* Explicitly enable preprocessing for fpm builds

* Update README

* generic-flatten: use assumed-rank instead of generics

---------

Co-authored-by: Mikhail Voronov <[email protected]>
  • Loading branch information
milancurcic and OneAdder authored Feb 16, 2025
1 parent a28a9be commit 4ad75bc
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 19 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Read the paper [here](https://arxiv.org/abs/1902.06714).
| Dense (fully-connected) | `dense` | `input1d`, `flatten` | 1 |||
| Convolutional (2-d) | `conv2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 || ✅(*) |
| Max-pooling (2-d) | `maxpool2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 |||
| Flatten | `flatten` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 1 |||
| Flatten | `flatten` | `input2d`, `input3d`, `conv2d`, `maxpool2d`, `reshape` | 1 |||
| Reshape (1-d to 3-d) | `reshape` | `input1d`, `dense`, `flatten` | 3 |||

(*) See Issue [#145](https://github.com/modern-fortran/neural-fortran/issues/145) regarding non-converging CNN training on the MNIST dataset.
Expand Down
3 changes: 3 additions & 0 deletions fpm.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@ license = "MIT"
author = "Milan Curcic"
maintainer = "[email protected]"
copyright = "Copyright 2018-2025, neural-fortran contributors"

[preprocess]
[preprocess.cpp]
13 changes: 7 additions & 6 deletions src/nf/nf_flatten_layer.f90
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ module nf_flatten_layer
integer, allocatable :: input_shape(:)
integer :: output_size

real, allocatable :: gradient(:,:,:)
real, allocatable :: gradient_2d(:,:)
real, allocatable :: gradient_3d(:,:,:)
real, allocatable :: output(:)

contains
Expand All @@ -40,23 +41,23 @@ end function flatten_layer_cons
interface

pure module subroutine backward(self, input, gradient)
!! Apply the backward pass to the flatten layer.
!! This is a reshape operation from 1-d gradient to 3-d input.
!! Apply the backward pass to the flatten layer for 2D and 3D input.
!! This is a reshape operation from 1-d gradient to 2-d and 3-d input.
class(flatten_layer), intent(in out) :: self
!! Flatten layer instance
real, intent(in) :: input(:,:,:)
real, intent(in) :: input(..)
!! Input from the previous layer
real, intent(in) :: gradient(:)
!! Gradient from the next layer
end subroutine backward

pure module subroutine forward(self, input)
!! Propagate forward the layer.
!! Propagate forward the layer for 2D or 3D input.
!! Calling this subroutine updates the values of a few data components
!! of `flatten_layer` that are needed for the backward pass.
class(flatten_layer), intent(in out) :: self
!! Dense layer instance
real, intent(in) :: input(:,:,:)
real, intent(in) :: input(..)
!! Input from the previous layer
end subroutine forward

Expand Down
31 changes: 25 additions & 6 deletions src/nf/nf_flatten_layer_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,30 @@ end function flatten_layer_cons

pure module subroutine backward(self, input, gradient)
class(flatten_layer), intent(in out) :: self
real, intent(in) :: input(:,:,:)
real, intent(in) :: input(..)
real, intent(in) :: gradient(:)
self % gradient = reshape(gradient, shape(input))
select rank(input)
rank(2)
self % gradient_2d = reshape(gradient, shape(input))
rank(3)
self % gradient_3d = reshape(gradient, shape(input))
rank default
error stop "Unsupported rank of input"
end select
end subroutine backward


pure module subroutine forward(self, input)
class(flatten_layer), intent(in out) :: self
real, intent(in) :: input(:,:,:)
self % output = pack(input, .true.)
real, intent(in) :: input(..)
select rank(input)
rank(2)
self % output = pack(input, .true.)
rank(3)
self % output = pack(input, .true.)
rank default
error stop "Unsupported rank of input"
end select
end subroutine forward


Expand All @@ -37,8 +51,13 @@ module subroutine init(self, input_shape)
self % input_shape = input_shape
self % output_size = product(input_shape)

allocate(self % gradient(input_shape(1), input_shape(2), input_shape(3)))
self % gradient = 0
if (size(input_shape) == 2) then
allocate(self % gradient_2d(input_shape(1), input_shape(2)))
self % gradient_2d = 0
else if (size(input_shape) == 3) then
allocate(self % gradient_3d(input_shape(1), input_shape(2), input_shape(3)))
self % gradient_3d = 0
end if

allocate(self % output(self % output_size))
self % output = 0
Expand Down
8 changes: 6 additions & 2 deletions src/nf/nf_layer_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ pure module subroutine backward_1d(self, previous, gradient)

type is(flatten_layer)

! Upstream layers permitted: input3d, conv2d, maxpool2d
! Upstream layers permitted: input2d, input3d, conv2d, maxpool2d
select type(prev_layer => previous % p)
type is(input2d_layer)
call this_layer % backward(prev_layer % output, gradient)
type is(input3d_layer)
call this_layer % backward(prev_layer % output, gradient)
type is(conv2d_layer)
Expand Down Expand Up @@ -168,8 +170,10 @@ pure module subroutine forward(self, input)

type is(flatten_layer)

! Upstream layers permitted: input3d, conv2d, maxpool2d, reshape3d
! Upstream layers permitted: input2d, input3d, conv2d, maxpool2d, reshape3d
select type(prev_layer => input % p)
type is(input2d_layer)
call this_layer % forward(prev_layer % output)
type is(input3d_layer)
call this_layer % forward(prev_layer % output)
type is(conv2d_layer)
Expand Down
10 changes: 9 additions & 1 deletion src/nf/nf_network_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,20 @@ module subroutine backward(self, output, loss)
select type(next_layer => self % layers(n + 1) % p)
type is(dense_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)

type is(conv2d_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)

type is(flatten_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
if (size(self % layers(n) % layer_shape) == 2) then
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_2d)
else
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_3d)
end if

type is(maxpool2d_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)

type is(reshape3d_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
end select
Expand Down
43 changes: 40 additions & 3 deletions test/test_flatten_layer.f90
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@ program test_flatten_layer
use iso_fortran_env, only: stderr => error_unit
use nf, only: dense, flatten, input, layer, network
use nf_flatten_layer, only: flatten_layer
use nf_input2d_layer, only: input2d_layer
use nf_input3d_layer, only: input3d_layer

implicit none

type(layer) :: test_layer, input_layer
type(network) :: net
real, allocatable :: gradient(:,:,:)
real, allocatable :: gradient_3d(:,:,:), gradient_2d(:,:)
real, allocatable :: output(:)
logical :: ok = .true.

! Test 3D input
test_layer = flatten()

if (.not. test_layer % name == 'flatten') then
Expand Down Expand Up @@ -59,14 +61,49 @@ program test_flatten_layer
call test_layer % backward(input_layer, real([1, 2, 3, 4]))

select type(this_layer => test_layer % p); type is(flatten_layer)
gradient = this_layer % gradient
gradient_3d = this_layer % gradient_3d
end select

if (.not. all(gradient == reshape(real([1, 2, 3, 4]), [1, 2, 2]))) then
if (.not. all(gradient_3d == reshape(real([1, 2, 3, 4]), [1, 2, 2]))) then
ok = .false.
write(stderr, '(a)') 'flatten layer correctly propagates backward.. failed'
end if

! Test 2D input
test_layer = flatten()
input_layer = input(2, 3)
call test_layer % init(input_layer)

if (.not. all(test_layer % layer_shape == [6])) then
ok = .false.
write(stderr, '(a)') 'flatten layer has an incorrect output shape for 2D input.. failed'
end if

! Test forward pass - reshaping from 2-d to 1-d
select type(this_layer => input_layer % p); type is(input2d_layer)
call this_layer % set(reshape(real([1, 2, 3, 4, 5, 6]), [2, 3]))
end select

call test_layer % forward(input_layer)
call test_layer % get_output(output)

if (.not. all(output == [1, 2, 3, 4, 5, 6])) then
ok = .false.
write(stderr, '(a)') 'flatten layer correctly propagates forward for 2D input.. failed'
end if

! Test backward pass - reshaping from 1-d to 2-d
call test_layer % backward(input_layer, real([1, 2, 3, 4, 5, 6]))

select type(this_layer => test_layer % p); type is(flatten_layer)
gradient_2d = this_layer % gradient_2d
end select

if (.not. all(gradient_2d == reshape(real([1, 2, 3, 4, 5, 6]), [2, 3]))) then
ok = .false.
write(stderr, '(a)') 'flatten layer correctly propagates backward for 2D input.. failed'
end if

net = network([ &
input(1, 28, 28), &
flatten(), &
Expand Down

0 comments on commit 4ad75bc

Please sign in to comment.