@@ -62,8 +62,8 @@ function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
62
62
#
63
63
set (SRCS ${ORIG_SRCS} )
64
64
set (CXX_SRCS ${ORIG_SRCS} )
65
- list (FILTER SRCS EXCLUDE REGEX "\. (cc)|(cpp)$" )
66
- list (FILTER CXX_SRCS INCLUDE REGEX "\. (cc)|(cpp)$" )
65
+ list (FILTER SRCS EXCLUDE REGEX "\. (cc)|(cpp)|(hip) $" )
66
+ list (FILTER CXX_SRCS INCLUDE REGEX "\. (cc)|(cpp)|(hip) $" )
67
67
68
68
#
69
69
# Generate ROCm/HIP source file names from CUDA file names.
@@ -80,7 +80,7 @@ function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
80
80
set (CSRC_BUILD_DIR ${CMAKE_CURRENT_BINARY_DIR} /csrc )
81
81
add_custom_target (
82
82
hipify${NAME}
83
- COMMAND ${CMAKE_SOURCE_DIR} /cmake/hipify.py -p ${CMAKE_SOURCE_DIR} /csrc -o ${CSRC_BUILD_DIR} ${SRCS}
83
+ COMMAND ${Python_EXECUTABLE} ${ CMAKE_SOURCE_DIR}/cmake/hipify.py -p ${CMAKE_SOURCE_DIR} /csrc -o ${CSRC_BUILD_DIR} ${SRCS}
84
84
DEPENDS ${CMAKE_SOURCE_DIR} /cmake/hipify.py ${SRCS}
85
85
BYPRODUCTS ${HIP_SRCS}
86
86
COMMENT "Running hipify on ${NAME} extension source files." )
@@ -232,11 +232,26 @@ macro(set_gencode_flags_for_srcs)
232
232
"${multiValueArgs} " ${ARGN} )
233
233
234
234
foreach (_ARCH ${arg_CUDA_ARCHS} )
235
- string (REPLACE "." "" _ARCH "${_ARCH} " )
236
- set_gencode_flag_for_srcs (
237
- SRCS ${arg_SRCS}
238
- ARCH "compute_${_ARCH} "
239
- CODE "sm_${_ARCH} " )
235
+ # handle +PTX suffix: generate both sm and ptx codes if requested
236
+ string (FIND "${_ARCH} " "+PTX" _HAS_PTX )
237
+ if (NOT _HAS_PTX EQUAL -1 )
238
+ string (REPLACE "+PTX" "" _BASE_ARCH "${_ARCH} " )
239
+ string (REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH} " )
240
+ set_gencode_flag_for_srcs (
241
+ SRCS ${arg_SRCS}
242
+ ARCH "compute_${_STRIPPED_ARCH} "
243
+ CODE "sm_${_STRIPPED_ARCH} " )
244
+ set_gencode_flag_for_srcs (
245
+ SRCS ${arg_SRCS}
246
+ ARCH "compute_${_STRIPPED_ARCH} "
247
+ CODE "compute_${_STRIPPED_ARCH} " )
248
+ else ()
249
+ string (REPLACE "." "" _STRIPPED_ARCH "${_ARCH} " )
250
+ set_gencode_flag_for_srcs (
251
+ SRCS ${arg_SRCS}
252
+ ARCH "compute_${_STRIPPED_ARCH} "
253
+ CODE "sm_${_STRIPPED_ARCH} " )
254
+ endif ()
240
255
endforeach ()
241
256
242
257
if (${arg_BUILD_PTX_FOR_ARCH} )
@@ -255,15 +270,18 @@ endmacro()
255
270
#
256
271
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
257
272
# `<major>.<minor>[letter]` compute the "loose intersection" with the
258
- # `TGT_CUDA_ARCHS` list of gencodes.
273
+ # `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
274
+ # `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
275
+ # is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
276
+ # architecture in `SRC_CUDA_ARCHS`.
259
277
# The loose intersection is defined as:
260
278
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
261
279
# where `<=` is the version comparison operator.
262
280
# In other words, for each version in `TGT_CUDA_ARCHS` find the highest version
263
281
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
264
- # We have special handling for 9 .0a, if 9 .0a is in `SRC_CUDA_ARCHS` and 9 .0 is
265
- # in `TGT_CUDA_ARCHS` then we should remove 9 .0a from `SRC_CUDA_ARCHS` and add
266
- # 9 .0a to the result (and remove 9 .0 from TGT_CUDA_ARCHS).
282
+ # We have special handling for x .0a, if x .0a is in `SRC_CUDA_ARCHS` and x .0 is
283
+ # in `TGT_CUDA_ARCHS` then we should remove x .0a from `SRC_CUDA_ARCHS` and add
284
+ # x .0a to the result (and remove x .0 from TGT_CUDA_ARCHS).
267
285
# The result is stored in `OUT_CUDA_ARCHS`.
268
286
#
269
287
# Example:
@@ -272,36 +290,63 @@ endmacro()
272
290
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
273
291
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
274
292
#
293
+ # Example With PTX:
294
+ # SRC_CUDA_ARCHS="8.0+PTX"
295
+ # TGT_CUDA_ARCHS="9.0"
296
+ # cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
297
+ # OUT_CUDA_ARCHS="8.0+PTX"
298
+ #
275
299
function (cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS )
276
- list (REMOVE_DUPLICATES SRC_CUDA_ARCHS )
277
- set (TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS} )
300
+ set (_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS} " )
301
+ set (_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS} )
302
+
303
+ # handle +PTX suffix: separate base arch for matching, record PTX requests
304
+ set (_PTX_ARCHS )
305
+ foreach (_arch ${_SRC_CUDA_ARCHS} )
306
+ if (_arch MATCHES "\\ +PTX$" )
307
+ string (REPLACE "+PTX" "" _base "${_arch} " )
308
+ list (APPEND _PTX_ARCHS "${_base} " )
309
+ list (REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch} " )
310
+ list (APPEND _SRC_CUDA_ARCHS "${_base} " )
311
+ endif ()
312
+ endforeach ()
313
+ list (REMOVE_DUPLICATES _PTX_ARCHS )
314
+ list (REMOVE_DUPLICATES _SRC_CUDA_ARCHS )
278
315
279
- # if 9 .0a is in SRC_CUDA_ARCHS and 9 .0 is in CUDA_ARCHS then we should
280
- # remove 9 .0a from SRC_CUDA_ARCHS and add 9 .0a to _CUDA_ARCHS
316
+ # if x .0a is in SRC_CUDA_ARCHS and x .0 is in CUDA_ARCHS then we should
317
+ # remove x .0a from SRC_CUDA_ARCHS and add x .0a to _CUDA_ARCHS
281
318
set (_CUDA_ARCHS )
282
- if ("9.0a" IN_LIST SRC_CUDA_ARCHS )
283
- list (REMOVE_ITEM SRC_CUDA_ARCHS "9.0a" )
284
- if ("9.0" IN_LIST TGT_CUDA_ARCHS_ )
285
- list (REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0" )
319
+ if ("9.0a" IN_LIST _SRC_CUDA_ARCHS )
320
+ list (REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a" )
321
+ if ("9.0" IN_LIST TGT_CUDA_ARCHS )
322
+ list (REMOVE_ITEM _TGT_CUDA_ARCHS "9.0" )
286
323
set (_CUDA_ARCHS "9.0a" )
287
324
endif ()
288
325
endif ()
289
326
290
- list (SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING )
327
+ if ("10.0a" IN_LIST _SRC_CUDA_ARCHS )
328
+ list (REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a" )
329
+ if ("10.0" IN_LIST TGT_CUDA_ARCHS )
330
+ list (REMOVE_ITEM _TGT_CUDA_ARCHS "10.0" )
331
+ set (_CUDA_ARCHS "10.0a" )
332
+ endif ()
333
+ endif ()
334
+
335
+ list (SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING )
291
336
292
337
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
293
338
# is less or equal to ARCH (but has the same major version since SASS binary
294
339
# compatibility is only forward compatible within the same major version).
295
- foreach (_ARCH ${TGT_CUDA_ARCHS_ } )
340
+ foreach (_ARCH ${_TGT_CUDA_ARCHS } )
296
341
set (_TMP_ARCH )
297
342
# Extract the major version of the target arch
298
343
string (REGEX REPLACE "^([0-9]+)\\ ..*$" "\\ 1" TGT_ARCH_MAJOR "${_ARCH} " )
299
- foreach (_SRC_ARCH ${SRC_CUDA_ARCHS } )
344
+ foreach (_SRC_ARCH ${_SRC_CUDA_ARCHS } )
300
345
# Extract the major version of the source arch
301
346
string (REGEX REPLACE "^([0-9]+)\\ ..*$" "\\ 1" SRC_ARCH_MAJOR "${_SRC_ARCH} " )
302
- # Check major- version match AND version -less-or-equal
347
+ # Check version-less-or-equal, and allow PTX arches to match across majors
303
348
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH )
304
- if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR )
349
+ if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR )
305
350
set (_TMP_ARCH "${_SRC_ARCH} " )
306
351
endif ()
307
352
else ()
@@ -317,6 +362,18 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
317
362
endforeach ()
318
363
319
364
list (REMOVE_DUPLICATES _CUDA_ARCHS )
365
+
366
+ # reapply +PTX suffix to architectures that requested PTX
367
+ set (_FINAL_ARCHS )
368
+ foreach (_arch ${_CUDA_ARCHS} )
369
+ if (_arch IN_LIST _PTX_ARCHS )
370
+ list (APPEND _FINAL_ARCHS "${_arch} +PTX" )
371
+ else ()
372
+ list (APPEND _FINAL_ARCHS "${_arch} " )
373
+ endif ()
374
+ endforeach ()
375
+ set (_CUDA_ARCHS ${_FINAL_ARCHS} )
376
+
320
377
set (${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE )
321
378
endfunction ()
322
379
0 commit comments