Skip to content

Commit 6bfea21

Browse files
committed
embedding_layer: positional encoding
1 parent 4cdd2e5 commit 6bfea21

File tree

2 files changed

+34
-2
lines changed

2 files changed

+34
-2
lines changed

src/nf/nf_embedding_layer.f90

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ module nf_embedding_layer
1515
!! This layer converts them into a table of shape
1616
!! (`sequence_length`, `model_dimension`)
1717
integer :: sequence_length, vocab_size, model_dimension
18+
logical :: positional
1819

1920
real, allocatable :: weights(:, :)
2021
real, allocatable :: output(:, :)
@@ -24,6 +25,7 @@ module nf_embedding_layer
2425

2526
procedure :: backward
2627
procedure :: forward
28+
procedure :: positional_encoding
2729
procedure :: init
2830
procedure :: get_num_params
2931
procedure :: get_params
@@ -33,8 +35,9 @@ module nf_embedding_layer
3335
end type embedding_layer
3436

3537
interface embedding_layer
36-
module function embedding_layer_cons(vocab_size, model_dimension) result(res)
38+
module function embedding_layer_cons(vocab_size, model_dimension, positional) result(res)
3739
integer, intent(in) :: vocab_size, model_dimension
40+
logical, optional :: positional
3841
type(embedding_layer) :: res
3942
end function embedding_layer_cons
4043
end interface embedding_layer
@@ -54,6 +57,11 @@ pure module subroutine backward(self, input, gradient)
5457
real, intent(in) :: gradient(:, :)
5558
end subroutine backward
5659

60+
pure module subroutine positional_encoding(self, pos)
61+
class(embedding_layer), intent(in out) :: self
62+
integer, intent(in) :: pos
63+
end subroutine positional_encoding
64+
5765
module subroutine init(self, input_shape)
5866
class(embedding_layer), intent(in out) :: self
5967
integer, intent(in) :: input_shape(:)

src/nf/nf_embedding_layer_submodule.f90

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,18 @@
22
use nf_base_layer, only: base_layer
33
implicit none
44
contains
5-
module function embedding_layer_cons(vocab_size, model_dimension) result(res)
5+
module function embedding_layer_cons(vocab_size, model_dimension, positional) result(res)
66
integer, intent(in) :: vocab_size, model_dimension
7+
logical, optional :: positional
78
type(embedding_layer) :: res
89

910
res % vocab_size = vocab_size
1011
res % model_dimension = model_dimension
12+
if (.not. present(positional)) then
13+
res % positional = .false.
14+
else
15+
res % positional = positional
16+
end if
1117
end function embedding_layer_cons
1218

1319
module subroutine init(self, input_shape)
@@ -37,7 +43,12 @@ pure module subroutine forward(self, input)
3743
elseif (index == 0) then
3844
index = 1
3945
end if
46+
4047
self % output(i, :) = self % weights(index, :)
48+
49+
if (self % positional) then
50+
call self % positional_encoding(i)
51+
end if
4152
end do
4253
end subroutine forward
4354

@@ -52,6 +63,19 @@ pure module subroutine backward(self, input, gradient)
5263
end do
5364
end subroutine backward
5465

66+
pure module subroutine positional_encoding(self, pos)
67+
class(embedding_layer), intent(in out) :: self
68+
integer, intent(in) :: pos
69+
integer :: i
70+
real :: theta
71+
72+
do concurrent(i = 1: floor(real(self % model_dimension) / 2))
73+
theta = (pos - 1) / 10000 ** (real(2 * (i-1)) / self % model_dimension)
74+
self % output(pos, 2 * i - 1) = self % output(pos, 2 * i - 1) + sin(theta)
75+
self % output(pos, 2 * i) = self % output(pos, 2 * i) + cos(theta)
76+
end do
77+
end subroutine positional_encoding
78+
5579
pure module function get_num_params(self) result(num_params)
5680
class(embedding_layer), intent(in) :: self
5781
integer :: num_params

0 commit comments

Comments
 (0)