Skip to content

Insert a flatten layer if a dense layer follows a layer #118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2018-2022 neural-fortran contributors
Copyright (c) 2018-2023 neural-fortran contributors

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
1 change: 0 additions & 1 deletion example/cnn_mnist.f90
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ program cnn_mnist
maxpool2d(pool_size=2), &
conv2d(filters=16, kernel_size=3, activation='relu'), &
maxpool2d(pool_size=2), &
flatten(), &
dense(10, activation='softmax') &
])

Expand Down
4 changes: 2 additions & 2 deletions fpm.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
name = "neural-fortran"
version = "0.10.0"
version = "0.11.0"
license = "MIT"
author = "Milan Curcic"
maintainer = "milancurcic@hey.com"
copyright = "Copyright 2018-2022, neural-fortran contributors"
copyright = "Copyright 2018-2023, neural-fortran contributors"

[build]
external-modules = "hdf5"
Expand Down
29 changes: 28 additions & 1 deletion src/nf/nf_network_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,37 @@ module function network_from_layers(layers) result(res)

res % layers = layers

! If connecting a 3-d output layer to a 1-d input layer without a flatten
! layer in between, insert a flatten layer.
n = 2
do while (n <= size(res % layers))
select type(this_layer => res % layers(n) % p)
type is(dense_layer)
select type(prev_layer => res % layers(n-1) % p)
type is(input3d_layer)
res % layers = [res % layers(:n-1), flatten(), res % layers(n:)]
n = n + 1
type is(conv2d_layer)
res % layers = [res % layers(:n-1), flatten(), res % layers(n:)]
n = n + 1
type is(maxpool2d_layer)
res % layers = [res % layers(:n-1), flatten(), res % layers(n:)]
n = n + 1
type is(reshape3d_layer)
res % layers = [res % layers(:n-1), flatten(), res % layers(n:)]
n = n + 1
class default
n = n + 1
end select
class default
n = n + 1
end select
end do

! Loop over each layer in order and call the init methods.
! This will allocate the data internal to each layer (e.g. weights, biases)
! according to the size of the previous layer.
do n = 2, size(layers)
do n = 2, size(res % layers)
call res % layers(n) % init(res % layers(n - 1))
end do

Expand Down
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ foreach(execid
conv2d_layer
maxpool2d_layer
flatten_layer
insert_flatten
reshape_layer
dense_network
get_set_network_params
Expand Down
64 changes: 64 additions & 0 deletions test/test_insert_flatten.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
program test_insert_flatten

use iso_fortran_env, only: stderr => error_unit
use nf, only: network, input, conv2d, maxpool2d, flatten, dense, reshape

implicit none

type(network) :: net
logical :: ok = .true.

net = network([ &
input([3, 32, 32]), &
dense(10) &
])

if (.not. net % layers(2) % name == 'flatten') then
ok = .false.
write(stderr, '(a)') 'flatten layer inserted after input3d.. failed'
end if

net = network([ &
input([3, 32, 32]), &
conv2d(filters=1, kernel_size=3), &
dense(10) &
])

!call net % print_info()

if (.not. net % layers(3) % name == 'flatten') then
ok = .false.
write(stderr, '(a)') 'flatten layer inserted after conv2d.. failed'
end if

net = network([ &
input([3, 32, 32]), &
conv2d(filters=1, kernel_size=3), &
maxpool2d(pool_size=2, stride=2), &
dense(10) &
])

if (.not. net % layers(4) % name == 'flatten') then
ok = .false.
write(stderr, '(a)') 'flatten layer inserted after maxpool2d.. failed'
end if

net = network([ &
input(4), &
reshape([1, 2, 2]), &
dense(4) &
])

if (.not. net % layers(3) % name == 'flatten') then
ok = .false.
write(stderr, '(a)') 'flatten layer inserted after reshape.. failed'
end if

if (ok) then
print '(a)', 'test_insert_flatten: All tests passed.'
else
write(stderr, '(a)') 'test_insert_flatten: One or more tests failed.'
stop 1
end if

end program test_insert_flatten