Skip to content

Commit 4cdd2e5

Browse files
committed
embedding_layer: plumbing
1 parent 48efd07 commit 4cdd2e5

File tree

5 files changed

+64
-14
lines changed

5 files changed

+64
-14
lines changed

src/nf.f90

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ module nf
1111
linear2d, &
1212
maxpool2d, &
1313
reshape, &
14-
self_attention
14+
self_attention, &
15+
embedding
1516
use nf_loss, only: mse, quadratic
1617
use nf_metrics, only: corr, maxabs
1718
use nf_network, only: network

src/nf/nf_layer_constructors.f90

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ module nf_layer_constructors
1717
linear2d, &
1818
maxpool2d, &
1919
reshape, &
20-
self_attention
20+
self_attention, &
21+
embedding
2122

2223
interface input
2324

@@ -222,15 +223,20 @@ module function linear2d(out_features) result(res)
222223
!! Resulting layer instance
223224
end function linear2d
224225

225-
module function self_attention(num_heads) result(res)
226-
!! Rank-2 (sequence_length, out_features) self attention constructor.
227-
!! sequence_length and model_dimension are determined at layer initialization, based on the
228-
!! output shape of the previous layer.
229-
integer, intent(in) :: num_heads
230-
!! Number of attention heads
231-
type(layer) :: res
232-
!! Resulting layer instance
233-
end function self_attention
226+
module function self_attention(num_heads) result(res)
227+
!! Rank-2 (sequence_length, out_features) self attention constructor.
228+
!! sequence_length and model_dimension are determined at layer initialization, based on the
229+
!! output shape of the previous layer.
230+
integer, intent(in) :: num_heads
231+
!! Number of attention heads
232+
type(layer) :: res
233+
!! Resulting layer instance
234+
end function self_attention
235+
236+
module function embedding(sequence_length, vocab_size, model_dimension) result(res)
237+
integer, intent(in) :: sequence_length, vocab_size, model_dimension
238+
type(layer) :: res
239+
end function embedding
234240

235241
end interface
236242

src/nf/nf_layer_constructors_submodule.f90

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
use nf_reshape_layer, only: reshape3d_layer
1313
use nf_linear2d_layer, only: linear2d_layer
1414
use nf_self_attention_layer, only: self_attention_layer
15+
use nf_embedding_layer, only: embedding_layer
1516
use nf_activation, only: activation_function, relu, sigmoid
1617

1718
implicit none
@@ -171,6 +172,7 @@ module function linear2d(out_features) result(res)
171172

172173
end function linear2d
173174

175+
174176
module function self_attention(num_heads) result(res)
175177
integer, intent(in) :: num_heads
176178
type(layer) :: res
@@ -179,4 +181,20 @@ module function self_attention(num_heads) result(res)
179181
allocate(res % p, source=self_attention_layer(num_heads))
180182
end function self_attention
181183

184+
185+
module function embedding(sequence_length, vocab_size, model_dimension) result(res)
186+
integer, intent(in) :: sequence_length, vocab_size, model_dimension
187+
type(layer) :: res
188+
type(embedding_layer) :: embedding_layer_instance
189+
190+
embedding_layer_instance = embedding_layer(vocab_size, model_dimension)
191+
call embedding_layer_instance % init([sequence_length])
192+
res % name = 'embedding'
193+
res % layer_shape = [sequence_length, model_dimension]
194+
res % input_layer_shape = [integer ::]
195+
allocate(res % p, source=embedding_layer_instance)
196+
res % initialized = .true.
197+
198+
end function embedding
199+
182200
end submodule nf_layer_constructors_submodule

src/nf/nf_layer_submodule.f90

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
use nf_reshape_layer, only: reshape3d_layer
1313
use nf_linear2d_layer, only: linear2d_layer
1414
use nf_self_attention_layer, only: self_attention_layer
15+
use nf_embedding_layer, only: embedding_layer
1516
use nf_optimizers, only: optimizer_base_type
1617

1718
contains
@@ -60,6 +61,8 @@ pure module subroutine backward_1d(self, previous, gradient)
6061
call this_layer % backward(prev_layer % output, gradient)
6162
type is(self_attention_layer)
6263
call this_layer % backward(prev_layer % output, gradient)
64+
type is(embedding_layer)
65+
call this_layer % backward(prev_layer % output, gradient)
6366
end select
6467

6568
end select
@@ -80,6 +83,8 @@ pure module subroutine backward_2d(self, previous, gradient)
8083
select type(prev_layer => previous % p)
8184
type is(input2d_layer)
8285
call this_layer % backward(prev_layer % output, gradient)
86+
type is(embedding_layer)
87+
call this_layer % backward(prev_layer % output, gradient)
8388
type is(linear2d_layer)
8489
call this_layer % backward(prev_layer % output, gradient)
8590
type is(self_attention_layer)
@@ -91,6 +96,8 @@ pure module subroutine backward_2d(self, previous, gradient)
9196
select type(prev_layer => previous % p)
9297
type is(input2d_layer)
9398
call this_layer % backward(prev_layer % output, gradient)
99+
type is(embedding_layer)
100+
call this_layer % backward(prev_layer % output, gradient)
94101
type is(linear2d_layer)
95102
call this_layer % backward(prev_layer % output, gradient)
96103
type is(self_attention_layer)
@@ -254,6 +261,8 @@ module subroutine forward(self, input)
254261
select type(prev_layer => input % p)
255262
type is(input2d_layer)
256263
call this_layer % forward(prev_layer % output)
264+
type is(embedding_layer)
265+
call this_layer % forward(prev_layer % output)
257266
type is(linear2d_layer)
258267
call this_layer % forward(prev_layer % output)
259268
type is(self_attention_layer)
@@ -266,6 +275,8 @@ module subroutine forward(self, input)
266275
select type(prev_layer => input % p)
267276
type is(input2d_layer)
268277
call this_layer % forward(prev_layer % output)
278+
type is(embedding_layer)
279+
call this_layer % forward(prev_layer % output)
269280
type is(linear2d_layer)
270281
call this_layer % forward(prev_layer % output)
271282
type is(self_attention_layer)
@@ -307,6 +318,8 @@ pure module subroutine get_output_2d(self, output)
307318

308319
type is(input2d_layer)
309320
allocate(output, source=this_layer % output)
321+
type is(embedding_layer)
322+
allocate(output, source=this_layer % output)
310323
type is(linear2d_layer)
311324
allocate(output, source=this_layer % output)
312325
type is(self_attention_layer)
@@ -425,6 +438,8 @@ elemental module function get_num_params(self) result(num_params)
425438
num_params = this_layer % get_num_params()
426439
type is (self_attention_layer)
427440
num_params = this_layer % get_num_params()
441+
type is (embedding_layer)
442+
num_params = this_layer % get_num_params()
428443
class default
429444
error stop 'Unknown layer type.'
430445
end select
@@ -458,6 +473,8 @@ module function get_params(self) result(params)
458473
params = this_layer % get_params()
459474
type is (self_attention_layer)
460475
params = this_layer % get_params()
476+
type is (embedding_layer)
477+
params = this_layer % get_params()
461478
class default
462479
error stop 'Unknown layer type.'
463480
end select
@@ -491,6 +508,8 @@ module function get_gradients(self) result(gradients)
491508
gradients = this_layer % get_gradients()
492509
type is (self_attention_layer)
493510
gradients = this_layer % get_gradients()
511+
type is (embedding_layer)
512+
gradients = this_layer % get_gradients()
494513
class default
495514
error stop 'Unknown layer type.'
496515
end select
@@ -548,6 +567,8 @@ module subroutine set_params(self, params)
548567

549568
type is (self_attention_layer)
550569
call this_layer % set_params(params)
570+
type is (embedding_layer)
571+
call this_layer % set_params(params)
551572

552573
type is (maxpool2d_layer)
553574
! No parameters to set.

src/nf/nf_network_submodule.f90

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
use nf_reshape_layer, only: reshape3d_layer
1212
use nf_linear2d_layer, only: linear2d_layer
1313
use nf_self_attention_layer, only: self_attention_layer
14+
use nf_embedding_layer, only: embedding_layer
1415
use nf_layer, only: layer
1516
use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d, reshape
1617
use nf_loss, only: quadratic
@@ -46,7 +47,7 @@ module function network_from_layers(layers) result(res)
4647
error stop 'Error: A network must have at least 2 layers.'
4748

4849
! The first layer must be an input layer
49-
if (.not. layers(1) % name == 'input') &
50+
if (.not. layers(1) % name == 'input' .and. .not. layers(1) % name == 'embedding') &
5051
error stop 'Error: First layer in the network must be an input layer.'
5152

5253
!TODO Ensure that the layers are in allowed sequence:
@@ -207,8 +208,11 @@ module subroutine forward_1d(self, input)
207208
integer :: n
208209

209210
! Set the input array into the input layer
210-
select type(input_layer => self % layers(1) % p); type is(input1d_layer)
211-
call input_layer % set(input)
211+
select type(input_layer => self % layers(1) % p)
212+
type is(input1d_layer)
213+
call input_layer % set(input)
214+
type is(embedding_layer)
215+
call input_layer % forward(nint(input))
212216
end select
213217

214218
do n = 2, size(self % layers)

0 commit comments

Comments
 (0)