Skip to content

Commit f1b414c

Browse files
committed
embedding_layer: update tests
1 parent 6bfea21 commit f1b414c

File tree

1 file changed

+68
-25
lines changed

1 file changed

+68
-25
lines changed

test/test_embedding_layer.f90

Lines changed: 68 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,37 +4,80 @@ program test_embedding_layer
44
implicit none
55

66
logical :: ok = .true.
7-
integer :: sample_input(3) = [2, 1, 3]
8-
real :: sample_gradient(3, 2) = reshape([0.1, 0.2, 0.3, 0.4, 0.6, 0.6], [3, 2])
9-
real :: output_flat(6)
10-
real :: expected_output_flat(6) = reshape([0.3, 0.1, 0.5, 0.4, 0.2, 0.6], [6])
11-
real :: dw_flat(8)
12-
real :: expected_dw_flat(8) = reshape([0.2, 0.1, 0.3, 0., 0.6, 0.4, 0.6, 0.], [8])
13-
type(embedding_layer) :: embedding
14-
15-
embedding = embedding_layer(vocab_size=4, model_dimension=2)
16-
call embedding % init([3])
17-
embedding % weights = reshape([0.1, 0.3, 0.5, 0.7, 0.2, 0.4, 0.6, 0.8], [4, 2])
18-
19-
call embedding % forward(sample_input)
20-
21-
output_flat = reshape(embedding % output, [6])
22-
if (.not. all(output_flat.eq.expected_output_flat)) then
23-
ok = .false.
24-
write(stderr, '(a)') 'forward returned incorrect values.. failed'
25-
end if
267

27-
call embedding % backward(sample_input, sample_gradient)
28-
dw_flat = reshape(embedding % dw, shape(dw_flat))
29-
if (.not. all(dw_flat.eq.expected_dw_flat)) then
30-
ok = .false.
31-
write(stderr, '(a)') 'backward returned incorrect dw values.. failed'
32-
end if
8+
call test_simple(ok)
9+
call test_positional(ok)
3310

3411
if (ok) then
3512
print '(a)', 'test_embedding_layer: All tests passed.'
3613
else
3714
write(stderr, '(a)') 'test_embedding_layer: One or more tests failed.'
3815
stop 1
3916
end if
17+
18+
contains
19+
subroutine test_simple(ok)
20+
logical, intent(in out) :: ok
21+
22+
integer :: sample_input(3) = [2, 1, 3]
23+
real :: sample_gradient(3, 2) = reshape([0.1, 0.2, 0.3, 0.4, 0.6, 0.6], [3, 2])
24+
real :: output_flat(6)
25+
real :: expected_output_flat(6) = reshape([0.3, 0.1, 0.5, 0.4, 0.2, 0.6], [6])
26+
real :: dw_flat(8)
27+
real :: expected_dw_flat(8) = reshape([0.2, 0.1, 0.3, 0., 0.6, 0.4, 0.6, 0.], [8])
28+
type(embedding_layer) :: embedding
29+
30+
embedding = embedding_layer(vocab_size=4, model_dimension=2)
31+
call embedding % init([3])
32+
embedding % weights = reshape([0.1, 0.3, 0.5, 0.7, 0.2, 0.4, 0.6, 0.8], [4, 2])
33+
34+
call embedding % forward(sample_input)
35+
36+
output_flat = reshape(embedding % output, [6])
37+
if (.not. all(output_flat.eq.expected_output_flat)) then
38+
ok = .false.
39+
write(stderr, '(a)') 'forward returned incorrect values.. failed'
40+
end if
41+
42+
call embedding % backward(sample_input, sample_gradient)
43+
dw_flat = reshape(embedding % dw, shape(dw_flat))
44+
if (.not. all(dw_flat.eq.expected_dw_flat)) then
45+
ok = .false.
46+
write(stderr, '(a)') 'backward returned incorrect dw values.. failed'
47+
end if
48+
end subroutine test_simple
49+
50+
subroutine test_positional(ok)
51+
logical, intent(in out) :: ok
52+
53+
integer :: sample_input(3) = [2, 1, 3]
54+
real :: output_flat(12)
55+
real :: expected_output_flat(12) = reshape([&
56+
0.3, 0.941471, 1.4092975,&
57+
1.3, 0.64030236, 0.08385316,&
58+
0.3, 0.10999984, 0.51999867,&
59+
1.3, 1.09995, 1.4998&
60+
], [12])
61+
type(embedding_layer) :: embedding
62+
63+
real :: theta
64+
integer :: i, pos
65+
66+
embedding = embedding_layer(vocab_size=5, model_dimension=4, positional=.true.)
67+
call embedding % init([3])
68+
embedding % weights = reshape([&
69+
0.1, 0.3, 0.5, 0.7, 0.2,&
70+
0.1, 0.3, 0.5, 0.7, 0.2,&
71+
0.1, 0.3, 0.5, 0.7, 0.2,&
72+
0.1, 0.3, 0.5, 0.7, 0.2&
73+
], [5, 4])
74+
75+
call embedding % forward(sample_input)
76+
77+
output_flat = reshape(embedding % output, [12])
78+
if (.not. all(abs(output_flat - expected_output_flat) <= (1e-06 + 1e-05 * abs(expected_output_flat)))) then
79+
ok = .false.
80+
write(stderr, '(a)') 'positional encoding returned incorrect values.. failed'
81+
end if
82+
end subroutine test_positional
4083
end program test_embedding_layer

0 commit comments

Comments
 (0)