Skip to content

Embedding Layer #205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
embedding_layer: add absolute positional encoding
  • Loading branch information
OneAdder committed Feb 23, 2025
commit 074bcd1edd70569082e730edd743570366aca51e
17 changes: 12 additions & 5 deletions src/nf/nf_embedding_layer.f90
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ module nf_embedding_layer
!! This layer converts them into a table of shape
!! (`sequence_length`, `model_dimension`)
integer :: sequence_length, vocab_size, model_dimension
logical :: positional
integer :: positional

real, allocatable :: weights(:, :)
real, allocatable :: output(:, :)
Expand All @@ -25,7 +25,8 @@ module nf_embedding_layer

procedure :: backward
procedure :: forward
procedure :: positional_encoding
procedure :: positional_trigonometric
procedure :: positional_absolute
procedure :: init
procedure :: get_num_params
procedure :: get_params
Expand All @@ -37,7 +38,7 @@ module nf_embedding_layer
interface embedding_layer
module function embedding_layer_cons(vocab_size, model_dimension, positional) result(res)
integer, intent(in) :: vocab_size, model_dimension
logical, optional :: positional
integer, optional :: positional
type(embedding_layer) :: res
end function embedding_layer_cons
end interface embedding_layer
Expand All @@ -57,11 +58,17 @@ pure module subroutine backward(self, input, gradient)
real, intent(in) :: gradient(:, :)
end subroutine backward

pure module subroutine positional_encoding(self, pos)
pure module subroutine positional_trigonometric(self, pos)
!! Sum embedding with positional info (trigonometric, not trianable)
class(embedding_layer), intent(in out) :: self
integer, intent(in) :: pos
end subroutine positional_encoding
end subroutine positional_trigonometric

pure module subroutine positional_absolute(self, pos)
!! Sum embedding with absolute position
class(embedding_layer), intent(in out) :: self
integer, intent(in) :: pos
end subroutine positional_absolute

module subroutine init(self, input_shape)
class(embedding_layer), intent(in out) :: self
Expand Down
28 changes: 22 additions & 6 deletions src/nf/nf_embedding_layer_submodule.f90
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
#define NONE 0
#define TRIGONOMETRIC 1
#define ABSOLUTE 2

submodule(nf_embedding_layer) nf_embedding_layer_submodule
use nf_base_layer, only: base_layer
implicit none
contains
module function embedding_layer_cons(vocab_size, model_dimension, positional) result(res)
integer, intent(in) :: vocab_size, model_dimension
logical, optional :: positional
integer, optional :: positional
type(embedding_layer) :: res

res % vocab_size = vocab_size
res % model_dimension = model_dimension
if (.not. present(positional)) then
res % positional = .false.
res % positional = NONE
else
res % positional = positional
end if
Expand Down Expand Up @@ -46,8 +50,10 @@ pure module subroutine forward(self, input)

self % output(i, :) = self % weights(index, :)

if (self % positional) then
call self % positional_encoding(i)
if (self % positional == TRIGONOMETRIC) then
call self % positional_trigonometric(i)
elseif (self % positional == ABSOLUTE) then
call self % positional_absolute(i)
end if
end do
end subroutine forward
Expand All @@ -63,7 +69,7 @@ pure module subroutine backward(self, input, gradient)
end do
end subroutine backward

pure module subroutine positional_encoding(self, pos)
pure module subroutine positional_trigonometric(self, pos)
class(embedding_layer), intent(in out) :: self
integer, intent(in) :: pos
integer :: i
Expand All @@ -74,7 +80,17 @@ pure module subroutine positional_encoding(self, pos)
self % output(pos, 2 * i - 1) = self % output(pos, 2 * i - 1) + sin(theta)
self % output(pos, 2 * i) = self % output(pos, 2 * i) + cos(theta)
end do
end subroutine positional_encoding
end subroutine positional_trigonometric

pure module subroutine positional_absolute(self, pos)
class(embedding_layer), intent(in out) :: self
integer, intent(in) :: pos
integer :: i

do concurrent(i = 1: self % model_dimension)
self % output(pos, i) = self % output(pos, i) + pos - 1
end do
end subroutine positional_absolute

pure module function get_num_params(self) result(num_params)
class(embedding_layer), intent(in) :: self
Expand Down
45 changes: 40 additions & 5 deletions test/test_embedding_layer.f90
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ program test_embedding_layer
logical :: ok = .true.

call test_simple(ok)
call test_positional(ok)
call test_positional_trigonometric(ok)
call test_positional_absolute(ok)

if (ok) then
print '(a)', 'test_embedding_layer: All tests passed.'
Expand Down Expand Up @@ -47,7 +48,7 @@ subroutine test_simple(ok)
end if
end subroutine test_simple

subroutine test_positional(ok)
subroutine test_positional_trigonometric(ok)
logical, intent(in out) :: ok

integer :: sample_input(3) = [2, 1, 3]
Expand All @@ -63,7 +64,7 @@ subroutine test_positional(ok)
real :: theta
integer :: i, pos

embedding = embedding_layer(vocab_size=5, model_dimension=4, positional=.true.)
embedding = embedding_layer(vocab_size=5, model_dimension=4, positional=1)
call embedding % init([3])
embedding % weights = reshape([&
0.1, 0.3, 0.5, 0.7, 0.2,&
Expand All @@ -77,7 +78,41 @@ subroutine test_positional(ok)
output_flat = reshape(embedding % output, [12])
if (.not. all(abs(output_flat - expected_output_flat) <= (1e-06 + 1e-05 * abs(expected_output_flat)))) then
ok = .false.
write(stderr, '(a)') 'positional encoding returned incorrect values.. failed'
write(stderr, '(a)') 'trigonometric positional encoding returned incorrect values.. failed'
end if
end subroutine test_positional
end subroutine test_positional_trigonometric

subroutine test_positional_absolute(ok)
logical, intent(in out) :: ok

integer :: sample_input(3) = [2, 1, 3]
real :: output_flat(12)
real :: expected_output_flat(12) = reshape([&
0.3, 1.1, 2.5,&
0.3, 1.1, 2.5,&
0.3, 1.1, 2.5,&
0.3, 1.1, 2.5&
], [12])
type(embedding_layer) :: embedding

real :: theta
integer :: i, pos

embedding = embedding_layer(vocab_size=5, model_dimension=4, positional=2)
call embedding % init([3])
embedding % weights = reshape([&
0.1, 0.3, 0.5, 0.7, 0.2,&
0.1, 0.3, 0.5, 0.7, 0.2,&
0.1, 0.3, 0.5, 0.7, 0.2,&
0.1, 0.3, 0.5, 0.7, 0.2&
], [5, 4])

call embedding % forward(sample_input)

output_flat = reshape(embedding % output, [12])
if (.not. all(abs(output_flat - expected_output_flat) <= (1e-06 + 1e-05 * abs(expected_output_flat)))) then
ok = .false.
write(stderr, '(a)') 'absolute positional encoding returned incorrect values.. failed'
end if
end subroutine test_positional_absolute
end program test_embedding_layer