@@ -64,24 +64,21 @@ void main() {
64
64
65
65
FLOAT_T outval = FLOAT_T(0.0 );
66
66
67
- // Initial mat1 tensor idx will be (0, out_tidx.y, out_tidx.z, 0)
68
67
int mat1_offset = out_tidx.y * mat1_strides.y + out_tidx.z * qmat2_strides.z;
69
- // Initial qmat2 tensor idx wil be (0, out_tidx.x, 0, 0); note that the qmat2
70
- // tensor is transposed
71
- int qmat2_offset = out_tidx.x * qmat2_strides.y;
68
+ int qmat2_offset = out_tidx.x;
72
69
73
70
// TODO(ssjia): optimize memory access pattern by traversing mat1 x in inner loop
74
71
for (int i = 0 ; i < mat1_sizes.x; i++ ) {
75
72
const FLOAT_T mat1_val = t_mat1[mat1_offset];
76
- const FLOAT_T mat2_val = t_qmat2[qmat2_offset] * scale ;
73
+ const FLOAT_T mat2_val = FLOAT_T( t_qmat2[qmat2_offset]) ;
77
74
78
75
outval += mat1_val * mat2_val;
79
76
80
77
mat1_offset++ ;
81
- qmat2_offset++ ;
78
+ qmat2_offset += qmat2_strides.y ;
82
79
}
83
80
84
- t_out[out_bufi] = outval;
81
+ t_out[out_bufi] = outval * scale ;
85
82
}
86
83
87
84
#else // USING_TEXTURE
@@ -97,25 +94,27 @@ void main() {
97
94
return ;
98
95
}
99
96
100
- const uint16_t qmat2_pos_y = out_pos.x * uint16_t( 4 ) ;
97
+ const uint16_t qmat2_pos_x = out_pos.x;
101
98
102
99
VEC4_T outtex = VEC4_T(0 );
103
100
104
101
const VEC4_T scales = load_texel(t_scales, u16vec3(out_pos.x, 0 , 0 ));
105
102
103
+ VEC4_T mat1_tex;
104
+ VEC4_T mat2_tex[4 ];
106
105
for (
107
106
uint16_t i = uint16_t(0 ), x = uint16_t(0 );
108
107
i < uint16_t(mat1_sizes.x);
109
108
i += uint16_t(4 ), x++ )
110
109
{
111
- const VEC4_T mat1_tex = load_texel(t_mat1, u16vec3(x, out_pos.y, 0 ));
112
- const VEC4_T sums = VEC4_T(
113
- dot (mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y , 0 ))),
114
- dot (mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(1 ), 0 ))),
115
- dot (mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(2 ), 0 ))),
116
- dot (mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(3 ), 0 )) ));
117
-
118
- outtex += sums ;
110
+ mat1_tex = load_texel(t_mat1, u16vec3(x, out_pos.y, 0 ));
111
+
112
+ mat2_tex[ 0 ] = load_texel(t_qmat2, u16vec3(out_pos. x, i , 0 ));
113
+ mat2_tex[ 1 ] = load_texel(t_qmat2, u16vec3(out_pos. x, i + uint16_t(1 ), 0 ));
114
+ mat2_tex[ 2 ] = load_texel(t_qmat2, u16vec3(out_pos. x, i + uint16_t(2 ), 0 ));
115
+ mat2_tex[ 3 ] = load_texel(t_qmat2, u16vec3(out_pos. x, i + uint16_t(3 ), 0 ));
116
+
117
+ outtex += mat1_tex.x * mat2_tex[ 0 ] + mat1_tex.y * mat2_tex[ 1 ] + mat1_tex.z * mat2_tex[ 2 ] + mat1_tex.w * mat2_tex[ 3 ] ;
119
118
}
120
119
121
120
outtex *= scales;
0 commit comments