12
12
use nf_reshape_layer, only: reshape3d_layer
13
13
use nf_linear2d_layer, only: linear2d_layer
14
14
use nf_self_attention_layer, only: self_attention_layer
15
+ use nf_embedding_layer, only: embedding_layer
15
16
use nf_optimizers, only: optimizer_base_type
16
17
17
18
contains
@@ -60,6 +61,8 @@ pure module subroutine backward_1d(self, previous, gradient)
60
61
call this_layer % backward(prev_layer % output, gradient)
61
62
type is (self_attention_layer)
62
63
call this_layer % backward(prev_layer % output, gradient)
64
+ type is (embedding_layer)
65
+ call this_layer % backward(prev_layer % output, gradient)
63
66
end select
64
67
65
68
end select
@@ -80,6 +83,8 @@ pure module subroutine backward_2d(self, previous, gradient)
80
83
select type (prev_layer = > previous % p)
81
84
type is (input2d_layer)
82
85
call this_layer % backward(prev_layer % output, gradient)
86
+ type is (embedding_layer)
87
+ call this_layer % backward(prev_layer % output, gradient)
83
88
type is (linear2d_layer)
84
89
call this_layer % backward(prev_layer % output, gradient)
85
90
type is (self_attention_layer)
@@ -91,6 +96,8 @@ pure module subroutine backward_2d(self, previous, gradient)
91
96
select type (prev_layer = > previous % p)
92
97
type is (input2d_layer)
93
98
call this_layer % backward(prev_layer % output, gradient)
99
+ type is (embedding_layer)
100
+ call this_layer % backward(prev_layer % output, gradient)
94
101
type is (linear2d_layer)
95
102
call this_layer % backward(prev_layer % output, gradient)
96
103
type is (self_attention_layer)
@@ -254,6 +261,8 @@ module subroutine forward(self, input)
254
261
select type (prev_layer = > input % p)
255
262
type is (input2d_layer)
256
263
call this_layer % forward(prev_layer % output)
264
+ type is (embedding_layer)
265
+ call this_layer % forward(prev_layer % output)
257
266
type is (linear2d_layer)
258
267
call this_layer % forward(prev_layer % output)
259
268
type is (self_attention_layer)
@@ -266,6 +275,8 @@ module subroutine forward(self, input)
266
275
select type (prev_layer = > input % p)
267
276
type is (input2d_layer)
268
277
call this_layer % forward(prev_layer % output)
278
+ type is (embedding_layer)
279
+ call this_layer % forward(prev_layer % output)
269
280
type is (linear2d_layer)
270
281
call this_layer % forward(prev_layer % output)
271
282
type is (self_attention_layer)
@@ -307,6 +318,8 @@ pure module subroutine get_output_2d(self, output)
307
318
308
319
type is (input2d_layer)
309
320
allocate (output, source= this_layer % output)
321
+ type is (embedding_layer)
322
+ allocate (output, source= this_layer % output)
310
323
type is (linear2d_layer)
311
324
allocate (output, source= this_layer % output)
312
325
type is (self_attention_layer)
@@ -425,6 +438,8 @@ elemental module function get_num_params(self) result(num_params)
425
438
num_params = this_layer % get_num_params()
426
439
type is (self_attention_layer)
427
440
num_params = this_layer % get_num_params()
441
+ type is (embedding_layer)
442
+ num_params = this_layer % get_num_params()
428
443
class default
429
444
error stop ' Unknown layer type.'
430
445
end select
@@ -458,6 +473,8 @@ module function get_params(self) result(params)
458
473
params = this_layer % get_params()
459
474
type is (self_attention_layer)
460
475
params = this_layer % get_params()
476
+ type is (embedding_layer)
477
+ params = this_layer % get_params()
461
478
class default
462
479
error stop ' Unknown layer type.'
463
480
end select
@@ -491,6 +508,8 @@ module function get_gradients(self) result(gradients)
491
508
gradients = this_layer % get_gradients()
492
509
type is (self_attention_layer)
493
510
gradients = this_layer % get_gradients()
511
+ type is (embedding_layer)
512
+ gradients = this_layer % get_gradients()
494
513
class default
495
514
error stop ' Unknown layer type.'
496
515
end select
@@ -548,6 +567,8 @@ module subroutine set_params(self, params)
548
567
549
568
type is (self_attention_layer)
550
569
call this_layer % set_params(params)
570
+ type is (embedding_layer)
571
+ call this_layer % set_params(params)
551
572
552
573
type is (maxpool2d_layer)
553
574
! No parameters to set.
0 commit comments