@@ -4,37 +4,80 @@ program test_embedding_layer
4
4
implicit none
5
5
6
6
logical :: ok = .true.
7
- integer :: sample_input(3 ) = [2 , 1 , 3 ]
8
- real :: sample_gradient(3 , 2 ) = reshape ([0.1 , 0.2 , 0.3 , 0.4 , 0.6 , 0.6 ], [3 , 2 ])
9
- real :: output_flat(6 )
10
- real :: expected_output_flat(6 ) = reshape ([0.3 , 0.1 , 0.5 , 0.4 , 0.2 , 0.6 ], [6 ])
11
- real :: dw_flat(8 )
12
- real :: expected_dw_flat(8 ) = reshape ([0.2 , 0.1 , 0.3 , 0 ., 0.6 , 0.4 , 0.6 , 0 .], [8 ])
13
- type (embedding_layer) :: embedding
14
-
15
- embedding = embedding_layer(vocab_size= 4 , model_dimension= 2 )
16
- call embedding % init([3 ])
17
- embedding % weights = reshape ([0.1 , 0.3 , 0.5 , 0.7 , 0.2 , 0.4 , 0.6 , 0.8 ], [4 , 2 ])
18
-
19
- call embedding % forward(sample_input)
20
-
21
- output_flat = reshape (embedding % output, [6 ])
22
- if (.not. all (output_flat.eq. expected_output_flat)) then
23
- ok = .false.
24
- write (stderr, ' (a)' ) ' forward returned incorrect values.. failed'
25
- end if
26
7
27
- call embedding % backward(sample_input, sample_gradient)
28
- dw_flat = reshape (embedding % dw, shape (dw_flat))
29
- if (.not. all (dw_flat.eq. expected_dw_flat)) then
30
- ok = .false.
31
- write (stderr, ' (a)' ) ' backward returned incorrect dw values.. failed'
32
- end if
8
+ call test_simple(ok)
9
+ call test_positional(ok)
33
10
34
11
if (ok) then
35
12
print ' (a)' , ' test_embedding_layer: All tests passed.'
36
13
else
37
14
write (stderr, ' (a)' ) ' test_embedding_layer: One or more tests failed.'
38
15
stop 1
39
16
end if
17
+
18
+ contains
19
+ subroutine test_simple (ok )
20
+ logical , intent (in out ) :: ok
21
+
22
+ integer :: sample_input(3 ) = [2 , 1 , 3 ]
23
+ real :: sample_gradient(3 , 2 ) = reshape ([0.1 , 0.2 , 0.3 , 0.4 , 0.6 , 0.6 ], [3 , 2 ])
24
+ real :: output_flat(6 )
25
+ real :: expected_output_flat(6 ) = reshape ([0.3 , 0.1 , 0.5 , 0.4 , 0.2 , 0.6 ], [6 ])
26
+ real :: dw_flat(8 )
27
+ real :: expected_dw_flat(8 ) = reshape ([0.2 , 0.1 , 0.3 , 0 ., 0.6 , 0.4 , 0.6 , 0 .], [8 ])
28
+ type (embedding_layer) :: embedding
29
+
30
+ embedding = embedding_layer(vocab_size= 4 , model_dimension= 2 )
31
+ call embedding % init([3 ])
32
+ embedding % weights = reshape ([0.1 , 0.3 , 0.5 , 0.7 , 0.2 , 0.4 , 0.6 , 0.8 ], [4 , 2 ])
33
+
34
+ call embedding % forward(sample_input)
35
+
36
+ output_flat = reshape (embedding % output, [6 ])
37
+ if (.not. all (output_flat.eq. expected_output_flat)) then
38
+ ok = .false.
39
+ write (stderr, ' (a)' ) ' forward returned incorrect values.. failed'
40
+ end if
41
+
42
+ call embedding % backward(sample_input, sample_gradient)
43
+ dw_flat = reshape (embedding % dw, shape (dw_flat))
44
+ if (.not. all (dw_flat.eq. expected_dw_flat)) then
45
+ ok = .false.
46
+ write (stderr, ' (a)' ) ' backward returned incorrect dw values.. failed'
47
+ end if
48
+ end subroutine test_simple
49
+
50
+ subroutine test_positional (ok )
51
+ logical , intent (in out ) :: ok
52
+
53
+ integer :: sample_input(3 ) = [2 , 1 , 3 ]
54
+ real :: output_flat(12 )
55
+ real :: expected_output_flat(12 ) = reshape ([&
56
+ 0.3 , 0.941471 , 1.4092975 ,&
57
+ 1.3 , 0.64030236 , 0.08385316 ,&
58
+ 0.3 , 0.10999984 , 0.51999867 ,&
59
+ 1.3 , 1.09995 , 1.4998 &
60
+ ], [12 ])
61
+ type (embedding_layer) :: embedding
62
+
63
+ real :: theta
64
+ integer :: i, pos
65
+
66
+ embedding = embedding_layer(vocab_size= 5 , model_dimension= 4 , positional= .true. )
67
+ call embedding % init([3 ])
68
+ embedding % weights = reshape ([&
69
+ 0.1 , 0.3 , 0.5 , 0.7 , 0.2 ,&
70
+ 0.1 , 0.3 , 0.5 , 0.7 , 0.2 ,&
71
+ 0.1 , 0.3 , 0.5 , 0.7 , 0.2 ,&
72
+ 0.1 , 0.3 , 0.5 , 0.7 , 0.2 &
73
+ ], [5 , 4 ])
74
+
75
+ call embedding % forward(sample_input)
76
+
77
+ output_flat = reshape (embedding % output, [12 ])
78
+ if (.not. all (abs (output_flat - expected_output_flat) <= (1e-06 + 1e-05 * abs (expected_output_flat)))) then
79
+ ok = .false.
80
+ write (stderr, ' (a)' ) ' positional encoding returned incorrect values.. failed'
81
+ end if
82
+ end subroutine test_positional
40
83
end program test_embedding_layer
0 commit comments