Skip to content

Embedding Layer #205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
embedding_layer: update constructor and tests
  • Loading branch information
OneAdder committed Feb 23, 2025
commit 73799bd5a4693b6be0e990a3db5b3f80134d6344
3 changes: 2 additions & 1 deletion src/nf/nf_layer_constructors.f90
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ module function self_attention(num_heads) result(res)
!! Resulting layer instance
end function self_attention

module function embedding(sequence_length, vocab_size, model_dimension) result(res)
module function embedding(sequence_length, vocab_size, model_dimension, positional) result(res)
!! Embedding layer constructor.
!!
!! This layer is for inputting token indices from the dictionary to the network.
Expand All @@ -243,6 +243,7 @@ module function embedding(sequence_length, vocab_size, model_dimension) result(r
!! `vocab_size`: length of token vocabulary
!! `model_dimension`: size of target embeddings
integer, intent(in) :: sequence_length, vocab_size, model_dimension
integer, optional, intent(in) :: positional
type(layer) :: res
end function embedding

Expand Down
5 changes: 3 additions & 2 deletions src/nf/nf_layer_constructors_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,13 @@ module function self_attention(num_heads) result(res)
end function self_attention


module function embedding(sequence_length, vocab_size, model_dimension) result(res)
module function embedding(sequence_length, vocab_size, model_dimension, positional) result(res)
integer, intent(in) :: sequence_length, vocab_size, model_dimension
integer, optional, intent(in) :: positional
type(layer) :: res
type(embedding_layer) :: embedding_layer_instance

embedding_layer_instance = embedding_layer(vocab_size, model_dimension)
embedding_layer_instance = embedding_layer(vocab_size, model_dimension, positional)
call embedding_layer_instance % init([sequence_length])
res % name = 'embedding'
res % layer_shape = [sequence_length, model_dimension]
Expand Down
33 changes: 24 additions & 9 deletions test/test_embedding_layer.f90
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
program test_embedding_layer
use iso_fortran_env, only: stderr => error_unit
use nf_embedding_layer, only: embedding_layer
use nf_layer, only: layer
use nf_layer_constructors, only: embedding_constructor => embedding
implicit none

logical :: ok = .true.
integer :: sample_input(3) = [2, 1, 3]

call test_simple(ok)
call test_positional_trigonometric(ok)
call test_positional_absolute(ok)
call test_simple(ok, sample_input)
call test_positional_trigonometric(ok, sample_input)
call test_positional_absolute(ok, sample_input)

if (ok) then
print '(a)', 'test_embedding_layer: All tests passed.'
Expand All @@ -17,10 +20,10 @@ program test_embedding_layer
end if

contains
subroutine test_simple(ok)
subroutine test_simple(ok, sample_input)
logical, intent(in out) :: ok
integer, intent(in) :: sample_input(:)

integer :: sample_input(3) = [2, 1, 3]
real :: sample_gradient(3, 2) = reshape([0.1, 0.2, 0.3, 0.4, 0.6, 0.6], [3, 2])
real :: output_flat(6)
real :: expected_output_flat(6) = reshape([0.3, 0.1, 0.5, 0.4, 0.2, 0.6], [6])
Expand Down Expand Up @@ -48,10 +51,10 @@ subroutine test_simple(ok)
end if
end subroutine test_simple

subroutine test_positional_trigonometric(ok)
subroutine test_positional_trigonometric(ok, sample_input)
logical, intent(in out) :: ok
integer, intent(in) :: sample_input(:)

integer :: sample_input(3) = [2, 1, 3]
real :: output_flat(12)
real :: expected_output_flat(12) = reshape([&
0.3, 0.941471, 1.4092975,&
Expand Down Expand Up @@ -82,10 +85,10 @@ subroutine test_positional_trigonometric(ok)
end if
end subroutine test_positional_trigonometric

subroutine test_positional_absolute(ok)
subroutine test_positional_absolute(ok, sample_input)
logical, intent(in out) :: ok
integer, intent(in) :: sample_input(:)

integer :: sample_input(3) = [2, 1, 3]
real :: output_flat(12)
real :: expected_output_flat(12) = reshape([&
0.3, 1.1, 2.5,&
Expand Down Expand Up @@ -115,4 +118,16 @@ subroutine test_positional_absolute(ok)
write(stderr, '(a)') 'absolute positional encoding returned incorrect values.. failed'
end if
end subroutine test_positional_absolute

subroutine test_embedding_constructor(ok, sample_input)
logical, intent(in out) :: ok
integer, intent(in) :: sample_input(:)

type(layer) :: embedding_constructed

embedding_constructed = embedding_constructor(sequence_length=3, vocab_size=5, model_dimension=4)
embedding_constructed = embedding_constructor(sequence_length=3, vocab_size=5, model_dimension=4, positional=0)
embedding_constructed = embedding_constructor(sequence_length=3, vocab_size=5, model_dimension=4, positional=1)
embedding_constructed = embedding_constructor(sequence_length=3, vocab_size=5, model_dimension=4, positional=2)
end subroutine test_embedding_constructor
end program test_embedding_layer
Loading