|
11 | 11 | #define PRECISION ${PRECISION}
|
12 | 12 |
|
13 | 13 | #define VEC4_T ${texel_type(DTYPE)}
|
| 14 | +#define T ${buffer_scalar_type(DTYPE)} |
14 | 15 |
|
15 | 16 | #define op(X, Y, A) ${OPERATOR}
|
16 | 17 |
|
| 18 | +${define_active_storage_type(STORAGE)} |
| 19 | +${define_required_extensions(DTYPE)} |
| 20 | + |
17 | 21 | layout(std430) buffer;
|
18 | 22 |
|
19 | 23 | ${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
|
20 | 24 | ${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
|
21 | 25 | ${layout_declare_tensor(B, "r", "t_other", DTYPE, STORAGE)}
|
22 | 26 |
|
| 27 | +$if STORAGE == "buffer": |
| 28 | + layout(push_constant) uniform restrict Block { |
| 29 | + ivec4 in_sizes; |
| 30 | + ivec4 other_sizes; |
| 31 | + ivec4 out_strides; |
| 32 | + ivec4 in_strides; |
| 33 | + ivec4 other_strides; |
| 34 | + int out_numel; |
| 35 | + float alpha; |
| 36 | + }; |
| 37 | +$else: |
| 38 | + layout(push_constant) uniform restrict Block { |
| 39 | + ivec4 out_sizes; |
| 40 | + ivec4 in_sizes; |
| 41 | + ivec4 other_sizes; |
| 42 | + ivec2 broadcast_params; |
| 43 | + float alpha; |
| 44 | + }; |
| 45 | + |
23 | 46 | #include "broadcasting_utils.h"
|
24 | 47 | #include "indexing_utils.h"
|
25 | 48 |
|
26 | 49 | layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
|
27 | 50 |
|
28 |
| -${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} |
29 |
| -const lowp ivec4 out_axis_map = unhash_axis_map(out_layout); |
30 |
| -const lowp int packed_dim = unhash_packed_dim(out_layout); |
| 51 | +$if STORAGE == "buffer": |
| 52 | + ${layout_declare_spec_const(C, "int", "out_packed_dim", "DEFAULT_LAYOUT")} |
| 53 | + ${layout_declare_spec_const(C, "int", "in_packed_dim", "DEFAULT_LAYOUT")} |
| 54 | + ${layout_declare_spec_const(C, "int", "other_packed_dim", "DEFAULT_LAYOUT")} |
| 55 | +$else: |
| 56 | + ${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} |
| 57 | + const lowp ivec4 out_axis_map = unhash_axis_map(out_layout); |
| 58 | + const lowp int packed_dim = unhash_packed_dim(out_layout); |
31 | 59 |
|
32 |
| -${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} |
33 |
| -const lowp ivec4 in_axis_map = unhash_axis_map(in_layout); |
| 60 | + ${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} |
| 61 | + const lowp ivec4 in_axis_map = unhash_axis_map(in_layout); |
34 | 62 |
|
35 |
| -${layout_declare_spec_const(C, "int", "other_layout", "DEFAULT_LAYOUT")} |
36 |
| -const lowp ivec4 other_axis_map = unhash_axis_map(other_layout); |
| 63 | + ${layout_declare_spec_const(C, "int", "other_layout", "DEFAULT_LAYOUT")} |
| 64 | + const lowp ivec4 other_axis_map = unhash_axis_map(other_layout); |
37 | 65 |
|
38 |
| -layout(push_constant) uniform restrict Block { |
39 |
| - ivec4 out_sizes; |
40 |
| - ivec4 in_sizes; |
41 |
| - ivec4 other_sizes; |
42 |
| - ivec2 broadcast_params; |
43 |
| - float alpha; |
44 |
| -}; |
| 66 | +#ifdef USING_BUFFER |
| 67 | + |
| 68 | +void main() { |
| 69 | + const int out_bufi = ivec3(gl_GlobalInvocationID).x; |
| 70 | + if (out_bufi >= out_numel) { |
| 71 | + return; |
| 72 | + } |
| 73 | + |
| 74 | + // Simple case; no broadcasting |
| 75 | + if (in_sizes == other_sizes) { |
| 76 | + t_out[out_bufi] = T(op(t_in[out_bufi], t_other[out_bufi], T(alpha))); |
| 77 | + return; |
| 78 | + } |
| 79 | + |
| 80 | + const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_packed_dim); |
| 81 | + const ivec4 in_tidx = min(out_tidx, in_sizes - 1); |
| 82 | + const ivec4 other_tidx = min(out_tidx, other_sizes - 1); |
| 83 | + |
| 84 | + const int in_bufi = tidx_to_bufi(in_tidx, in_strides); |
| 85 | + const int other_bufi = tidx_to_bufi(other_tidx, other_strides); |
| 86 | + |
| 87 | + t_out[out_bufi] = T(op(t_in[in_bufi], t_other[other_bufi], T(alpha))); |
| 88 | +} |
| 89 | + |
| 90 | +#else // USING_TEXTURE |
45 | 91 |
|
46 | 92 | void main() {
|
47 | 93 | const ivec3 lpos = ivec3(gl_GlobalInvocationID);
|
@@ -79,3 +125,5 @@ void main() {
|
79 | 125 | VEC4_T(op(in_texel, other_texel, alpha)),
|
80 | 126 | out_axis_map);
|
81 | 127 | }
|
| 128 | + |
| 129 | +#endif |
0 commit comments