4
4
use nf_conv2d_layer, only: conv2d_layer
5
5
use nf_dense_layer, only: dense_layer
6
6
use nf_flatten_layer, only: flatten_layer
7
+ use nf_flatten2d_layer, only: flatten2d_layer
7
8
use nf_input1d_layer, only: input1d_layer
8
9
use nf_input2d_layer, only: input2d_layer
9
10
use nf_input3d_layer, only: input3d_layer
@@ -46,8 +47,16 @@ pure module subroutine backward_1d(self, previous, gradient)
46
47
call this_layer % backward(prev_layer % output, gradient)
47
48
type is (maxpool2d_layer)
48
49
call this_layer % backward(prev_layer % output, gradient)
49
- ! type is(linear2d_layer)
50
- ! call this_layer % backward(prev_layer % output, gradient)
50
+ end select
51
+
52
+ type is (flatten2d_layer)
53
+
54
+ ! Upstream layers permitted: linear2d_layer
55
+ select type (prev_layer = > previous % p)
56
+ type is (linear2d_layer)
57
+ call this_layer % backward(prev_layer % output, gradient)
58
+ type is (input2d_layer)
59
+ call this_layer % backward(prev_layer % output, gradient)
51
60
end select
52
61
53
62
end select
@@ -61,8 +70,6 @@ pure module subroutine backward_2d(self, previous, gradient)
61
70
class(layer), intent (in ) :: previous
62
71
real , intent (in ) :: gradient(:,:)
63
72
64
- ! Backward pass from a 2-d layer downstream currently implemented
65
- ! only for input2d and linear2d layers
66
73
select type (this_layer = > self % p)
67
74
68
75
type is (linear2d_layer)
@@ -193,8 +200,14 @@ pure module subroutine forward(self, input)
193
200
call this_layer % forward(prev_layer % output)
194
201
type is (reshape3d_layer)
195
202
call this_layer % forward(prev_layer % output)
196
- ! type is(linear2d_layer)
197
- ! call this_layer % forward(prev_layer % output)
203
+ end select
204
+
205
+ type is (flatten2d_layer)
206
+ select type (prev_layer = > input % p)
207
+ type is (linear2d_layer)
208
+ call this_layer % forward(prev_layer % output)
209
+ type is (input2d_layer)
210
+ call this_layer % forward(prev_layer % output)
198
211
end select
199
212
200
213
type is (reshape3d_layer)
@@ -237,6 +250,8 @@ pure module subroutine get_output_1d(self, output)
237
250
allocate (output, source= this_layer % output)
238
251
type is (flatten_layer)
239
252
allocate (output, source= this_layer % output)
253
+ type is (flatten2d_layer)
254
+ allocate (output, source= this_layer % output)
240
255
class default
241
256
error stop ' 1-d output can only be read from an input1d, dense, or flatten layer.'
242
257
@@ -308,9 +323,11 @@ impure elemental module subroutine init(self, input)
308
323
self % layer_shape = shape (this_layer % output)
309
324
type is (flatten_layer)
310
325
self % layer_shape = shape (this_layer % output)
326
+ type is (flatten2d_layer)
327
+ self % layer_shape = shape (this_layer % output)
311
328
end select
312
329
313
- self % input_layer_shape = input % layer_shape
330
+ self % input_layer_shape = input % layer_shape
314
331
self % initialized = .true.
315
332
316
333
end subroutine init
@@ -351,6 +368,8 @@ elemental module function get_num_params(self) result(num_params)
351
368
num_params = 0
352
369
type is (flatten_layer)
353
370
num_params = 0
371
+ type is (flatten2d_layer)
372
+ num_params = 0
354
373
type is (reshape3d_layer)
355
374
num_params = 0
356
375
type is (linear2d_layer)
@@ -380,6 +399,8 @@ module function get_params(self) result(params)
380
399
! No parameters to get.
381
400
type is (flatten_layer)
382
401
! No parameters to get.
402
+ type is (flatten2d_layer)
403
+ ! No parameters to get.
383
404
type is (reshape3d_layer)
384
405
! No parameters to get.
385
406
type is (linear2d_layer)
@@ -408,6 +429,8 @@ module function get_gradients(self) result(gradients)
408
429
type is (maxpool2d_layer)
409
430
! No gradients to get.
410
431
type is (flatten_layer)
432
+ ! No parameters to get.
433
+ type is (flatten2d_layer)
411
434
! No gradients to get.
412
435
type is (reshape3d_layer)
413
436
! No gradients to get.
@@ -473,6 +496,11 @@ module subroutine set_params(self, params)
473
496
write (stderr, ' (a)' ) ' Warning: calling set_params() ' &
474
497
// ' on a zero-parameter layer; nothing to do.'
475
498
499
+ type is (flatten2d_layer)
500
+ ! No parameters to set.
501
+ write (stderr, ' (a)' ) ' Warning: calling set_params() ' &
502
+ // ' on a zero-parameter layer; nothing to do.'
503
+
476
504
type is (reshape3d_layer)
477
505
! No parameters to set.
478
506
write (stderr, ' (a)' ) ' Warning: calling set_params() ' &
0 commit comments