12
12
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
13
13
#pragma OPENCL EXTENSION cl_khr_fp64 : enable
14
14
15
+ int __nvvm_reflect (const char __constant * );
16
+
15
17
// CLC helpers
16
18
__local bool *
17
19
__clc__get_group_scratch_bool () __asm("__clc__get_group_scratch_bool" );
@@ -150,43 +152,58 @@ __clc__SubgroupBitwiseAny(uint op, bool predicate, bool *carry) {
150
152
#define __CLC_OR (x , y ) (x | y)
151
153
#define __CLC_AND (x , y ) (x & y)
152
154
155
+ #define __CLC_SUBGROUP_COLLECTIVE_BODY (OP , TYPE , IDENTITY ) \
156
+ uint sg_lid = __spirv_SubgroupLocalInvocationId(); \
157
+ /* Can't use XOR/butterfly shuffles; some lanes may be inactive */ \
158
+ for (int o = 1 ; o < __spirv_SubgroupMaxSize (); o *= 2 ) { \
159
+ TYPE contribution = __clc__SubgroupShuffleUp (x , o ); \
160
+ bool inactive = (sg_lid < o ); \
161
+ contribution = (inactive ) ? IDENTITY : contribution ; \
162
+ x = OP (x , contribution ); \
163
+ } \
164
+ /* For Reduce, broadcast result from highest active lane */ \
165
+ TYPE result ; \
166
+ if (op == Reduce ) { \
167
+ result = __clc__SubgroupShuffle (x , __spirv_SubgroupSize () - 1 ); \
168
+ * carry = result ; \
169
+ } /* For InclusiveScan, use results as computed */ \
170
+ else if (op == InclusiveScan ) { \
171
+ result = x ; \
172
+ * carry = result ; \
173
+ } /* For ExclusiveScan, shift and prepend identity */ \
174
+ else if (op == ExclusiveScan ) { \
175
+ * carry = x ; \
176
+ result = __clc__SubgroupShuffleUp (x , 1 ); \
177
+ if (sg_lid == 0 ) { \
178
+ result = IDENTITY ; \
179
+ } \
180
+ } \
181
+ return result ;
182
+
153
183
#define __CLC_SUBGROUP_COLLECTIVE (NAME , OP , TYPE , IDENTITY ) \
154
184
_CLC_DEF _CLC_OVERLOAD _CLC_CONVERGENT TYPE __CLC_APPEND( \
155
185
__clc__Subgroup, NAME)(uint op, TYPE x, TYPE * carry) { \
156
- uint sg_lid = __spirv_SubgroupLocalInvocationId(); \
157
- /* Can't use XOR/butterfly shuffles; some lanes may be inactive */ \
158
- for (int o = 1 ; o < __spirv_SubgroupMaxSize (); o *= 2 ) { \
159
- TYPE contribution = __clc__SubgroupShuffleUp (x , o ); \
160
- bool inactive = (sg_lid < o ); \
161
- contribution = (inactive ) ? IDENTITY : contribution ; \
162
- x = OP (x , contribution ); \
163
- } \
164
- /* For Reduce, broadcast result from highest active lane */ \
165
- TYPE result ; \
166
- if (op == Reduce ) { \
167
- result = __clc__SubgroupShuffle (x , __spirv_SubgroupSize () - 1 ); \
168
- * carry = result ; \
169
- } /* For InclusiveScan, use results as computed */ \
170
- else if (op == InclusiveScan ) { \
171
- result = x ; \
186
+ __CLC_SUBGROUP_COLLECTIVE_BODY(OP, TYPE, IDENTITY) \
187
+ }
188
+
189
+ #define __CLC_SUBGROUP_COLLECTIVE_REDUX (NAME , OP , REDUX_OP , TYPE , IDENTITY ) \
190
+ _CLC_DEF _CLC_OVERLOAD _CLC_CONVERGENT TYPE __CLC_APPEND( \
191
+ __clc__Subgroup, NAME)(uint op, TYPE x, TYPE * carry) { \
192
+ /* Fast path for warp reductions for sm_80+ */ \
193
+ if (__nvvm_reflect ("__CUDA_ARCH" ) >= 800 && op == Reduce ) { \
194
+ TYPE result = __nvvm_redux_sync_ ##REDUX_OP (x, __clc__membermask()); \
172
195
*carry = result; \
173
- } /* For ExclusiveScan, shift and prepend identity */ \
174
- else if (op == ExclusiveScan ) { \
175
- * carry = x ; \
176
- result = __clc__SubgroupShuffleUp (x , 1 ); \
177
- if (sg_lid == 0 ) { \
178
- result = IDENTITY ; \
179
- } \
196
+ return result; \
180
197
} \
181
- return result ; \
198
+ __CLC_SUBGROUP_COLLECTIVE_BODY(OP, TYPE, IDENTITY) \
182
199
}
183
200
184
201
__CLC_SUBGROUP_COLLECTIVE (IAdd , __CLC_ADD , char , 0 )
185
202
__CLC_SUBGROUP_COLLECTIVE (IAdd , __CLC_ADD , uchar , 0 )
186
203
__CLC_SUBGROUP_COLLECTIVE (IAdd , __CLC_ADD , short , 0 )
187
204
__CLC_SUBGROUP_COLLECTIVE (IAdd , __CLC_ADD , ushort , 0 )
188
- __CLC_SUBGROUP_COLLECTIVE (IAdd , __CLC_ADD , int , 0 )
189
- __CLC_SUBGROUP_COLLECTIVE (IAdd , __CLC_ADD , uint , 0 )
205
+ __CLC_SUBGROUP_COLLECTIVE_REDUX (IAdd , __CLC_ADD , add , int , 0 )
206
+ __CLC_SUBGROUP_COLLECTIVE_REDUX (IAdd , __CLC_ADD , add , uint , 0 )
190
207
__CLC_SUBGROUP_COLLECTIVE (IAdd , __CLC_ADD , long , 0 )
191
208
__CLC_SUBGROUP_COLLECTIVE (IAdd , __CLC_ADD , ulong , 0 )
192
209
__CLC_SUBGROUP_COLLECTIVE (FAdd , __CLC_ADD , half , 0 )
@@ -197,8 +214,8 @@ __CLC_SUBGROUP_COLLECTIVE(SMin, __CLC_MIN, char, CHAR_MAX)
197
214
__CLC_SUBGROUP_COLLECTIVE (UMin , __CLC_MIN , uchar , UCHAR_MAX )
198
215
__CLC_SUBGROUP_COLLECTIVE (SMin , __CLC_MIN , short , SHRT_MAX )
199
216
__CLC_SUBGROUP_COLLECTIVE (UMin , __CLC_MIN , ushort , USHRT_MAX )
200
- __CLC_SUBGROUP_COLLECTIVE (SMin , __CLC_MIN , int , INT_MAX )
201
- __CLC_SUBGROUP_COLLECTIVE (UMin , __CLC_MIN , uint , UINT_MAX )
217
+ __CLC_SUBGROUP_COLLECTIVE_REDUX (SMin , __CLC_MIN , min , int , INT_MAX )
218
+ __CLC_SUBGROUP_COLLECTIVE_REDUX (UMin , __CLC_MIN , umin , uint , UINT_MAX )
202
219
__CLC_SUBGROUP_COLLECTIVE (SMin , __CLC_MIN , long , LONG_MAX )
203
220
__CLC_SUBGROUP_COLLECTIVE (UMin , __CLC_MIN , ulong , ULONG_MAX )
204
221
__CLC_SUBGROUP_COLLECTIVE (FMin , __CLC_MIN , half , HALF_MAX )
@@ -209,15 +226,17 @@ __CLC_SUBGROUP_COLLECTIVE(SMax, __CLC_MAX, char, CHAR_MIN)
209
226
__CLC_SUBGROUP_COLLECTIVE (UMax , __CLC_MAX , uchar , 0 )
210
227
__CLC_SUBGROUP_COLLECTIVE (SMax , __CLC_MAX , short , SHRT_MIN )
211
228
__CLC_SUBGROUP_COLLECTIVE (UMax , __CLC_MAX , ushort , 0 )
212
- __CLC_SUBGROUP_COLLECTIVE (SMax , __CLC_MAX , int , INT_MIN )
213
- __CLC_SUBGROUP_COLLECTIVE (UMax , __CLC_MAX , uint , 0 )
229
+ __CLC_SUBGROUP_COLLECTIVE_REDUX (SMax , __CLC_MAX , max , int , INT_MIN )
230
+ __CLC_SUBGROUP_COLLECTIVE_REDUX (UMax , __CLC_MAX , umax , uint , 0 )
214
231
__CLC_SUBGROUP_COLLECTIVE (SMax , __CLC_MAX , long , LONG_MIN )
215
232
__CLC_SUBGROUP_COLLECTIVE (UMax , __CLC_MAX , ulong , 0 )
216
233
__CLC_SUBGROUP_COLLECTIVE (FMax , __CLC_MAX , half , - HALF_MAX )
217
234
__CLC_SUBGROUP_COLLECTIVE (FMax , __CLC_MAX , float , - FLT_MAX )
218
235
__CLC_SUBGROUP_COLLECTIVE (FMax , __CLC_MAX , double , - DBL_MAX )
219
236
237
+ #undef __CLC_SUBGROUP_COLLECTIVE_BODY
220
238
#undef __CLC_SUBGROUP_COLLECTIVE
239
+ #undef __CLC_SUBGROUP_COLLECTIVE_REDUX
221
240
222
241
#define __CLC_GROUP_COLLECTIVE (NAME , OP , TYPE , IDENTITY ) \
223
242
_CLC_DEF _CLC_OVERLOAD _CLC_CONVERGENT TYPE __CLC_APPEND( \
0 commit comments