Skip to content

Commit 78411c4

Browse files
[SYCL-PTX] Add warp-reduce path in sub-group reduce (#3949)
PTX introduces new warp reduction instructions with sm_80. These changes adds a path in select libclc subgroup collective functions for using these warp reduction instructions when available. As an effect of these changes, future changes to libclc targeting nvptx can use __nvvm_reflect to differentiate between SM versions, allowing architecture-specific paths that will be determined after linking device code with libclc. Signed-off-by: Steffen Larsen <steffen.larsen@codeplay.com>
1 parent e975d8d commit 78411c4

File tree

3 files changed

+60
-34
lines changed

3 files changed

+60
-34
lines changed

libclc/CMakeLists.txt

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,9 @@ foreach( t ${LIBCLC_TARGETS_TO_BUILD} )
311311
if( ${d} STREQUAL "none" OR ${ARCH} STREQUAL "spirv" OR ${ARCH} STREQUAL "spirv64" )
312312
# FIXME: Ideally we would not be tied to a specific PTX ISA version
313313
if( ${ARCH} STREQUAL nvptx OR ${ARCH} STREQUAL nvptx64 )
314-
set( flags "SHELL:-Xclang -target-feature" "SHELL:-Xclang +ptx64")
314+
# Disables NVVM reflection to defer to after linking
315+
set( flags "SHELL:-Xclang -target-feature" "SHELL:-Xclang +ptx72"
316+
"SHELL:-march=sm_86" "SHELL:-mllvm --nvvm-reflect-enable=false")
315317
endif()
316318
set( arch_suffix "${t}" )
317319
else()
@@ -327,12 +329,15 @@ foreach( t ${LIBCLC_TARGETS_TO_BUILD} )
327329
set( t "spir64--" )
328330
endif()
329331
set( build_flags -O0 -finline-hint-functions )
330-
set( opt_flags )
332+
set( opt_flags -O3 )
331333
set( spvflags --spirv-max-version=1.1 )
332334
elseif( ${ARCH} STREQUAL "clspv" )
333335
set( t "spir--" )
334336
set( build_flags )
335337
set( opt_flags -O3 )
338+
elseif( ${ARCH} STREQUAL "nvptx" OR ${ARCH} STREQUAL "nvptx64" )
339+
set( build_flags )
340+
set( opt_flags -O3 "--nvvm-reflect-enable=false" )
336341
else()
337342
set( build_flags )
338343
set( opt_flags -O3 )
@@ -342,6 +347,7 @@ foreach( t ${LIBCLC_TARGETS_TO_BUILD} )
342347
TRIPLE ${t}
343348
TARGET_ENV libspirv
344349
COMPILE_OPT ${flags}
350+
OPT_FLAGS ${opt_flags}
345351
FILES ${libspirv_files}
346352
ALIASES ${${d}_aliases}
347353
GENERATE_TARGET "generate_convert_spirv.cl" "generate_convert_core.cl"
@@ -351,6 +357,7 @@ foreach( t ${LIBCLC_TARGETS_TO_BUILD} )
351357
TRIPLE ${t}
352358
TARGET_ENV clc
353359
COMPILE_OPT ${flags}
360+
OPT_FLAGS ${opt_flags}
354361
FILES ${lib_files}
355362
LIB_DEP libspirv-${arch_suffix}
356363
ALIASES ${${d}_aliases}

libclc/cmake/modules/AddLibclc.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ macro(add_libclc_builtin_set arch_suffix)
4343
cmake_parse_arguments(ARG
4444
""
4545
"TRIPLE;TARGET_ENV;LIB_DEP;PARENT_TARGET"
46-
"FILES;ALIASES;GENERATE_TARGET;COMPILE_OPT"
46+
"FILES;ALIASES;GENERATE_TARGET;COMPILE_OPT;OPT_FLAGS"
4747
${ARGN})
4848

4949
if (DEFINED ${ARG_LIB_DEP})
@@ -76,7 +76,7 @@ macro(add_libclc_builtin_set arch_suffix)
7676
# Add opt target
7777
set( builtins_opt_path "${LIBCLC_LIBRARY_OUTPUT_INTDIR}/builtins.opt.${obj_suffix}" )
7878
add_custom_command( OUTPUT "${builtins_opt_path}"
79-
COMMAND ${LLVM_OPT} -O3 -o
79+
COMMAND ${LLVM_OPT} ${ARG_OPT_FLAGS} -o
8080
"${builtins_opt_path}"
8181
"${LIBCLC_LIBRARY_OUTPUT_INTDIR}/builtins.link.${obj_suffix}"
8282
DEPENDS opt "builtins.link.${arch_suffix}" )

libclc/ptx-nvidiacl/libspirv/group/collectives.cl

Lines changed: 49 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
1313
#pragma OPENCL EXTENSION cl_khr_fp64 : enable
1414

15+
int __nvvm_reflect(const char __constant *);
16+
1517
// CLC helpers
1618
__local bool *
1719
__clc__get_group_scratch_bool() __asm("__clc__get_group_scratch_bool");
@@ -150,43 +152,58 @@ __clc__SubgroupBitwiseAny(uint op, bool predicate, bool *carry) {
150152
#define __CLC_OR(x, y) (x | y)
151153
#define __CLC_AND(x, y) (x & y)
152154

155+
#define __CLC_SUBGROUP_COLLECTIVE_BODY(OP, TYPE, IDENTITY) \
156+
uint sg_lid = __spirv_SubgroupLocalInvocationId(); \
157+
/* Can't use XOR/butterfly shuffles; some lanes may be inactive */ \
158+
for (int o = 1; o < __spirv_SubgroupMaxSize(); o *= 2) { \
159+
TYPE contribution = __clc__SubgroupShuffleUp(x, o); \
160+
bool inactive = (sg_lid < o); \
161+
contribution = (inactive) ? IDENTITY : contribution; \
162+
x = OP(x, contribution); \
163+
} \
164+
/* For Reduce, broadcast result from highest active lane */ \
165+
TYPE result; \
166+
if (op == Reduce) { \
167+
result = __clc__SubgroupShuffle(x, __spirv_SubgroupSize() - 1); \
168+
*carry = result; \
169+
} /* For InclusiveScan, use results as computed */ \
170+
else if (op == InclusiveScan) { \
171+
result = x; \
172+
*carry = result; \
173+
} /* For ExclusiveScan, shift and prepend identity */ \
174+
else if (op == ExclusiveScan) { \
175+
*carry = x; \
176+
result = __clc__SubgroupShuffleUp(x, 1); \
177+
if (sg_lid == 0) { \
178+
result = IDENTITY; \
179+
} \
180+
} \
181+
return result;
182+
153183
#define __CLC_SUBGROUP_COLLECTIVE(NAME, OP, TYPE, IDENTITY) \
154184
_CLC_DEF _CLC_OVERLOAD _CLC_CONVERGENT TYPE __CLC_APPEND( \
155185
__clc__Subgroup, NAME)(uint op, TYPE x, TYPE * carry) { \
156-
uint sg_lid = __spirv_SubgroupLocalInvocationId(); \
157-
/* Can't use XOR/butterfly shuffles; some lanes may be inactive */ \
158-
for (int o = 1; o < __spirv_SubgroupMaxSize(); o *= 2) { \
159-
TYPE contribution = __clc__SubgroupShuffleUp(x, o); \
160-
bool inactive = (sg_lid < o); \
161-
contribution = (inactive) ? IDENTITY : contribution; \
162-
x = OP(x, contribution); \
163-
} \
164-
/* For Reduce, broadcast result from highest active lane */ \
165-
TYPE result; \
166-
if (op == Reduce) { \
167-
result = __clc__SubgroupShuffle(x, __spirv_SubgroupSize() - 1); \
168-
*carry = result; \
169-
} /* For InclusiveScan, use results as computed */ \
170-
else if (op == InclusiveScan) { \
171-
result = x; \
186+
__CLC_SUBGROUP_COLLECTIVE_BODY(OP, TYPE, IDENTITY) \
187+
}
188+
189+
#define __CLC_SUBGROUP_COLLECTIVE_REDUX(NAME, OP, REDUX_OP, TYPE, IDENTITY) \
190+
_CLC_DEF _CLC_OVERLOAD _CLC_CONVERGENT TYPE __CLC_APPEND( \
191+
__clc__Subgroup, NAME)(uint op, TYPE x, TYPE * carry) { \
192+
/* Fast path for warp reductions for sm_80+ */ \
193+
if (__nvvm_reflect("__CUDA_ARCH") >= 800 && op == Reduce) { \
194+
TYPE result = __nvvm_redux_sync_##REDUX_OP(x, __clc__membermask()); \
172195
*carry = result; \
173-
} /* For ExclusiveScan, shift and prepend identity */ \
174-
else if (op == ExclusiveScan) { \
175-
*carry = x; \
176-
result = __clc__SubgroupShuffleUp(x, 1); \
177-
if (sg_lid == 0) { \
178-
result = IDENTITY; \
179-
} \
196+
return result; \
180197
} \
181-
return result; \
198+
__CLC_SUBGROUP_COLLECTIVE_BODY(OP, TYPE, IDENTITY) \
182199
}
183200

184201
__CLC_SUBGROUP_COLLECTIVE(IAdd, __CLC_ADD, char, 0)
185202
__CLC_SUBGROUP_COLLECTIVE(IAdd, __CLC_ADD, uchar, 0)
186203
__CLC_SUBGROUP_COLLECTIVE(IAdd, __CLC_ADD, short, 0)
187204
__CLC_SUBGROUP_COLLECTIVE(IAdd, __CLC_ADD, ushort, 0)
188-
__CLC_SUBGROUP_COLLECTIVE(IAdd, __CLC_ADD, int, 0)
189-
__CLC_SUBGROUP_COLLECTIVE(IAdd, __CLC_ADD, uint, 0)
205+
__CLC_SUBGROUP_COLLECTIVE_REDUX(IAdd, __CLC_ADD, add, int, 0)
206+
__CLC_SUBGROUP_COLLECTIVE_REDUX(IAdd, __CLC_ADD, add, uint, 0)
190207
__CLC_SUBGROUP_COLLECTIVE(IAdd, __CLC_ADD, long, 0)
191208
__CLC_SUBGROUP_COLLECTIVE(IAdd, __CLC_ADD, ulong, 0)
192209
__CLC_SUBGROUP_COLLECTIVE(FAdd, __CLC_ADD, half, 0)
@@ -197,8 +214,8 @@ __CLC_SUBGROUP_COLLECTIVE(SMin, __CLC_MIN, char, CHAR_MAX)
197214
__CLC_SUBGROUP_COLLECTIVE(UMin, __CLC_MIN, uchar, UCHAR_MAX)
198215
__CLC_SUBGROUP_COLLECTIVE(SMin, __CLC_MIN, short, SHRT_MAX)
199216
__CLC_SUBGROUP_COLLECTIVE(UMin, __CLC_MIN, ushort, USHRT_MAX)
200-
__CLC_SUBGROUP_COLLECTIVE(SMin, __CLC_MIN, int, INT_MAX)
201-
__CLC_SUBGROUP_COLLECTIVE(UMin, __CLC_MIN, uint, UINT_MAX)
217+
__CLC_SUBGROUP_COLLECTIVE_REDUX(SMin, __CLC_MIN, min, int, INT_MAX)
218+
__CLC_SUBGROUP_COLLECTIVE_REDUX(UMin, __CLC_MIN, umin, uint, UINT_MAX)
202219
__CLC_SUBGROUP_COLLECTIVE(SMin, __CLC_MIN, long, LONG_MAX)
203220
__CLC_SUBGROUP_COLLECTIVE(UMin, __CLC_MIN, ulong, ULONG_MAX)
204221
__CLC_SUBGROUP_COLLECTIVE(FMin, __CLC_MIN, half, HALF_MAX)
@@ -209,15 +226,17 @@ __CLC_SUBGROUP_COLLECTIVE(SMax, __CLC_MAX, char, CHAR_MIN)
209226
__CLC_SUBGROUP_COLLECTIVE(UMax, __CLC_MAX, uchar, 0)
210227
__CLC_SUBGROUP_COLLECTIVE(SMax, __CLC_MAX, short, SHRT_MIN)
211228
__CLC_SUBGROUP_COLLECTIVE(UMax, __CLC_MAX, ushort, 0)
212-
__CLC_SUBGROUP_COLLECTIVE(SMax, __CLC_MAX, int, INT_MIN)
213-
__CLC_SUBGROUP_COLLECTIVE(UMax, __CLC_MAX, uint, 0)
229+
__CLC_SUBGROUP_COLLECTIVE_REDUX(SMax, __CLC_MAX, max, int, INT_MIN)
230+
__CLC_SUBGROUP_COLLECTIVE_REDUX(UMax, __CLC_MAX, umax, uint, 0)
214231
__CLC_SUBGROUP_COLLECTIVE(SMax, __CLC_MAX, long, LONG_MIN)
215232
__CLC_SUBGROUP_COLLECTIVE(UMax, __CLC_MAX, ulong, 0)
216233
__CLC_SUBGROUP_COLLECTIVE(FMax, __CLC_MAX, half, -HALF_MAX)
217234
__CLC_SUBGROUP_COLLECTIVE(FMax, __CLC_MAX, float, -FLT_MAX)
218235
__CLC_SUBGROUP_COLLECTIVE(FMax, __CLC_MAX, double, -DBL_MAX)
219236

237+
#undef __CLC_SUBGROUP_COLLECTIVE_BODY
220238
#undef __CLC_SUBGROUP_COLLECTIVE
239+
#undef __CLC_SUBGROUP_COLLECTIVE_REDUX
221240

222241
#define __CLC_GROUP_COLLECTIVE(NAME, OP, TYPE, IDENTITY) \
223242
_CLC_DEF _CLC_OVERLOAD _CLC_CONVERGENT TYPE __CLC_APPEND( \

0 commit comments

Comments
 (0)