Skip to content

Generic flatten (2d and 3d) #202

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Read the paper [here](https://arxiv.org/abs/1902.06714).
| Dense (fully-connected) | `dense` | `input1d`, `flatten` | 1 | ✅ | ✅ |
| Convolutional (2-d) | `conv2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 | ✅ | ✅(*) |
| Max-pooling (2-d) | `maxpool2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 | ✅ | ✅ |
| Flatten | `flatten` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 1 | ✅ | ✅ |
| Flatten | `flatten` | `input2d`, `input3d`, `conv2d`, `maxpool2d`, `reshape` | 1 | ✅ | ✅ |
| Reshape (1-d to 3-d) | `reshape` | `input1d`, `dense`, `flatten` | 3 | ✅ | ✅ |

(*) See Issue [#145](https://github.com/modern-fortran/neural-fortran/issues/145) regarding non-converging CNN training on the MNIST dataset.
Expand Down
3 changes: 3 additions & 0 deletions fpm.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@ license = "MIT"
author = "Milan Curcic"
maintainer = "mcurcic@miami.edu"
copyright = "Copyright 2018-2025, neural-fortran contributors"

[preprocess]
[preprocess.cpp]
13 changes: 7 additions & 6 deletions src/nf/nf_flatten_layer.f90
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ module nf_flatten_layer
integer, allocatable :: input_shape(:)
integer :: output_size

real, allocatable :: gradient(:,:,:)
real, allocatable :: gradient_2d(:,:)
real, allocatable :: gradient_3d(:,:,:)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I thought about that but decided not to make the code even less SOLID

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But here we have a choice between SOLID and less boilerplate, I think I agree that the second one is better

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and, most importantly for me, this approach allows for a unified API (only one flatten() for the user).

real, allocatable :: output(:)

contains
Expand All @@ -40,23 +41,23 @@ end function flatten_layer_cons
interface

pure module subroutine backward(self, input, gradient)
!! Apply the backward pass to the flatten layer.
!! This is a reshape operation from 1-d gradient to 3-d input.
!! Apply the backward pass to the flatten layer for 2D and 3D input.
!! This is a reshape operation from 1-d gradient to 2-d and 3-d input.
class(flatten_layer), intent(in out) :: self
!! Flatten layer instance
real, intent(in) :: input(:,:,:)
real, intent(in) :: input(..)
!! Input from the previous layer
real, intent(in) :: gradient(:)
!! Gradient from the next layer
end subroutine backward

pure module subroutine forward(self, input)
!! Propagate forward the layer.
!! Propagate forward the layer for 2D or 3D input.
!! Calling this subroutine updates the values of a few data components
!! of `flatten_layer` that are needed for the backward pass.
class(flatten_layer), intent(in out) :: self
!! Dense layer instance
real, intent(in) :: input(:,:,:)
real, intent(in) :: input(..)
!! Input from the previous layer
end subroutine forward

Expand Down
31 changes: 25 additions & 6 deletions src/nf/nf_flatten_layer_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,30 @@ end function flatten_layer_cons

pure module subroutine backward(self, input, gradient)
class(flatten_layer), intent(in out) :: self
real, intent(in) :: input(:,:,:)
real, intent(in) :: input(..)
real, intent(in) :: gradient(:)
self % gradient = reshape(gradient, shape(input))
select rank(input)
rank(2)
self % gradient_2d = reshape(gradient, shape(input))
rank(3)
self % gradient_3d = reshape(gradient, shape(input))
rank default
error stop "Unsupported rank of input"
end select
end subroutine backward


pure module subroutine forward(self, input)
class(flatten_layer), intent(in out) :: self
real, intent(in) :: input(:,:,:)
self % output = pack(input, .true.)
real, intent(in) :: input(..)
select rank(input)
rank(2)
self % output = pack(input, .true.)
rank(3)
self % output = pack(input, .true.)
rank default
error stop "Unsupported rank of input"
end select
end subroutine forward


Expand All @@ -37,8 +51,13 @@ module subroutine init(self, input_shape)
self % input_shape = input_shape
self % output_size = product(input_shape)

allocate(self % gradient(input_shape(1), input_shape(2), input_shape(3)))
self % gradient = 0
if (size(input_shape) == 2) then
allocate(self % gradient_2d(input_shape(1), input_shape(2)))
self % gradient_2d = 0
else if (size(input_shape) == 3) then
allocate(self % gradient_3d(input_shape(1), input_shape(2), input_shape(3)))
self % gradient_3d = 0
end if

allocate(self % output(self % output_size))
self % output = 0
Expand Down
8 changes: 6 additions & 2 deletions src/nf/nf_layer_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ pure module subroutine backward_1d(self, previous, gradient)

type is(flatten_layer)

! Upstream layers permitted: input3d, conv2d, maxpool2d
! Upstream layers permitted: input2d, input3d, conv2d, maxpool2d
select type(prev_layer => previous % p)
type is(input2d_layer)
call this_layer % backward(prev_layer % output, gradient)
type is(input3d_layer)
call this_layer % backward(prev_layer % output, gradient)
type is(conv2d_layer)
Expand Down Expand Up @@ -168,8 +170,10 @@ pure module subroutine forward(self, input)

type is(flatten_layer)

! Upstream layers permitted: input3d, conv2d, maxpool2d, reshape3d
! Upstream layers permitted: input2d, input3d, conv2d, maxpool2d, reshape3d
select type(prev_layer => input % p)
type is(input2d_layer)
call this_layer % forward(prev_layer % output)
type is(input3d_layer)
call this_layer % forward(prev_layer % output)
type is(conv2d_layer)
Expand Down
10 changes: 9 additions & 1 deletion src/nf/nf_network_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,20 @@ module subroutine backward(self, output, loss)
select type(next_layer => self % layers(n + 1) % p)
type is(dense_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)

type is(conv2d_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)

type is(flatten_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
if (size(self % layers(n) % layer_shape) == 2) then
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_2d)
else
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_3d)
end if

type is(maxpool2d_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)

type is(reshape3d_layer)
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
end select
Expand Down
43 changes: 40 additions & 3 deletions test/test_flatten_layer.f90
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@ program test_flatten_layer
use iso_fortran_env, only: stderr => error_unit
use nf, only: dense, flatten, input, layer, network
use nf_flatten_layer, only: flatten_layer
use nf_input2d_layer, only: input2d_layer
use nf_input3d_layer, only: input3d_layer

implicit none

type(layer) :: test_layer, input_layer
type(network) :: net
real, allocatable :: gradient(:,:,:)
real, allocatable :: gradient_3d(:,:,:), gradient_2d(:,:)
real, allocatable :: output(:)
logical :: ok = .true.

! Test 3D input
test_layer = flatten()

if (.not. test_layer % name == 'flatten') then
Expand Down Expand Up @@ -59,14 +61,49 @@ program test_flatten_layer
call test_layer % backward(input_layer, real([1, 2, 3, 4]))

select type(this_layer => test_layer % p); type is(flatten_layer)
gradient = this_layer % gradient
gradient_3d = this_layer % gradient_3d
end select

if (.not. all(gradient == reshape(real([1, 2, 3, 4]), [1, 2, 2]))) then
if (.not. all(gradient_3d == reshape(real([1, 2, 3, 4]), [1, 2, 2]))) then
ok = .false.
write(stderr, '(a)') 'flatten layer correctly propagates backward.. failed'
end if

! Test 2D input
test_layer = flatten()
input_layer = input(2, 3)
call test_layer % init(input_layer)

if (.not. all(test_layer % layer_shape == [6])) then
ok = .false.
write(stderr, '(a)') 'flatten layer has an incorrect output shape for 2D input.. failed'
end if

! Test forward pass - reshaping from 2-d to 1-d
select type(this_layer => input_layer % p); type is(input2d_layer)
call this_layer % set(reshape(real([1, 2, 3, 4, 5, 6]), [2, 3]))
end select

call test_layer % forward(input_layer)
call test_layer % get_output(output)

if (.not. all(output == [1, 2, 3, 4, 5, 6])) then
ok = .false.
write(stderr, '(a)') 'flatten layer correctly propagates forward for 2D input.. failed'
end if

! Test backward pass - reshaping from 1-d to 2-d
call test_layer % backward(input_layer, real([1, 2, 3, 4, 5, 6]))

select type(this_layer => test_layer % p); type is(flatten_layer)
gradient_2d = this_layer % gradient_2d
end select

if (.not. all(gradient_2d == reshape(real([1, 2, 3, 4, 5, 6]), [2, 3]))) then
ok = .false.
write(stderr, '(a)') 'flatten layer correctly propagates backward for 2D input.. failed'
end if

net = network([ &
input(1, 28, 28), &
flatten(), &
Expand Down