@@ -64,20 +64,22 @@ void main() {
64
64
// +--------+--------+
65
65
// | pos[2] | pos[3] |
66
66
// +--------+--------+
67
- ivec3 pos[TILE_SIZE_X * TILE_SIZE_Y];
67
+ int pos[TILE_SIZE_X * TILE_SIZE_Y * 2 ];
68
68
for (int y = 0 , i = 0 ; y < TILE_SIZE_Y; ++ y) {
69
69
for (int x = 0 ; x < TILE_SIZE_X; ++ x) {
70
- pos[i] = ivec3 (gpos.x * TILE_SIZE_X + x, gpos.y * TILE_SIZE_Y + y, gpos.z);
70
+ pos[i * 2 ] = gpos.x * TILE_SIZE_X + x;
71
+ pos[i * 2 + 1 ] = gpos.y * TILE_SIZE_Y + y;
71
72
i++ ;
72
73
}
73
74
}
74
75
75
76
// Compute the index of the input texture that needs to be loaded for each
76
77
// output position. Note that negative indices can be produced indicating that
77
78
// the top-left element is in a region added by padding.
78
- ivec2 ipos[TILE_SIZE_X * TILE_SIZE_Y];
79
+ int ipos[TILE_SIZE_X * TILE_SIZE_Y * 2 ];
79
80
for (int i = 0 ; i < TILE_SIZE_X * TILE_SIZE_Y; ++ i) {
80
- ipos[i] = pos[i].xy * stride - padding;
81
+ ipos[i * 2 ] = pos[i * 2 ] * stride.x - padding.x;
82
+ ipos[i * 2 + 1 ] = pos[i * 2 + 1 ] * stride.y - padding.y;
81
83
}
82
84
83
85
// Final output array where each element is a tensor value.
@@ -112,7 +114,7 @@ void main() {
112
114
}
113
115
114
116
for (int i = 0 ; i < TILE_SIZE_X * TILE_SIZE_Y; ++ i) {
115
- const vec4 in_tex = texelFetch(t_in, ivec3 (ipos[i], z4), 0 );
117
+ const vec4 in_tex = texelFetch(t_in, ivec3 (ipos[i * 2 ], ipos[i * 2 + 1 ], z4), 0 );
116
118
// Load the input texel into an array
117
119
float tex_values[4 ];
118
120
tex_values[0 ] = in_tex.x;
@@ -163,8 +165,9 @@ void main() {
163
165
}
164
166
165
167
for (int i = 0 ; i < TILE_SIZE_X * TILE_SIZE_Y; ++ i) {
166
- if (all (lessThan (pos[i], out_limits.xyz))) {
167
- imageStore(t_out, pos[i], op(vec4 (sum[i * 4 ], sum[i * 4 + 1 ], sum[i * 4 + 2 ], sum[i * 4 + 3 ]), out_min, out_max));
168
+ const ivec3 pos_l = ivec3 (pos[i * 2 ], pos[i * 2 + 1 ], gpos.z);
169
+ if (all (lessThan (pos_l, out_limits.xyz))) {
170
+ imageStore(t_out, pos_l, op(vec4 (sum[i * 4 ], sum[i * 4 + 1 ], sum[i * 4 + 2 ], sum[i * 4 + 3 ]), out_min, out_max));
168
171
}
169
172
}
170
173
}
0 commit comments