14
14
15
15
#define TILE_SIZE ${TILE_SIZE}
16
16
17
+ #define BATCH_SIZE_Y ${BATCH_SIZE_Y}
18
+
17
19
#define op(X, A, B) ${OPERATOR}
18
20
19
21
#include "indexing_utils.h"
@@ -39,12 +41,20 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
39
41
* output at a single output location.
40
42
*/
41
43
void main() {
42
- const u16vec3 pos = u16vec3(
44
+ // y divided up by batch size is used to determine 3d position
45
+ // since work size is calculated by x * ((y + B_Y - 1) / B_Y) * z
46
+ const uint out_limits_y_scaled = (out_limits.y + BATCH_SIZE_Y - 1 ) / BATCH_SIZE_Y;
47
+
48
+ u16vec3 pos = u16vec3(
43
49
gl_GlobalInvocationID.x % out_limits.x,
44
- (gl_GlobalInvocationID.x / out_limits.x) % out_limits.y ,
45
- gl_GlobalInvocationID.x / (out_limits.x * out_limits.y ));
50
+ (( gl_GlobalInvocationID.x / out_limits.x) % out_limits_y_scaled) ,
51
+ gl_GlobalInvocationID.x / (out_limits.x * out_limits_y_scaled ));
46
52
47
- if (any (greaterThanEqual (pos, out_limits))) {
53
+ // scale pos.y by batch size, because that's the top pixel to be processed
54
+ pos.y *= uint16_t(BATCH_SIZE_Y);
55
+
56
+ // do not process if top pixel does not fit within the output range
57
+ if (any (greaterThanEqual (u16vec3(pos.x, pos.y, pos.z), out_limits))) {
48
58
return ;
49
59
}
50
60
@@ -57,18 +67,47 @@ void main() {
57
67
const u16vec2 start = ipos;
58
68
const u16vec2 end = ipos + u16vec2(overlay_region.xy);
59
69
60
- VEC4_T sum = texelFetch(t_bias, u16vec2(pos.z, 0 ), 0 );
70
+ // sum outputs
71
+ VEC4_T sum[BATCH_SIZE_Y];
72
+
73
+ sum[0 ] = texelFetch(t_bias, u16vec2(pos.z, 0 ), 0 );
74
+ for (int i = 1 ; i < BATCH_SIZE_Y; i++ ) {
75
+ sum[i] = sum[0 ];
76
+ }
77
+
78
+ // array to store input texels
79
+ VEC4_T in_texels[TILE_SIZE];
80
+
81
+ // array to store kernel data of previous y
82
+ VEC4_T prev_kernel_line[TILE_SIZE];
83
+
61
84
uint16_t kx = uint16_t(0 );
62
- for (uint16_t y = start.y, i = uint16_t(0 ); i < uint16_t(TILE_SIZE); y += uint16_t(dilation.y), i++ ) {
85
+ for (uint16_t y = start.y, i = uint16_t(0 ); i < uint16_t(TILE_SIZE + BATCH_SIZE_Y - 1 ); y += uint16_t(dilation.y), i++ ) {
63
86
for (uint16_t x = start.x, j = uint16_t(0 ); j < uint16_t(TILE_SIZE); x += uint16_t(dilation.x), j++ ) {
64
- // The weight kernel was rearranged such that every NxN filter is
65
- // flattened to fit in one row. Each filter was then stacked on top of
66
- // each other vertically.
67
- const vec4 in_texel = texelFetch(t_in, u16vec3(x, y, pos.z), 0 );
68
- sum = fma(in_texel, texelFetch(t_kernel, u16vec2(kx, pos.z), 0 ), sum);
69
- kx++ ;
87
+ in_texels[int (j)] = texelFetch(t_in, u16vec3(x, y, pos.z), 0 );
88
+ }
89
+
90
+ // from 2nd iteration onwards accumulate dot product in 2nd sum
91
+ // based on kernel line data fetched in previous iteration and input texel from this iteration
92
+ if (i > uint16_t(0 )) {
93
+ for (uint16_t j = uint16_t(0 ); j < uint16_t(TILE_SIZE); j++ ) {
94
+ sum[1 ] = fma(in_texels[int (j)], prev_kernel_line[int (j)], sum[1 ]);
95
+ }
96
+ }
97
+
98
+ // accumulate dot product in 1st sum only until tile size
99
+ if (i < uint16_t(TILE_SIZE)) {
100
+ for (uint16_t j = uint16_t(0 ); j < uint16_t(TILE_SIZE); j++ , kx++ ) {
101
+ prev_kernel_line[int (j)] = texelFetch(t_kernel, u16vec2(kx, pos.z), 0 );
102
+ sum[0 ] = fma(in_texels[int (j)], prev_kernel_line[int (j)], sum[0 ]);
103
+ }
70
104
}
71
105
}
72
106
73
- imageStore(t_out, pos, op(sum, out_min, out_max));
107
+ for (int i = 0 ; i < BATCH_SIZE_Y; i++ ) {
108
+ if (any (greaterThanEqual (u16vec3(pos.x, pos.y + i, pos.z), out_limits))) {
109
+ continue ;
110
+ }
111
+ imageStore(t_out, u16vec3(pos.x, pos.y + i, pos.z), op(sum[i], out_min, out_max));
112
+ }
74
113
}
0 commit comments