Skip to content

Commit bd850b9

Browse files
committed
linear2d_layer: make linear2d layer work with input2d and flatten2d
1 parent df84781 commit bd850b9

File tree

3 files changed

+50
-8
lines changed

3 files changed

+50
-8
lines changed

src/nf.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module nf
33
use nf_datasets_mnist, only: label_digits, load_mnist
44
use nf_layer, only: layer
55
use nf_layer_constructors, only: &
6-
conv2d, dense, flatten, input, maxpool2d, reshape, linear2d
6+
conv2d, dense, flatten, flatten2d, input, maxpool2d, reshape, linear2d
77
use nf_loss, only: mse, quadratic
88
use nf_metrics, only: corr, maxabs
99
use nf_network, only: network

src/nf/nf_layer_submodule.f90

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
use nf_conv2d_layer, only: conv2d_layer
55
use nf_dense_layer, only: dense_layer
66
use nf_flatten_layer, only: flatten_layer
7+
use nf_flatten2d_layer, only: flatten2d_layer
78
use nf_input1d_layer, only: input1d_layer
89
use nf_input2d_layer, only: input2d_layer
910
use nf_input3d_layer, only: input3d_layer
@@ -46,8 +47,16 @@ pure module subroutine backward_1d(self, previous, gradient)
4647
call this_layer % backward(prev_layer % output, gradient)
4748
type is(maxpool2d_layer)
4849
call this_layer % backward(prev_layer % output, gradient)
49-
! type is(linear2d_layer)
50-
! call this_layer % backward(prev_layer % output, gradient)
50+
end select
51+
52+
type is(flatten2d_layer)
53+
54+
! Upstream layers permitted: linear2d_layer
55+
select type(prev_layer => previous % p)
56+
type is(linear2d_layer)
57+
call this_layer % backward(prev_layer % output, gradient)
58+
type is(input2d_layer)
59+
call this_layer % backward(prev_layer % output, gradient)
5160
end select
5261

5362
end select
@@ -61,8 +70,6 @@ pure module subroutine backward_2d(self, previous, gradient)
6170
class(layer), intent(in) :: previous
6271
real, intent(in) :: gradient(:,:)
6372

64-
! Backward pass from a 2-d layer downstream currently implemented
65-
! only for input2d and linear2d layers
6673
select type(this_layer => self % p)
6774

6875
type is(linear2d_layer)
@@ -193,8 +200,14 @@ pure module subroutine forward(self, input)
193200
call this_layer % forward(prev_layer % output)
194201
type is(reshape3d_layer)
195202
call this_layer % forward(prev_layer % output)
196-
! type is(linear2d_layer)
197-
! call this_layer % forward(prev_layer % output)
203+
end select
204+
205+
type is(flatten2d_layer)
206+
select type(prev_layer => input % p)
207+
type is(linear2d_layer)
208+
call this_layer % forward(prev_layer % output)
209+
type is(input2d_layer)
210+
call this_layer % forward(prev_layer % output)
198211
end select
199212

200213
type is(reshape3d_layer)
@@ -237,6 +250,8 @@ pure module subroutine get_output_1d(self, output)
237250
allocate(output, source=this_layer % output)
238251
type is(flatten_layer)
239252
allocate(output, source=this_layer % output)
253+
type is(flatten2d_layer)
254+
allocate(output, source=this_layer % output)
240255
class default
241256
error stop '1-d output can only be read from an input1d, dense, or flatten layer.'
242257

@@ -308,9 +323,11 @@ impure elemental module subroutine init(self, input)
308323
self % layer_shape = shape(this_layer % output)
309324
type is(flatten_layer)
310325
self % layer_shape = shape(this_layer % output)
326+
type is(flatten2d_layer)
327+
self % layer_shape = shape(this_layer % output)
311328
end select
312329

313-
self % input_layer_shape = input % layer_shape
330+
self % input_layer_shape = input % layer_shape
314331
self % initialized = .true.
315332

316333
end subroutine init
@@ -351,6 +368,8 @@ elemental module function get_num_params(self) result(num_params)
351368
num_params = 0
352369
type is (flatten_layer)
353370
num_params = 0
371+
type is (flatten2d_layer)
372+
num_params = 0
354373
type is (reshape3d_layer)
355374
num_params = 0
356375
type is (linear2d_layer)
@@ -380,6 +399,8 @@ module function get_params(self) result(params)
380399
! No parameters to get.
381400
type is (flatten_layer)
382401
! No parameters to get.
402+
type is (flatten2d_layer)
403+
! No parameters to get.
383404
type is (reshape3d_layer)
384405
! No parameters to get.
385406
type is (linear2d_layer)
@@ -408,6 +429,8 @@ module function get_gradients(self) result(gradients)
408429
type is (maxpool2d_layer)
409430
! No gradients to get.
410431
type is (flatten_layer)
432+
! No parameters to get.
433+
type is (flatten2d_layer)
411434
! No gradients to get.
412435
type is (reshape3d_layer)
413436
! No gradients to get.
@@ -473,6 +496,11 @@ module subroutine set_params(self, params)
473496
write(stderr, '(a)') 'Warning: calling set_params() ' &
474497
// 'on a zero-parameter layer; nothing to do.'
475498

499+
type is (flatten2d_layer)
500+
! No parameters to set.
501+
write(stderr, '(a)') 'Warning: calling set_params() ' &
502+
// 'on a zero-parameter layer; nothing to do.'
503+
476504
type is (reshape3d_layer)
477505
! No parameters to set.
478506
write(stderr, '(a)') 'Warning: calling set_params() ' &

src/nf/nf_network_submodule.f90

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
use nf_conv2d_layer, only: conv2d_layer
44
use nf_dense_layer, only: dense_layer
55
use nf_flatten_layer, only: flatten_layer
6+
use nf_flatten2d_layer, only: flatten2d_layer
67
use nf_input1d_layer, only: input1d_layer
78
use nf_input2d_layer, only: input2d_layer
89
use nf_input3d_layer, only: input3d_layer
@@ -135,6 +136,11 @@ module subroutine backward(self, output, loss)
135136
self % layers(n - 1), &
136137
self % loss % derivative(output, this_layer % output) &
137138
)
139+
type is(flatten2d_layer)
140+
call self % layers(n) % backward( &
141+
self % layers(n - 1), &
142+
self % loss % derivative(output, this_layer % output) &
143+
)
138144
end select
139145
else
140146
! Hidden layer; take the gradient from the next layer
@@ -145,6 +151,8 @@ module subroutine backward(self, output, loss)
145151
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
146152
type is(flatten_layer)
147153
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
154+
type is(flatten2d_layer)
155+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
148156
type is(maxpool2d_layer)
149157
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
150158
type is(reshape3d_layer)
@@ -255,6 +263,8 @@ module function predict_1d(self, input) result(res)
255263
res = output_layer % output
256264
type is(flatten_layer)
257265
res = output_layer % output
266+
type is(flatten2d_layer)
267+
res = output_layer % output
258268
class default
259269
error stop 'network % output not implemented for this output layer'
260270
end select
@@ -275,6 +285,10 @@ module function predict_2d(self, input) result(res)
275285
select type(output_layer => self % layers(num_layers) % p)
276286
type is(dense_layer)
277287
res = output_layer % output
288+
type is(flatten2d_layer)
289+
res = output_layer % output
290+
class default
291+
error stop 'network % output not implemented for this output layer'
278292
end select
279293

280294
end function predict_2d

0 commit comments

Comments
 (0)