Skip to content

Commit 73799bd

Browse files
committed
embedding_layer: update constructor and tests
1 parent 074bcd1 commit 73799bd

File tree

3 files changed

+29
-12
lines changed

3 files changed

+29
-12
lines changed

src/nf/nf_layer_constructors.f90

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ module function self_attention(num_heads) result(res)
233233
!! Resulting layer instance
234234
end function self_attention
235235

236-
module function embedding(sequence_length, vocab_size, model_dimension) result(res)
236+
module function embedding(sequence_length, vocab_size, model_dimension, positional) result(res)
237237
!! Embedding layer constructor.
238238
!!
239239
!! This layer is for inputting token indices from the dictionary to the network.
@@ -243,6 +243,7 @@ module function embedding(sequence_length, vocab_size, model_dimension) result(r
243243
!! `vocab_size`: length of token vocabulary
244244
!! `model_dimension`: size of target embeddings
245245
integer, intent(in) :: sequence_length, vocab_size, model_dimension
246+
integer, optional, intent(in) :: positional
246247
type(layer) :: res
247248
end function embedding
248249

src/nf/nf_layer_constructors_submodule.f90

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,12 +182,13 @@ module function self_attention(num_heads) result(res)
182182
end function self_attention
183183

184184

185-
module function embedding(sequence_length, vocab_size, model_dimension) result(res)
185+
module function embedding(sequence_length, vocab_size, model_dimension, positional) result(res)
186186
integer, intent(in) :: sequence_length, vocab_size, model_dimension
187+
integer, optional, intent(in) :: positional
187188
type(layer) :: res
188189
type(embedding_layer) :: embedding_layer_instance
189190

190-
embedding_layer_instance = embedding_layer(vocab_size, model_dimension)
191+
embedding_layer_instance = embedding_layer(vocab_size, model_dimension, positional)
191192
call embedding_layer_instance % init([sequence_length])
192193
res % name = 'embedding'
193194
res % layer_shape = [sequence_length, model_dimension]

test/test_embedding_layer.f90

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
program test_embedding_layer
22
use iso_fortran_env, only: stderr => error_unit
33
use nf_embedding_layer, only: embedding_layer
4+
use nf_layer, only: layer
5+
use nf_layer_constructors, only: embedding_constructor => embedding
46
implicit none
57

68
logical :: ok = .true.
9+
integer :: sample_input(3) = [2, 1, 3]
710

8-
call test_simple(ok)
9-
call test_positional_trigonometric(ok)
10-
call test_positional_absolute(ok)
11+
call test_simple(ok, sample_input)
12+
call test_positional_trigonometric(ok, sample_input)
13+
call test_positional_absolute(ok, sample_input)
1114

1215
if (ok) then
1316
print '(a)', 'test_embedding_layer: All tests passed.'
@@ -17,10 +20,10 @@ program test_embedding_layer
1720
end if
1821

1922
contains
20-
subroutine test_simple(ok)
23+
subroutine test_simple(ok, sample_input)
2124
logical, intent(in out) :: ok
25+
integer, intent(in) :: sample_input(:)
2226

23-
integer :: sample_input(3) = [2, 1, 3]
2427
real :: sample_gradient(3, 2) = reshape([0.1, 0.2, 0.3, 0.4, 0.6, 0.6], [3, 2])
2528
real :: output_flat(6)
2629
real :: expected_output_flat(6) = reshape([0.3, 0.1, 0.5, 0.4, 0.2, 0.6], [6])
@@ -48,10 +51,10 @@ subroutine test_simple(ok)
4851
end if
4952
end subroutine test_simple
5053

51-
subroutine test_positional_trigonometric(ok)
54+
subroutine test_positional_trigonometric(ok, sample_input)
5255
logical, intent(in out) :: ok
56+
integer, intent(in) :: sample_input(:)
5357

54-
integer :: sample_input(3) = [2, 1, 3]
5558
real :: output_flat(12)
5659
real :: expected_output_flat(12) = reshape([&
5760
0.3, 0.941471, 1.4092975,&
@@ -82,10 +85,10 @@ subroutine test_positional_trigonometric(ok)
8285
end if
8386
end subroutine test_positional_trigonometric
8487

85-
subroutine test_positional_absolute(ok)
88+
subroutine test_positional_absolute(ok, sample_input)
8689
logical, intent(in out) :: ok
90+
integer, intent(in) :: sample_input(:)
8791

88-
integer :: sample_input(3) = [2, 1, 3]
8992
real :: output_flat(12)
9093
real :: expected_output_flat(12) = reshape([&
9194
0.3, 1.1, 2.5,&
@@ -115,4 +118,16 @@ subroutine test_positional_absolute(ok)
115118
write(stderr, '(a)') 'absolute positional encoding returned incorrect values.. failed'
116119
end if
117120
end subroutine test_positional_absolute
121+
122+
subroutine test_embedding_constructor(ok, sample_input)
123+
logical, intent(in out) :: ok
124+
integer, intent(in) :: sample_input(:)
125+
126+
type(layer) :: embedding_constructed
127+
128+
embedding_constructed = embedding_constructor(sequence_length=3, vocab_size=5, model_dimension=4)
129+
embedding_constructed = embedding_constructor(sequence_length=3, vocab_size=5, model_dimension=4, positional=0)
130+
embedding_constructed = embedding_constructor(sequence_length=3, vocab_size=5, model_dimension=4, positional=1)
131+
embedding_constructed = embedding_constructor(sequence_length=3, vocab_size=5, model_dimension=4, positional=2)
132+
end subroutine test_embedding_constructor
118133
end program test_embedding_layer

0 commit comments

Comments
 (0)