diff --git a/CHANGELOG.md b/CHANGELOG.md index 245527e1a3..a418edc011 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,22 @@ # NVIDIA CUTLASS Changelog +## [3.5.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.1) (2024-07-25) + +- [Minimal SM90 WGMMA + TMA GEMM example in 100 lines of code](./examples/cute/tutorial/wgmma_sm90.cu) +- [Exposure of L2 `cache_hint`s in TMA copy atoms](./include/cute/arch/copy_sm90_tma.hpp#L48) +- Exposure of raster order and tile swizzle extent in [CUTLASS library profiler](./media/docs/profiler.md#GEMM), and +[example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu). +- [TMA store based and EVT supported epilogues](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for [Hopper pointer array batched kernels](./test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu). +- A new [`GemmSparseUniversal` API for CUTLASS 2.x Ampere kernels](./include/cutlass/gemm/device/gemm_sparse_universal.h) leveraging 2:4 structured sparsity and [support for LLM friendly tile sizes](./test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu). +- [CUDA host adapter](./include/cutlass/cuda_host_adapter.hpp) extensions to support TMA descriptor construction driver APIs. +- Inclusion of more [Hopper fprop, dgrad, and wgrad convolution kernels in CUTLASS library and profiler](./python/cutlass_library/generator.py). +- Support for residual add (beta != 0) in convolution kernels. +- A refactor of [include files throughout CUTLASS core directories](./include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](./test/self_contained_includes/CMakeLists.txt). +- [A guide for setting up VSCode to work well with CUTLASS](./media/docs/ide_setup.md) and [expanded code style guide](./media/docs/programming_guidelines.md). +- Better support for MSVC as a host compiler. +- Many performance optimizations, improvements, and bug fixes including fixes for FlashAttention-2. +- Optimal code generation with CUDA toolkit versions 12.4 and 12.5u1. + ## [3.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.0) (2024-04-09) - Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](./include/cute/atom/copy_traits_sm90_im2col.hpp) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3e764c89f0..71523d7f69 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -92,7 +92,7 @@ if(CUTLASS_NATIVE_CUDA) else() list(APPEND CUTLASS_CUDA_NVCC_FLAGS --std=c++17) endif() - + if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT) set(CMAKE_INSTALL_PREFIX install CACHE PATH "Default installation location." FORCE) endif() @@ -134,6 +134,16 @@ set(CUTLASS_ENABLE_PERFORMANCE ${CUTLASS_ENABLE_PROFILER} CACHE BOOL "Enable CUT set(CUTLASS_ENABLE_TESTS ${CUTLASS_ENABLE_TESTS_INIT} CACHE BOOL "Enable CUTLASS Tests") set(CUTLASS_ENABLE_GTEST_UNIT_TESTS ${CUTLASS_ENABLE_TESTS} CACHE BOOL "Enable CUTLASS GTest-based Unit Tests") set(CUTLASS_USE_SYSTEM_GOOGLETEST OFF CACHE BOOL "Use system/external installation of GTest") + +set(CUTLASS_USE_PACKED_TUPLE ON CACHE BOOL "If ON, make cute::tuple be new standard-layout tuple type; if OFF, use the original cute::tuple implementation that is _not_ standard-layout.") +if (CUTLASS_USE_PACKED_TUPLE) + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTE_USE_PACKED_TUPLE=1) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DCUTLASS_USE_PACKED_TUPLE=1") + message(STATUS "Make cute::tuple be the new standard-layout tuple type") +elseif() + message(STATUS "Use the original cute::tuple implementation that is _not_ standard-layout") +endif() + ################################################################################ set(CUTLASS_NVCC_ARCHS_SUPPORTED "") @@ -216,7 +226,7 @@ if (${CUTLASS_NVCC_VERBOSE}) endif() # -# CUTLASS NAMESPACE +# CUTLASS NAMESPACE # set(CUTLASS_NAMESPACE "cutlass" CACHE STRING "Top level namespace of CUTLASS") @@ -234,15 +244,15 @@ set(CUTLASS_ENABLE_F16C OFF CACHE BOOL "Enable F16C x86 extensions in host code. set(KERNEL_FILTER_FILE "" CACHE STRING "KERNEL FILTER FILE FULL PATH") -if (KERNEL_FILTER_FILE AND NOT CUTLASS_LIBRARY_KERNELS) +if (KERNEL_FILTER_FILE AND NOT CUTLASS_LIBRARY_KERNELS) # If a kernel filter file is specified, we want to generate and then # filter on the entire kernel set, not the default kernel - # (sub)set. The user may overried CUTLASS_LIBRRARY_KERNELS, in which + # (sub)set. The user may have overridden CUTLASS_LIBRRARY_KERNELS, in which # case the resulting kernel set will be the intersection of the two # options differenced against CUTLASS_LIBRARY_IGNORE_KERNELS. set(CUTLASS_LIBRARY_KERNELS_INIT "*") -else() - set(CUTLASS_LIBRARY_KERNELS_INIT "") +else() + set(CUTLASS_LIBRARY_KERNELS_INIT "") endif() if (KERNEL_FILTER_FILE) @@ -256,9 +266,10 @@ if(KERNEL_FILTER_FILE) message(STATUS "Full path of filter file: ${KERNEL_FILTER_FILE}") endif() -set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma delimited list of operation name filters. Default '' means all operations are enabled.") -set(CUTLASS_LIBRARY_KERNELS ${CUTLASS_LIBRARY_KERNELS_INIT} CACHE STRING "Comma delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If 'all' is specified, all kernels are enabled.") -set(CUTLASS_LIBRARY_IGNORE_KERNELS "" CACHE STRING "Comma delimited list of kernel names to exclude from build.") +set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma-delimited list of operation name filters. Default '' means all operations are enabled.") +set(CUTLASS_LIBRARY_KERNELS ${CUTLASS_LIBRARY_KERNELS_INIT} CACHE STRING "Comma-delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If the string 'all' is specified, all kernels are enabled.") +set(CUTLASS_LIBRARY_IGNORE_KERNELS "" CACHE STRING "Comma-delimited list of kernels to exclude from build. This option ONLY takes effect if CUTLASS_LIBRARY_KERNELS is set.") +set(CUTLASS_LIBRARY_EXCLUDE_KERNELS "" CACHE STRING "Comma-delimited list of kernels to exclude from build. This option always takes effect, whether or not CUTLASS_LIBRARY_KERNELS is set. It also can exclude kernels from the filter file (see KERNEL_FILTER_FILE).") ################################################################################ @@ -330,6 +341,11 @@ if (CUTLASS_ENABLE_TENSOR_CORE_MMA) list(APPEND CUTLASS_CUDA_FLAGS -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1) endif() +set(CUTLASS_PROFILER_DISABLE_REFERENCE OFF CACHE BOOL "Disable compilation of reference kernels in the CUTLASS profiler.") +if (CUTLASS_PROFILER_DISABLE_REFERENCE) + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_PROFILER_DISABLE_REFERENCE=1) +endif() + @@ -398,8 +414,8 @@ if(CUDA_COMPILER MATCHES "[Cc]lang") message(FATAL_ERROR "Clang CUDA compilation requires Clang CXX compilation. Currently CMAKE_CXX_COMPILER is ${CMAKE_CXX_COMPILER_ID}" ) endif() - # There are numerous Clang versions that can work with each CUDA toolkit and the - # the checks are not very useful so we are turning them off and using testing to + # There are numerous Clang versions that can work with each CUDA toolkit and the + # the checks are not very useful so we are turning them off and using testing to # ensure the various combinations work properly. list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-path=${CUDA_TOOLKIT_ROOT_DIR}) @@ -425,7 +441,14 @@ if(CUDA_COMPILER MATCHES "[Cc]lang") link_libraries(nvidia::cuda_driver) endif() -# Support for 128-bit integers if using NVIDIA C++ compiler +# Known gcc 8.1-8.3 SFINAE issue (fixed in gcc 8.4), check https://gcc.gnu.org/bugzilla/show_bug.cgi?id=87748 +# Also see https://github.com/NVIDIA/nccl/issues/835 for nvtx3.hpp +if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 8.1 AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS_EQUAL 8.3) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNVTX3_USE_CHECKED_OVERLOADS_FOR_GET=0") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DNVTX3_USE_CHECKED_OVERLOADS_FOR_GET=0") +endif() + +# Support for 128-bit integers if using NVIDIA C++ compiler if (${CMAKE_CXX_COMPILER_ID} MATCHES "PGI" OR ${CMAKE_CXX_COMPILER_ID} MATCHES "NVHPC") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Mint128 ") endif() @@ -433,24 +456,24 @@ endif() if (CMAKE_VERSION VERSION_GREATER_EQUAL 3.18) # CMake 3.18 added support for CUDA_ARCHITECTURES target property. We will use this # property for CMake 3.18+, so we request the NEW behavior for correct compatibility. - # https://cmake.org/cmake/help/v3.18/policy/CMP0104.html#policy:CMP0104 + # https://cmake.org/cmake/help/v3.18/policy/CMP0104.html#policy:CMP0104 cmake_policy(SET CMP0104 NEW) endif() if (MSVC) - + # MSVC by default does not apply the correct __cplusplus version as specified by the C++ standard - # because MSVC is not a completely compliant implementation. This option forces MSVC to use the + # because MSVC is not a completely compliant implementation. This option forces MSVC to use the # appropriate value given the requested --std option. This fixes a compilation issue mismatch # between GCC/Clang and MSVC. # # error : a constexpr function cannot have a nonliteral return type "dim3" - # + # # See https://developercommunity.visualstudio.com/t/msvc-incorrectly-defines-cplusplus/139261 set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zc:__cplusplus") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /Zc:__cplusplus") - + endif() # Some tests require this build option in order to link. @@ -488,7 +511,7 @@ function(cutlass_apply_cuda_gencode_flags TARGET) list(JOIN CODES "," CODES_STR) list(APPEND NVCC_FLAGS -gencode=arch=compute_${ARCH},code=[${CODES_STR}]) endforeach() - + if (NOT __SM_ARCHS) if (CUDA_COMPILER MATCHES "[Cc]lang") target_compile_options( @@ -523,7 +546,7 @@ function(cutlass_apply_cuda_gencode_flags TARGET) endfunction() -# Cache the flags so they are available when the function below is called anywhere globally. +# Cache the flags so they are available when the function below is called anywhere globally. set(__CUTLASS_CUDA_FLAGS ${CUTLASS_CUDA_FLAGS} CACHE INTERNAL "") set(__CUTLASS_CUDA_FLAGS_RELEASE ${CUTLASS_CUDA_FLAGS_RELEASE} CACHE INTERNAL "") @@ -694,6 +717,7 @@ if(NOT WIN32) "-Wl,-rpath,'$ORIGIN/../lib'" "-Wl,-rpath,'${CUDA_TOOLKIT_ROOT_DIR}/lib64'" "-Wl,-rpath,'${CUDA_TOOLKIT_ROOT_DIR}/lib'" + ${CMAKE_DL_LIBS} ) endif() @@ -757,24 +781,24 @@ set(CUTLASS_CTEST_TEMPLATE_FILE ${CMAKE_CURRENT_LIST_DIR}/cmake/CTestTestfile.co set(CUTLASS_CTEST_GENERATED_FILES "" CACHE INTERNAL "") function(cutlass_add_executable_tests NAME TARGET) -# -# Generates test rules for `make test`, `make test_all`, and `ctest` invoked from either the +# +# Generates test rules for `make test`, `make test_all`, and `ctest` invoked from either the # or the / after installation. -# +# # NAME: The base name for the test. Can be run with `make ` or `ctest -R 'c'`. # TARGET: The target corresponding to the executable under test. # DISABLE_EXECUTABLE_INSTALL_RULE: An option, if given, that disables creating an install rule for TARGET. # DEPENDS: A list of targets or files on which this test is dependent. # DEPENDEES: A list of targets which should depend on this test. # TEST_COMMAND_OPTIONS: A list of variables (i.e. by reference params) which contain command line arguments -# to pass to the test executable. A unique test is generated for each set of +# to pass to the test executable. A unique test is generated for each set of # options given. If this option is not used, a single test with no arguments is generated. -# TEST_COMMAND_OPTIONS_PREFIX: If provided, is added as a prefix to each TEST_COMMAND_OPTIONS value for +# TEST_COMMAND_OPTIONS_PREFIX: If provided, is added as a prefix to each TEST_COMMAND_OPTIONS value for # generating the full variable name to be referenced. # RESULT_CACHE_FILE: A file to be installed alongside the test executable with pre-computed # test results to speed up test runtime. -# TEST_SETS_SUPPORTED: A list of test set names these tests support. -# +# TEST_SETS_SUPPORTED: A list of test set names these tests support. +# set(options DISABLE_EXECUTABLE_INSTALL_RULE) set(oneValueArgs DISABLE_TESTS RESULT_CACHE_FILE TEST_COMMAND_OPTIONS_PREFIX) @@ -806,9 +830,9 @@ function(cutlass_add_executable_tests NAME TARGET) endif() if (NOT __DISABLE_EXECUTABLE_INSTALL_RULE AND CUTLASS_INSTALL_TESTS) - + # file(RELATIVE_PATH CMAKE_CURRENT_BINARY_RELATIVE_DIR ${CMAKE_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}) - + install( TARGETS ${TARGET} RUNTIME DESTINATION ${CUTLASS_TEST_INSTALL_BINDIR} @@ -822,7 +846,7 @@ function(cutlass_add_executable_tests NAME TARGET) ) endif() - + endif() if (NOT __TEST_COMMAND_OPTIONS) @@ -856,10 +880,10 @@ function(cutlass_add_executable_tests NAME TARGET) string(TOLOWER "${NAME}" TEST_NAME) endif() - # The following rigmarole is needed to deal with spaces and possible quotes in + # The following rigmarole is needed to deal with spaces and possible quotes in # command line arguments. The options are passed "by reference" as the actual # variable names holding the real options. We then expand these in a way that - # preserves any quotes. Note, they have to be in this order for it to work for + # preserves any quotes. Note, they have to be in this order for it to work for # all the use cases below. set(TEST_COMMAND_OPTIONS ${${__TEST_COMMAND_OPTIONS_PREFIX}${CMD_OPTIONS_VAR}}) @@ -889,7 +913,7 @@ function(cutlass_add_executable_tests NAME TARGET) endforeach() # To run the tests from an install package with tests enabled, we need to generate test files - # that don't rely on the current directory structure in build. + # that don't rely on the current directory structure in build. set(TEST_NAME c${NAME}) set(TEST_GEN_DIR ${CMAKE_CURRENT_BINARY_DIR}/ctest/${TEST_NAME}) @@ -906,14 +930,14 @@ function(cutlass_add_executable_tests NAME TARGET) # The following line imports the tests for immediate run via `make test`. include(${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake) - + set(CUTLASS_CTEST_GENERATED_FILES ${CUTLASS_CTEST_GENERATED_FILES};ctest/${TEST_NAME}/CTestTestfile.${TEST_NAME}.cmake CACHE INTERNAL "") if (CUTLASS_INSTALL_TESTS) - file(GENERATE - OUTPUT "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake" - INPUT "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake.in" + file(GENERATE + OUTPUT "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake" + INPUT "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake.in" ) install( @@ -971,19 +995,19 @@ endif() include(CMakePackageConfigHelpers) write_basic_package_version_file( - ${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfigVersion.cmake + ${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfigVersion.cmake COMPATIBILITY AnyNewerVersion) configure_file( ${CMAKE_CURRENT_SOURCE_DIR}/cmake/NvidiaCutlassConfig.cmake.in - ${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfig.cmake + ${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfig.cmake @ONLY ) install( - FILES - ${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfig.cmake - ${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfigVersion.cmake + FILES + ${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfig.cmake + ${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfigVersion.cmake DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/NvidiaCutlass/ ) diff --git a/README.md b/README.md index 865ffb76ff..9ac15f4165 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ ![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") -# CUTLASS 3.5 +# CUTLASS 3.5.1 -_CUTLASS 3.5 - April 2024_ +_CUTLASS 3.5.1 - July 2024_ CUTLASS is a collection of CUDA C++ template abstractions for implementing high-performance matrix-matrix multiplication (GEMM) and related computations at all levels @@ -41,9 +41,30 @@ and improves code composability and readability. More documentation specific to In addition to GEMMs, CUTLASS implements high-performance convolution via the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components. + # What's New in CUTLASS 3.5 -CUTLASS 3.5 is an update to CUTLASS adding: +CUTLASS 3.5.1 is an update to CUTLASS adding: + +- [Minimal SM90 WGMMA + TMA GEMM example in 100 lines of code](./examples/cute/tutorial/wgmma_sm90.cu). +- [Exposure of L2 `cache_hint`s in TMA copy atoms](./include/cute/arch/copy_sm90_tma.hpp#L48) +- Exposure of raster order and tile swizzle extent in [CUTLASS library profiler](./media/docs/profiler.md#GEMM), and +[example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu). +- [TMA store based and EVT supported epilogues](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for [Hopper pointer array batched kernels](./test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu). +- A new [`GemmSparseUniversal` API for CUTLASS 2.x Ampere kernels](./include/cutlass/gemm/device/gemm_sparse_universal.h) leveraging 2:4 structured sparsity and [support for LLM friendly tile sizes](./test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu). +- [CUDA host adapter](./include/cutlass/cuda_host_adapter.hpp) extensions to support TMA descriptor construction driver APIs. +- Inclusion of more [Hopper fprop, dgrad, and wgrad convolution kernels in CUTLASS library and profiler](./python/cutlass_library/generator.py). +- Support for residual add (beta != 0) in convolution kernels. +- A refactor of [include files throughout CUTLASS core directories](./include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](./test/self_contained_includes/CMakeLists.txt). +- [A guide for setting up VSCode to work well with CUTLASS](./media/docs/ide_setup.md) and [expanded code style guide](./media/docs/programming_guidelines.md). +- Better support for MSVC as a host compiler. +- Many performance optimizations, improvements, and bug fixes including fixes for FlashAttention-2. +- Optimal code generation with CUDA toolkit versions 12.4 and 12.5u1. +- NOTICE: + + Upcoming CUTLASS 3.6 release will include a breaking refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` API to bring it in line with `gemm::GemmUniversal`. After this, the 3.x convolution API will no longer be considered as a beta API. + + Upcoming CUTLASS 3.6 release will include a breaking refactor to the Hopper TMA pointer array batched epilogue in order to support grouped GEMMs. + +CUTLASS 3.5.0 is an update to CUTLASS adding: - Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](./include/cute/atom/copy_traits_sm90_im2col.hpp). + Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](./media/docs/gemm_api_3x.md). @@ -61,6 +82,7 @@ CUTLASS 3.5 is an update to CUTLASS adding: - Remove C++11 requirement on a few CUTLASS 2.x API header files. All CUTLASS files now require C++17. - Fixes to greatly reduce build warnings. - Updates and bugfixes from the community (thanks!) +- CUTLASS 3.5.1 is a minor update to CUTLASS containing small bug fixes and improvements, including fixes for FlashAttention-2 builds. Minimum requirements: diff --git a/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu b/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu index b559eb8aff..23c2d9f45f 100644 --- a/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu +++ b/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu @@ -162,7 +162,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M = using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 4>; // <- MMA Op tile M = 8, N = 8, K = 4 // This code section describes how threadblocks are scheduled on GPU -using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This code section describes ? using EpilogueOp = cutlass::epilogue::thread::LinearCombination< diff --git a/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu b/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu index a24fd0e4e2..34f682deb0 100644 --- a/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu +++ b/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu @@ -161,7 +161,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 64>; // <- warp tile M = using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 16>; // <- MMA Op tile M = 8, N = 8, K = 16 // This code section describes how threadblocks are scheduled on GPU -using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This code section describes the epilogue part of the kernel using EpilogueOp = cutlass::epilogue::thread::LinearCombination< diff --git a/examples/12_gemm_bias_relu/gemm_bias_relu.cu b/examples/12_gemm_bias_relu/gemm_bias_relu.cu index 97a21f2681..bca8e0ac74 100644 --- a/examples/12_gemm_bias_relu/gemm_bias_relu.cu +++ b/examples/12_gemm_bias_relu/gemm_bias_relu.cu @@ -84,7 +84,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M = using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8 // This code section describes how threadblocks are scheduled on GPU -using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // Define the epilogue operation as LinearCombinationRelu. This is approximately equal to // diff --git a/examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h b/examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h index eb04105f86..2206bac0e6 100644 --- a/examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h +++ b/examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h @@ -231,7 +231,7 @@ struct B2bFusedGroupedGemmRun host_tensor_ref_D1.at(i).sync_device(); ref_A0.at(i) = (host_tensor_A0.at(i).device_ref()); - ref_B0.at(i) = (host_tensor_B0.at(i).device_ref());; + ref_B0.at(i) = (host_tensor_B0.at(i).device_ref()); ref_C0.at(i) = (host_tensor_C0.at(i).device_ref()); if (alpha0 == ElementCompute(0)) //per-channel scale ref_Scale0.at(i) = (host_tensor_Scale0.at(i).device_ref()); @@ -340,7 +340,7 @@ struct B2bFusedGroupedGemmRun std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n"; for (int i = 0; i < problem_count; ++i) { - host_tensor_D1.at(i).sync_host();; + host_tensor_D1.at(i).sync_host(); // // Verify diff --git a/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu b/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu index 8b5b1b77cb..99d3cdb178 100644 --- a/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu +++ b/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu @@ -194,7 +194,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 16>; // <- warp tile M = using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8 // This code section describes how threadblocks are scheduled on GPU -using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This code section describes the epilogue part of the kernel using EpilogueOp = cutlass::epilogue::thread::LinearCombination< diff --git a/examples/15_ampere_sparse_tensorop_gemm/CMakeLists.txt b/examples/15_ampere_sparse_tensorop_gemm/CMakeLists.txt index d41d263a2f..02d3205889 100644 --- a/examples/15_ampere_sparse_tensorop_gemm/CMakeLists.txt +++ b/examples/15_ampere_sparse_tensorop_gemm/CMakeLists.txt @@ -33,6 +33,11 @@ cutlass_example_add_executable( ampere_sparse_tensorop_gemm.cu ) +cutlass_example_add_executable( + 15_ampere_sparse_tensorop_gemm_universal + ampere_sparse_tensorop_gemm_universal.cu + ) + cutlass_example_add_executable( 15_ampere_sparse_tensorop_gemm_with_visitor ampere_sparse_tensorop_gemm_with_visitor.cu diff --git a/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu b/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu index 9c1663972f..e92b717caa 100644 --- a/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu +++ b/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu @@ -84,7 +84,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 256>; // <- warp tile M = using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 128>; // <- MMA Op tile M = 16, N = 8, K = 128 // This code section describes how threadblocks are scheduled on GPU -using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This code section describes the epilogue part of the kernel using EpilogueOp = cutlass::epilogue::thread::LinearCombination< diff --git a/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_universal.cu b/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_universal.cu new file mode 100644 index 0000000000..dcab5ac144 --- /dev/null +++ b/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_universal.cu @@ -0,0 +1,329 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** +Please check example 07, 08 and 17 for the basics of dense tensor op gemm kernels. NVIDIA Ampere +architecture also supports structured sparse tensor op for tf32, fp16, int8 and int4. + +Sparse GEMM kernels needs to takes an additional E matrix which stores the meta data. The format of +meta data is different for every data types. CUTLASS templates can automatically infer it based on +input A and B. Check code below. + +Moreover, matrix E needs to be preprocessed so that it can use ldmatrix to load into the registers +efficiently. +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_sparse_universal.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/host_reorder.h" +#include "cutlass/util/host_uncompress.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" +#include "helper.h" + +// The code section below describes datatype for input, output matrices and computation between +// elements in input matrices. +using ElementAccumulator = int32_t; // <- data type of accumulator +using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations +using ElementInputA = cutlass::int4b_t; // <- data type of elements in input matrix A +using ElementInputB = cutlass::int4b_t; // <- data type of elements in input matrix B +using ElementOutput = int32_t; // <- data type of elements in output matrix D + +// The code section below describes matrix layout of input and output matrices. Row Major for +// Matrix A, Column Major for Matrix B and Row Major for Matrix C +using LayoutInputA = cutlass::layout::RowMajor; +using LayoutInputB = cutlass::layout::ColumnMajor; +using LayoutOutput = cutlass::layout::RowMajor; + +// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM +using MMAOp = cutlass::arch::OpClassTensorOp; + +// This code section describes CUDA SM architecture number +using SmArch = cutlass::arch::Sm80; + +// This code section describes the tile size a thread block will compute +using ShapeMMAThreadBlock = + cutlass::gemm::GemmShape<128, 128, 256>; // <- threadblock tile M = 128, N = 128, K = 256 +// This code section describes tile size a warp will compute +using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 256>; // <- warp tile M = 64, N = 64, K = 256 +// This code section describes the size of MMA op +using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 128>; // <- MMA Op tile M = 16, N = 8, K = 128 + +// This code section describes how threadblocks are scheduled on GPU +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + +// This code section describes the epilogue part of the kernel +using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized + // memory access. For a byte, it's 16 + // elements. This becomes the vector width of + // math instructions in the epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function + +// Number of pipelines you want to use +constexpr int NumStages = 3; + +using Gemm = cutlass::gemm::device::GemmSparseUniversal; + +// Data type and layout of meta data matrix E can be inferred from template Gemm. +using ElementInputE = typename Gemm::ElementE; +using LayoutInputE = cutlass::layout::RowMajor; +using ReorderedLayoutInputE = typename Gemm::LayoutE; + +// Blow property is defined in include/cutlass/arch/sp_mma_sm80.h +// 50% Sparsity on Ampere +constexpr int kSparse = Gemm::kSparse; +// How many elements of A are covered per ElementE +constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; +// The size of individual meta data +constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; + +int run() { + + const int length_m = 512; + const int length_n = 512; + const int length_k = 1024; + + // Create a tuple of problem size for matrix multiplication + cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k); + + // Initialize tensors using CUTLASS helper functions + cutlass::HostTensor tensor_a( + cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse)); // <- Create matrix A with dimensions M x (K / 2) + cutlass::HostTensor tensor_a_uncompressed( + problem_size.mk()); // <- Create uncompressed matrix A with dimensions M x K for reference computing + + cutlass::HostTensor tensor_b( + problem_size.kn()); // <- Create matrix B with dimensions K x N + cutlass::HostTensor tensor_c( + problem_size.mn()); // <- Create matrix C with dimensions M x N + cutlass::HostTensor tensor_d( + problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from + // CUTLASS kernel + cutlass::HostTensor tensor_ref_d( + problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from + // reference kernel + + // Create matrix E with dimensions M x (K / 2 / kElementsPerElementE). This one is used by reference computing. + cutlass::HostTensor tensor_e( + cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); + // Same size as the above. The above one needs to be reordered and stored in this one. + cutlass::HostTensor tensor_e_reordered( + cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); + + // Fill input and output matrices on host using CUTLASS helper functions + cutlass::reference::host::TensorFillRandomUniform( + tensor_a.host_view(), + 1, + ElementInputA(2), + ElementInputA(-2), + 0); // <- Fill matrix A on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_b.host_view(), + 1, + ElementInputB(2), + ElementInputB(-2), + 0); // <- Fill matrix B on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_c.host_view(), + 1, + ElementOutput(2), + ElementOutput(-2), + 0); // <- Fill matrix C on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomSparseMeta( + tensor_e.host_view(), + 1, + kMetaSizeInBits); // <- Fill matrix E on host with uniform-distribution random meta data + cutlass::reference::host::TensorFill( + tensor_d.host_view()); // <- fill matrix D on host with zeros + cutlass::reference::host::TensorFill( + tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros + + // Reorder the meta data matrix so that we can use ldmatrix to load them to tensor core + // instructions. + cutlass::reorder_meta(tensor_e_reordered.host_ref(), tensor_e.host_ref(), + {problem_size.m(), problem_size.n(), + problem_size.k() / kSparse / kElementsPerElementE}); + + // Copy data from host to GPU + tensor_a.sync_device(); + tensor_b.sync_device(); + tensor_c.sync_device(); + tensor_d.sync_device(); + tensor_e_reordered.sync_device(); + tensor_ref_d.sync_device(); + + // Initialize alpha and beta for dot product computation + ElementComputeEpilogue alpha = ElementComputeEpilogue(1); + ElementComputeEpilogue beta = ElementComputeEpilogue(0); + + // Split K dimension into 1 partitions + int split_k_slices = 2; + + // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch + // instantiated CUTLASS kernel + typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, // <- problem size of matrix multiplication + split_k_slices,// <- k-dimension split factor + {alpha, beta}, // <- tuple of alpha and beta + tensor_a.device_data(), // <- reference to matrix A on device + tensor_b.device_data(), // <- reference to matrix B on device + tensor_c.device_data(), // <- reference to matrix C on device + tensor_d.device_data(), // <- reference to matrix D on device + tensor_e_reordered.device_data(), // <- reference to matrix E on device + int64_t(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + tensor_a.layout().stride(0), + tensor_b.layout().stride(0), + tensor_c.layout().stride(0), + tensor_d.layout().stride(0), + tensor_e_reordered.layout().stride(0) + }; + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm_op; + + // Check the problem size is supported or not + cutlass::Status status = gemm_op.can_implement(arguments); + CUTLASS_CHECK(status); + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm_op.initialize(arguments, workspace.get()); + CUTLASS_CHECK(status); + + // Launch initialized CUTLASS kernel + status = gemm_op(); + CUTLASS_CHECK(status); + + // uncompress tensor_a based on meta data tensor_e. We need it for reference computing. + cutlass::uncompress(tensor_a_uncompressed.host_ref(), tensor_a.host_ref(), + tensor_e.host_ref(), problem_size.m(), problem_size.k()); + + // Create instantiation for host reference gemm kernel + cutlass::reference::host::Gemm + gemm_host; + + // Launch host reference gemm kernel + gemm_host(problem_size, + alpha, + tensor_a_uncompressed.host_ref(), + tensor_b.host_ref(), + beta, + tensor_c.host_ref(), + tensor_ref_d.host_ref()); + + // Copy output data from CUTLASS host for comparison + tensor_d.sync_host(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::host::TensorEquals( + tensor_d.host_view(), + tensor_ref_d.host_view()); + + std::cout << (passed ? "Passed" : "Failed") << std::endl; + + return (passed ? 0 : -1); +} + +int main() { + + bool notSupported = false; + + // Ampere Sparse Tensor Core operations exposed with mma.sync and ldmatrix are first available + // in CUDA 11.1. + // + // CUTLASS must be compiled with CUDA 11.1 Toolkit to run these examples. + + if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 1))) { + std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.1 Toolkit or later." << std::endl; + notSupported = true; + } + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (props.major * 10 + props.minor < 80) { + std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." + << std::endl; + notSupported = true; + } + + if (notSupported) { + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + return run(); +} diff --git a/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu b/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu index 612a666cdf..90aa44528e 100644 --- a/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu +++ b/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu @@ -94,7 +94,7 @@ using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 64>; // <- MMA Op tile M = 1 // This code section describes how threadblocks are scheduled on GPU using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; -using Operator = cutlass::arch::OpMultiplyAdd; +using Operator = cutlass::arch::OpMultiplyAddSaturate; // Number of pipelines you want to use constexpr int NumStages = 3; diff --git a/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu b/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu index 23fd6c183b..4e5fca1a03 100644 --- a/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu +++ b/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu @@ -138,7 +138,7 @@ using Gemm = typename cutlass::gemm::device::GemmWithKReduction< >; // Below is the reduction kernel used in the case of parallel split-k -using ReduceGemmSplitKShape = cutlass::MatrixShape<4, 64>;; +using ReduceGemmSplitKShape = cutlass::MatrixShape<4, 64>; using ReduceOp = cutlass::reduction::thread::ReduceAdd< ElementAccumulator, @@ -154,7 +154,7 @@ using ReduceGemmSplitKKernel = cutlass::reduction::kernel::ReduceSplitK< using ReduceGemmSplitK = cutlass::reduction::device::ReduceSplitK; -using ReduceVectorSplitKShape = cutlass::MatrixShape<1, 256>;; +using ReduceVectorSplitKShape = cutlass::MatrixShape<1, 256>; // This code section describes the epilogue part of the kernel, we use default value using DummyEpilogueOp = cutlass::epilogue::thread::LinearCombination< diff --git a/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu b/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu index 1ecd38ee9b..9e561cb6a2 100644 --- a/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu +++ b/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu @@ -258,7 +258,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 32, 16>; // <- warp tile M = using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8 // This code section describes how threadblocks are scheduled on GPU -using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This code section describes the epilogue part of the kernel using EpilogueOp = cutlass::epilogue::thread::LinearCombination< diff --git a/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm.cu b/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm.cu index 18375f6dd3..0a995bf929 100644 --- a/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm.cu +++ b/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm.cu @@ -221,7 +221,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<32, 32, 16>; // <- warp tile M = using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8 // This code section describes how threadblocks are scheduled on GPU -using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This code section describes the epilogue part of the kernel using EpilogueOp = cutlass::epilogue::thread::LinearCombination< diff --git a/examples/33_ampere_3xtf32_tensorop_symm/ampere_3xtf32_tensorop_symm.cu b/examples/33_ampere_3xtf32_tensorop_symm/ampere_3xtf32_tensorop_symm.cu index 4863ed93e7..22cb3286eb 100644 --- a/examples/33_ampere_3xtf32_tensorop_symm/ampere_3xtf32_tensorop_symm.cu +++ b/examples/33_ampere_3xtf32_tensorop_symm/ampere_3xtf32_tensorop_symm.cu @@ -193,7 +193,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 32, 16>; // <- warp tile M = using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8 // This code section describes how threadblocks are scheduled on GPU -using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This code section describes the epilogue part of the kernel using EpilogueOp = cutlass::epilogue::thread::LinearCombination< diff --git a/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu b/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu index 50af76a122..885aacebc1 100644 --- a/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu +++ b/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu @@ -215,7 +215,7 @@ using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>; // <- MMA Op tile M = 8 // 16, 8, 16 -> Ampere // This code section describes how threadblocks are scheduled on GPU -using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // Define the epilogue operation as LinearCombination. This is approximately equal to // diff --git a/examples/40_cutlass_py/conv2d.py b/examples/40_cutlass_py/conv2d.py index c0268c72f0..71e94259ff 100644 --- a/examples/40_cutlass_py/conv2d.py +++ b/examples/40_cutlass_py/conv2d.py @@ -118,7 +118,7 @@ conv_kind=cutlass_bindings.conv.Operator.fprop, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized, arch=cc, tile_description=tile_description, - A=A, B=B, C=C, stride_support=StrideSupport.Strided, + A=A, B=B, C=C, stride_support=StrideSupport.Unity, epilogue_functor=epilogue_functor ) diff --git a/examples/41_fused_multi_head_attention/gemm/mma_from_smem.h b/examples/41_fused_multi_head_attention/gemm/mma_from_smem.h index eecd860062..3e41274349 100644 --- a/examples/41_fused_multi_head_attention/gemm/mma_from_smem.h +++ b/examples/41_fused_multi_head_attention/gemm/mma_from_smem.h @@ -1799,7 +1799,7 @@ struct B2bGemm< if (rowIdx == 1) { lse_prefetched[colIdx] = accum_n < lse_extent ? lse[accum_n] - : platform::numeric_limits::infinity(); + : cutlass::platform::numeric_limits::infinity(); } accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]); ++colIdx; @@ -1938,7 +1938,7 @@ struct B2bGemm< if (rowIdx == 1) { lse_prefetched[colIdx] = accum_n < lse_extent ? lse[accum_n] - : platform::numeric_limits::infinity(); + : cutlass::platform::numeric_limits::infinity(); } accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]); ++colIdx; diff --git a/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu b/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu index 45b0f51008..f26f4da37d 100644 --- a/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu +++ b/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu @@ -35,19 +35,23 @@ This example demonstrate a simple way to instantiate and run a TF32 GEMM using the new CUTLASS 3.0 APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows: - 1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA) + 1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA) which are more efficient than the Ampere tensor core instructions. - 2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large + 2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous copies between thread blocks in a cluster. Another advantage is that TMA can load in FP32 data and convert them implicitly to TF32. 3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details). + 4. A simple way to tune the CTA rasterization direction and swizzle pattern of Hopper kernels. Both the + CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can + improve performance. + Examples: - $ ./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm --m=2048 --n=2048 --k=2048 + $ ./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm --m=2048 --n=2048 --k=2048 --rasterization=N --swizzle=2 */ #include @@ -63,6 +67,7 @@ #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" #include "cutlass/util/command_line.h" #include "cutlass/util/distribution.h" @@ -105,7 +110,7 @@ using OperatorClass = cutlass::arch::OpClassTensorOp; // O using TileShape = Shape<_128,_128,_32>; // Threadblock-level tile size using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size -using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default setting in the Collective Builder +using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default setting in the Collective Builder using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -175,6 +180,8 @@ cutlass::DeviceAllocation block_ /// Testbed utility types ///////////////////////////////////////////////////////////////////////////////////////////////// +using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions; + // Command line options parsing struct Options { @@ -183,12 +190,16 @@ struct Options { float alpha, beta; int iterations; int m, n, k; + RasterOrderOptions raster; + int swizzle; Options(): help(false), m(5120), n(4096), k(4096), alpha(1.f), beta(0.f), - iterations(1000) + iterations(1000), + raster(RasterOrderOptions::Heuristic), + swizzle(1) { } // Parses the command line @@ -206,6 +217,21 @@ struct Options { cmd.get_cmd_line_argument("alpha", alpha, 1.f); cmd.get_cmd_line_argument("beta", beta, 0.f); cmd.get_cmd_line_argument("iterations", iterations); + + char raster_char; + cmd.get_cmd_line_argument("raster", raster_char); + + if (raster_char == 'N' || raster_char == 'n') { + raster = RasterOrderOptions::AlongN; + } + else if (raster_char == 'M' || raster_char == 'm') { + raster = RasterOrderOptions::AlongM; + } + else if (raster_char == 'H' || raster_char == 'h') { + raster = RasterOrderOptions::Heuristic; + } + + cmd.get_cmd_line_argument("swizzle", swizzle, 1); } /// Prints the usage statement. @@ -220,6 +246,8 @@ struct Options { << " --k= Sets the K extent of the GEMM\n" << " --alpha= Epilogue scalar alpha\n" << " --beta= Epilogue scalar beta\n\n" + << " --raster= CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n" + << " --swizzle= CTA Rasterization swizzle\n\n" << " --iterations= Number of profiling iterations to perform.\n\n"; out @@ -294,10 +322,10 @@ bool initialize_block( /// Initialize operands to be used in the GEMM and reference GEMM void initialize(const Options &options) { - stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, Int<1>{})); - stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, Int<1>{})); - stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, Int<1>{})); - stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, Int<1>{})); + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1}); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1}); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1}); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1}); block_A.reset(options.m * options.k); block_B.reset(options.k * options.n); @@ -320,6 +348,10 @@ typename Gemm::Arguments args_from_options(const Options &options) {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} }; + arguments.scheduler.raster_order = options.raster; + // The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8) + arguments.scheduler.max_swizzle_size = options.swizzle; + return arguments; } @@ -408,7 +440,17 @@ int run(Options &options) result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + std::string raster = "Heuristic"; + + if (options.raster == RasterOrderOptions::AlongN) { + raster = "Along N"; + } + else if (options.raster == RasterOrderOptions::AlongM) { + raster = "Along M"; + } + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl; std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; std::cout << " GFLOPS: " << result.gflops << std::endl; } @@ -441,7 +483,6 @@ int main(int argc, char const **args) { << "later (compute capability 90 or greater).\n"; return 0; } - // // Parse options // diff --git a/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu b/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu index 1700ac5693..1e820ddb47 100644 --- a/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu +++ b/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu @@ -538,9 +538,8 @@ int main(int argc, char const **args) { std::cout << "This example requires a GPU of NVIDIA's Hopper Architecture or " << "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n"; - return 0; + return 0; } - // // Parse options // diff --git a/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu b/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu index fdb12c1f33..a736e5ce31 100644 --- a/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu +++ b/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu @@ -354,9 +354,8 @@ int main(int argc, char const **args) { std::cout << "This example requires a GPU of NVIDIA's Hopper Architecture or " << "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n"; - return 0; + return 0; } - // // Parse options // diff --git a/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu b/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu index 52a8c19c27..884f3535d0 100644 --- a/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu +++ b/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu @@ -627,7 +627,6 @@ int main(int argc, const char ** argv) { std::cerr << "This example requires a device with compute capability 90 or higher.\n"; notSupported = true; } - if (notSupported) { return EXIT_SUCCESS; // Do not fail CI checks on unsupported systems } diff --git a/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp b/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp index 04777d8a80..57053b0f9a 100644 --- a/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp +++ b/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp @@ -166,7 +166,7 @@ class GemmGather to_underlying_arguments(Arguments const& args, void* workspace) { (void) workspace; auto problem_shape = args.problem_shape; - if constexpr (detail::IF_SWAP_AB::value) { + if constexpr (detail::Has_SwapAB_v) { // swap M/N get<0>(problem_shape) = get<1>(args.problem_shape); get<1>(problem_shape) = get<0>(args.problem_shape); @@ -181,8 +181,7 @@ class GemmGather }; } - CUTLASS_HOST_DEVICE static - bool + static bool can_implement(Arguments const& args) { bool implementable = (args.mode == GemmUniversalMode::kGemm) or (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); diff --git a/examples/52_hopper_gather_scatter_fusion/scatter_epilogue.hpp b/examples/52_hopper_gather_scatter_fusion/scatter_epilogue.hpp index 323e960602..dc9c0df804 100644 --- a/examples/52_hopper_gather_scatter_fusion/scatter_epilogue.hpp +++ b/examples/52_hopper_gather_scatter_fusion/scatter_epilogue.hpp @@ -119,7 +119,7 @@ class EpilogueGatherScatter { } template - CUTLASS_HOST_DEVICE static bool + static bool can_implement( [[maybe_unused]] ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) { diff --git a/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu b/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu index 20a282b06c..d24c5f294a 100644 --- a/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu +++ b/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu @@ -750,7 +750,6 @@ int main(int argc, char const **argv) std::cerr << "This example requires a device with compute capability 90 or higher.\n"; notSupported = true; } - if (notSupported) { return EXIT_SUCCESS; // Do not fail CI checks on unsupported systems } diff --git a/examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm.cu b/examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm.cu index 3d3be226c3..726f6d222a 100644 --- a/examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm.cu +++ b/examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm.cu @@ -47,9 +47,13 @@ 4. This example shows all important fusions used by FP8 gemm kernels, i.e., scale factor for A, B, C, D tensor, the abs_max value of D tensor. + 5. A simple way to tune the CTA rasterization direction and swizzle pattern of Hopper kernels. Both the + CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can + improve performance. + Examples: - $ ./examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm --m=2048 --n=2048 --k=2048 + $ ./examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm --m=2048 --n=2048 --k=2048 --rasterization=N --swizzle=2 */ #include @@ -63,6 +67,7 @@ #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" #include "cutlass/epilogue/dispatch_policy.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" @@ -214,6 +219,8 @@ cutlass::HostTensor reference_abs_max_aux; /// Testbed utility types ///////////////////////////////////////////////////////////////////////////////////////////////// +using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions; + /// Result structure struct Result { @@ -273,7 +280,7 @@ bool initialize_tensor( } /// Initialize operands to be used in the GEMM and reference GEMM -void initialize(const Options &options) { +void initialize(const Options &options) { stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l)); @@ -346,7 +353,7 @@ void initialize(const Options &options) { } /// Populates a Gemm::Arguments structure from the given commandline options -typename Gemm::Arguments args_from_options(const Options &options) +typename Gemm::Arguments args_from_options(const Options &options) { typename Gemm::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, @@ -392,10 +399,14 @@ typename Gemm::Arguments args_from_options(const Options &options) fusion_args.amax_D_ptr = abs_max_D.device_data(); } + arguments.scheduler.raster_order = options.raster; + // The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8) + arguments.scheduler.max_swizzle_size = options.swizzle; + return arguments; } -bool verify(const Options &options) { +bool verify(const Options &options) { // // Compute reference output // @@ -468,7 +479,7 @@ bool verify(const Options &options) { /// Execute a given example GEMM computation template -int run(Options &options) +int run(Options &options) { initialize(options); @@ -518,7 +529,17 @@ int run(Options &options) result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + std::string raster = "Heuristic"; + + if (options.raster == RasterOrderOptions::AlongN) { + raster = "Along N"; + } + else if (options.raster == RasterOrderOptions::AlongM) { + raster = "Along M"; + } + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl; std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; std::cout << " GFLOPS: " << result.gflops << std::endl; } @@ -551,12 +572,11 @@ int main(int argc, char const **args) { << "later (compute capability 90 or greater).\n"; return 0; } - // // Parse options // - Options options; + Options options; options.parse(argc, args); diff --git a/examples/54_hopper_fp8_warp_specialized_gemm/hopper_fp8_commandline.hpp b/examples/54_hopper_fp8_warp_specialized_gemm/hopper_fp8_commandline.hpp index 94d72356ff..96d8794d8e 100644 --- a/examples/54_hopper_fp8_warp_specialized_gemm/hopper_fp8_commandline.hpp +++ b/examples/54_hopper_fp8_warp_specialized_gemm/hopper_fp8_commandline.hpp @@ -30,6 +30,7 @@ **************************************************************************************************/ // Command line options parsing +template struct Options { bool help = false; @@ -41,6 +42,8 @@ struct Options { bool save_amax = true; int iterations = 1000; int m = 1024, n = 512, k = 1024, l = 1; + RasterOrderOptions raster; + int swizzle; // Parses the command line void parse(int argc, char const **args) { @@ -66,6 +69,21 @@ struct Options { cmd.get_cmd_line_argument("save_aux", save_aux, true); cmd.get_cmd_line_argument("save_amax", save_amax, true); cmd.get_cmd_line_argument("iterations", iterations); + + char raster_char; + cmd.get_cmd_line_argument("raster", raster_char); + + if (raster_char == 'N' || raster_char == 'n') { + raster = RasterOrderOptions::AlongN; + } + else if (raster_char == 'M' || raster_char == 'm') { + raster = RasterOrderOptions::AlongM; + } + else if (raster_char == 'H' || raster_char == 'h') { + raster = RasterOrderOptions::Heuristic; + } + + cmd.get_cmd_line_argument("swizzle", swizzle, 1); } /// Prints the usage statement. @@ -89,6 +107,8 @@ struct Options { << " --device_scale= Copy scalars to device memory before kernel launch (default: false)\n" << " --save_aux= Save the pre-activation as an auxiliary tensor (default: true)\n" << " --save_amax= Save the pre-scaled max absolute value of any fp8 outputs (aux and/or D) (default: true)\n" + << " --raster= CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n" + << " --swizzle= CTA Rasterization swizzle\n\n" << " --iterations= Number of profiling iterations to perform.\n\n"; out diff --git a/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu b/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu index cfab823b4a..28baae260c 100644 --- a/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu +++ b/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu @@ -687,7 +687,6 @@ int main(int argc, char const **args) { << "later (compute capability 90 or greater).\n"; return 0; } - // // Parse options // diff --git a/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu b/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu index 465c0a41f2..7a191ce2d8 100644 --- a/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu +++ b/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu @@ -99,7 +99,7 @@ using TileShape = Shape<_256,_128,_64>; // T using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; // Kernel to launch -using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; // Epilogue to launch using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -492,7 +492,6 @@ int main(int argc, char const **args) { << "later (compute capability 90 or greater).\n"; return 0; } - // // Parse options // diff --git a/examples/56_hopper_ptr_array_batched_gemm/CMakeLists.txt b/examples/56_hopper_ptr_array_batched_gemm/CMakeLists.txt index 0a4e69566a..1f59ceb8a1 100644 --- a/examples/56_hopper_ptr_array_batched_gemm/CMakeLists.txt +++ b/examples/56_hopper_ptr_array_batched_gemm/CMakeLists.txt @@ -30,10 +30,10 @@ set(TEST_SQUARE --m=2048 --n=2048 --k=2048 -l=10 --iterations=1) # Square problem sizes set(TEST_SQUARE_LARGE_BATCH --m=2048 --n=2048 --k=2048 -l=500 --iterations=1) # Square problem sizes -set(TEST_EPILOGUE --alpha=0.5 --beta=0.7 --iterations=1) # Default problem sizes +set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=1) # Default problem sizes set(TEST_EPILOGUE_LARGE_BATCH --alpha=1.5 --beta=2.0 -l=500 --iterations=1) # Default problem sizes -set(TEST_EPILOGUE_OP --beta=0.7 --iterations=1) # Default problem sizes w/ Epilogue Op test +set(TEST_EPILOGUE_OP --beta=0.5 --iterations=1) # Default problem sizes w/ Epilogue Op test set(TEST_EPILOGUE_OP_LARGE_BATCH --alpha=1.5 -l=500 --iterations=1) # Default problem sizes w/ Epilogue Op test set(TEST_SMALLK --m=2048 --n=5120 --k=128 --l=5 --iterations=1) # Small-k problem sizes diff --git a/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu b/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu index 2a737e1e98..f94679568a 100644 --- a/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu +++ b/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu @@ -32,7 +32,7 @@ /*! \file \brief Hopper Grouped GEMM example using CUTLASS 3 APIs for NVIDIA Hopper architecture. - This example demonstrates an implementation of Grouped GEMM using a TMA + GMMA + This example demonstrates an implementation of Grouped GEMM using a TMA + GMMA warp-specialized cooperative kernel. For this example all scheduling work is performed on the device. The new feature showcased in this example is on-the-fly modification of TMA descriptors @@ -42,7 +42,7 @@ $ ./examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm --m=2048 --n=2048 --k=2048 --groups=10 - The above example command makes all 10 groups to be sized at the given m, n, k sizes. + The above example command makes all 10 groups to be sized at the given m, n, k sizes. Skipping any of the problem dimensions randomizes it across the different groups. Same applies for alpha and beta values that are randomized across the different groups. @@ -117,7 +117,7 @@ constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // A using ElementAccumulator = float; // Element type for internal accumulation using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag -using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size +using TileShape = Shape<_256,_128,_128>; // Threadblock-level tile size using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum; // Kernel to launch @@ -163,10 +163,10 @@ using DeviceGemmReference = cutlass::reference::device::Gemm< ElementAccumulator, ElementAccumulator>; -using StrideA = typename Gemm::GemmKernel::UnderlyingStrideA; -using StrideB = typename Gemm::GemmKernel::UnderlyingStrideB; -using StrideC = typename Gemm::GemmKernel::UnderlyingStrideC; -using StrideD = typename Gemm::GemmKernel::UnderlyingStrideD; +using StrideA = typename Gemm::GemmKernel::InternalStrideA; +using StrideB = typename Gemm::GemmKernel::InternalStrideB; +using StrideC = typename Gemm::GemmKernel::InternalStrideC; +using StrideD = typename Gemm::GemmKernel::InternalStrideD; // Host-side allocations std::vector offset_A; @@ -226,7 +226,7 @@ struct Options { std::string benchmark_path; std::vector problem_sizes_host; int const tma_alignment_bits = 128; - int const alignment = tma_alignment_bits / cutlass::sizeof_bits::value; + int const alignment = tma_alignment_bits / cutlass::sizeof_bits::value; // Parses the command line void parse(int argc, char const **args) { @@ -438,10 +438,10 @@ void allocate(const Options &options) { total_elements_C += elements_C; total_elements_D += elements_D; - stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, Int<1>{}))); - stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, Int<1>{}))); - stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, Int<1>{}))); - stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, Int<1>{}))); + stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1})); + stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); + stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1})); + stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1})); } @@ -456,7 +456,7 @@ void allocate(const Options &options) { /// Initialize operands to be used in the GEMM and reference GEMM void initialize(const Options &options) { - + uint64_t seed = 2020; problem_sizes.reset(options.groups); @@ -695,7 +695,6 @@ int main(int argc, char const **args) { << "later (compute capability 90 or greater).\n"; return 0; } - // // Parse options // diff --git a/examples/58_ada_fp8_gemm/ada_fp8_gemm.cu b/examples/58_ada_fp8_gemm/ada_fp8_gemm.cu index 0a6c034d19..79bead365b 100644 --- a/examples/58_ada_fp8_gemm/ada_fp8_gemm.cu +++ b/examples/58_ada_fp8_gemm/ada_fp8_gemm.cu @@ -97,7 +97,7 @@ using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationGenericWith cutlass::epilogue::thread::ReLu, ElementOutput, ElementAuxOutput, - 128 / cutlass::sizeof_bits::value, + 8, ElementAccumulator, ElementAccumulator >; @@ -106,7 +106,7 @@ template using Gemm_ = cutlass::gemm::device::GemmUniversalWithAbsMax< ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, - cutlass::gemm::GemmShape<128, 256, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::gemm::GemmShape<128, 64, 128>, cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, EpilogueOutputOp, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages, kAlignmentA, kAlignmentB, MathOperator >; diff --git a/examples/cute/tutorial/CMakeLists.txt b/examples/cute/tutorial/CMakeLists.txt index 1e4dad5f16..b427d9368c 100644 --- a/examples/cute/tutorial/CMakeLists.txt +++ b/examples/cute/tutorial/CMakeLists.txt @@ -53,3 +53,8 @@ cutlass_example_add_executable( tiled_copy.cu ) +cutlass_example_add_executable( + wgmma_sm90 + wgmma_sm90.cu +) + diff --git a/examples/cute/tutorial/sgemm_sm80.cu b/examples/cute/tutorial/sgemm_sm80.cu index e1211aac67..5ae0bf0f8b 100644 --- a/examples/cute/tutorial/sgemm_sm80.cu +++ b/examples/cute/tutorial/sgemm_sm80.cu @@ -153,12 +153,12 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler, // Allocate the accumulators -- same size as the projected data Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N) - CUTE_STATIC_ASSERT_V( shape(tCrA) == shape(tCsA)); // (MMA,MMA_M,MMA_K) - CUTE_STATIC_ASSERT_V( shape(tCrB) == shape(tCsB)); // (MMA,MMA_N,MMA_K) - CUTE_STATIC_ASSERT_V( shape(tCrC) == shape(tCgC)); // (MMA,MMA_M,MMA_N) - CUTE_STATIC_ASSERT_V(size<1>(tCgC) == size<1>(tCsA)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(tCgC) == size<1>(tCsB)); // MMA_N - CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // MMA_K + CUTE_STATIC_ASSERT_V(( shape(tCrA) == take<0,3>(shape(tCsA)))); // (MMA,MMA_M,MMA_K) + CUTE_STATIC_ASSERT_V(( shape(tCrB) == take<0,3>(shape(tCsB)))); // (MMA,MMA_N,MMA_K) + CUTE_STATIC_ASSERT_V(( shape(tCrC) == take<0,3>(shape(tCgC)))); // (MMA,MMA_M,MMA_N) + CUTE_STATIC_ASSERT_V((size<1>(tCgC) == size<1>(tCsA))); // MMA_M + CUTE_STATIC_ASSERT_V((size<2>(tCgC) == size<1>(tCsB))); // MMA_N + CUTE_STATIC_ASSERT_V((size<2>(tCsA) == size<2>(tCsB))); // MMA_K // Clear the accumulators clear(tCrC); @@ -358,7 +358,7 @@ gemm_nt(int m, int n, int k, alpha, beta); } -// Setup params for a NT GEMM +// Setup params for a TN GEMM template void @@ -391,10 +391,10 @@ gemm_tn(int m, int n, int k, auto bP = Int<3>{}; // Pipeline // Define the smem layouts (static) - auto sA_atom = make_layout(make_shape ( bM, bK), - make_stride(Int<1>{}, bM+Int<1>{})); // (m,k) -> smem_idx; padded m-major - auto sB_atom = make_layout(make_shape ( bN, bK), - make_stride(Int<1>{}, bN+Int<1>{})); // (n,k) -> smem_idx; padded n-major + auto sA_atom = make_layout(make_shape ( bM, bK), + make_stride(Int<1>{}, bM+Int<1>{})); // (m,k) -> smem_idx; padded m-major + [[maybe_unused]] auto sB_atom = make_layout(make_shape ( bN, bK), + make_stride(Int<1>{}, bN+Int<1>{})); // (n,k) -> smem_idx; padded n-major auto sA = tile_to_shape(sA_atom, make_shape(bM, bK, bP)); auto sB = tile_to_shape(sA_atom, make_shape(bN, bK, bP)); auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx diff --git a/examples/cute/tutorial/wgmma_sm90.cu b/examples/cute/tutorial/wgmma_sm90.cu new file mode 100644 index 0000000000..0baa494a37 --- /dev/null +++ b/examples/cute/tutorial/wgmma_sm90.cu @@ -0,0 +1,562 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#include +#include +#include + +#include +#include + +#include + +#include "cutlass/cluster_launch.hpp" +#include "cutlass/arch/barrier.h" +#include "cutlass/pipeline/sm90_pipeline.hpp" + +#include "cutlass/util/print_error.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/helper_cuda.hpp" +#include "cutlass/arch/mma_sm90.h" +#include "cutlass/device_kernel.h" + +using namespace cute; + +template // (N,K,P) +struct SharedStorage +{ + array_aligned> smem_A; + array_aligned> smem_B; + + uint64_t tma_barrier[size<2>(SmemLayoutA{})]; + uint64_t mma_barrier[size<2>(SmemLayoutA{})]; +}; + +template +__global__ static +__launch_bounds__(decltype(size(TiledMma{}))::value) +void +gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler, + TA const* A, CUTLASS_GRID_CONSTANT TmaA const tma_a, + TB const* B, CUTLASS_GRID_CONSTANT TmaB const tma_b, + TC * C, CStride dC, TiledMma mma, + Alpha alpha, Beta beta) +{ + // Preconditions + CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K) + CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K) + + static_assert(is_static::value); + static_assert(is_static::value); + + CUTE_STATIC_ASSERT_V(size<0>(SmemLayoutA{}) == size<0>(cta_tiler)); // BLK_M + CUTE_STATIC_ASSERT_V(size<0>(SmemLayoutB{}) == size<1>(cta_tiler)); // BLK_N + CUTE_STATIC_ASSERT_V(size<1>(SmemLayoutA{}) == size<2>(cta_tiler)); // BLK_K + CUTE_STATIC_ASSERT_V(size<1>(SmemLayoutB{}) == size<2>(cta_tiler)); // BLK_K + + CUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC)); // dC strides for shape MN + + // + // Full and Tiled Tensors + // + + // Represent the full tensors + auto [M, N, K] = shape_MNK; + Tensor mA = tma_a.get_tma_tensor(make_shape(M,K)); // (M,K) TMA Tensor + Tensor mB = tma_b.get_tma_tensor(make_shape(N,K)); // (N,K) TMA Tensor + Tensor mC = make_tensor(make_gmem_ptr(C), make_shape(M,N), dC); // (M,N) + + // Get the appropriate blocks for this thread block + auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k) + Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) + Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) + + // Shared memory tensors + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& smem = *reinterpret_cast(shared_memory); + Tensor sA = make_tensor(make_smem_ptr(smem.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(smem.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Partition the copying of A and B tiles + // + // TUTORIAL: + // These are TMA partitionings, which have a dedicated custom partitioner. + // The Int<0>, Layout<_1> indicates that the TMAs are not multicasted. + // Any multicasting must be in conformance with tma_x constructed with make_tma_atom on host. + // The group_modes<0,2> transforms the (X,Y,Z)-shaped tensors into ((X,Y),Z)-shaped tensors + // with the understanding that the TMA is responsible for everything in mode-0. + // The tma_partition reorders and offsets mode-0 according to the tma_x atom and the multicast info. + // + + auto [tAgA, tAsA] = tma_partition(tma_a, Int<0>{}, Layout<_1>{}, + group_modes<0,2>(sA), group_modes<0,2>(gA)); // (TMA,k) and (TMA,PIPE) + + auto [tBgB, tBsB] = tma_partition(tma_b, Int<0>{}, Layout<_1>{}, + group_modes<0,2>(sB), group_modes<0,2>(gB)); // (TMA,k) and (TMA,PIPE) + + // The TMA is responsible for copying everything in mode-0 of tAsA and tBsB + constexpr int kTmaTransactionBytes = CUTE_STATIC_V(size<0>(tAsA)) * sizeof(TA) + + CUTE_STATIC_V(size<0>(tBsB)) * sizeof(TB); + + // + // PREFETCH + // + + auto K_PIPE_MAX = size<1>(tAsA); + + // Total count of tiles + int k_tile_count = size<1>(tAgA); + // Current tile index in gmem to read from + int k_tile = 0; + + // Initialize Barriers + int warp_idx = cutlass::canonical_warp_idx_sync(); + int lane_predicate = cute::elect_one_sync(); + uint64_t* producer_mbar = smem.tma_barrier; + uint64_t* consumer_mbar = smem.mma_barrier; + + using ProducerBarType = cutlass::arch::ClusterTransactionBarrier; // TMA + using ConsumerBarType = cutlass::arch::ClusterBarrier; // MMA + CUTE_UNROLL + for (int pipe = 0; pipe < K_PIPE_MAX; ++pipe) { + if ((warp_idx == 0) && lane_predicate) { + ProducerBarType::init(&producer_mbar[pipe], 1); + ConsumerBarType::init(&consumer_mbar[pipe], 128); + } + } + // Ensure barrier init is complete on all CTAs + cluster_sync(); + + // Start async loads for all pipes + CUTE_UNROLL + for (int pipe = 0; pipe < K_PIPE_MAX; ++pipe) + { + if ((warp_idx == 0) && lane_predicate) + { + // Set expected Tx Bytes after each reset / init + ProducerBarType::arrive_and_expect_tx(&producer_mbar[pipe], kTmaTransactionBytes); + copy(tma_a.with(producer_mbar[pipe]), tAgA(_,k_tile), tAsA(_,pipe)); + copy(tma_b.with(producer_mbar[pipe]), tBgB(_,k_tile), tBsB(_,pipe)); + } + --k_tile_count; + ++k_tile; + } + + // + // Define A/B partitioning and C accumulators + // + // TUTORIAL: + // The tCrA and tCrB are actually Tensors of MMA Descriptors constructed as views of SMEM. + // The MMA Descriptor generation is automatic via inspection and validation of the SMEM Layouts. + // Because the MMA reads directly from SMEM and the fragments are descriptors rather than registers, + // there is no need for copy(tCsA, tCrA) in the mainloop. + // + + ThrMMA thr_mma = mma.get_thread_slice(threadIdx.x); + Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N) + + // Allocate accumulators and clear them + Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N) + clear(tCrC); + + // Allocate "fragments" + Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + // + // PIPELINED MAIN LOOP + // + // TUTORIAL: + // Rather than interleaving the stages and instructions like in SM70 and SM80, + // the SM90 mainloops rely on explicit producer-consumer synchronization + // on the purely async instructions TMA and MMA. + // More advanced pipeline and warp-specialization strategies are available in CUTLASS mainloops. + // + + // A PipelineState is a circular pipe index [.index()] and a pipe phase [.phase()] + // that flips each cycle through K_PIPE_MAX. + auto write_state = cutlass::PipelineState(); // TMA writes + auto read_state = cutlass::PipelineState(); // MMA reads + + CUTE_NO_UNROLL + while (k_tile_count > -K_PIPE_MAX) + { + // Wait for Producer to complete + int read_pipe = read_state.index(); + ProducerBarType::wait(&producer_mbar[read_pipe], read_state.phase()); + + // MMAs to cover 1 K_TILE + warpgroup_arrive(); + gemm(mma, tCrA(_,_,_,read_pipe), tCrB(_,_,_,read_pipe), tCrC); // (V,M) x (V,N) => (V,M,N) + warpgroup_commit_batch(); + + // Wait for all MMAs in a K_TILE to complete + warpgroup_wait<0>(); + + // Notify that consumption is done + ConsumerBarType::arrive(&consumer_mbar[read_pipe]); + ++read_state; + + if ((warp_idx == 0) && lane_predicate) + { + int pipe = write_state.index(); + // Wait for Consumer to complete consumption + ConsumerBarType::wait(&consumer_mbar[pipe], write_state.phase()); + // Set expected Tx Bytes after each reset / init + ProducerBarType::arrive_and_expect_tx(&producer_mbar[pipe], kTmaTransactionBytes); + copy(tma_a.with(producer_mbar[pipe]), tAgA(_,k_tile), tAsA(_,pipe)); + copy(tma_b.with(producer_mbar[pipe]), tBgB(_,k_tile), tBsB(_,pipe)); + ++write_state; + } + --k_tile_count; + ++k_tile; + } + + // + // Epilogue (unpredicated) + // + + axpby(alpha, tCrC, beta, tCgC); +} + +// Setup params for an NT GEMM +template +void +gemm_nt(int m, int n, int k, + Alpha alpha, + TA const* A, int ldA, + TB const* B, int ldB, + Beta beta, + TC * C, int ldC, + cudaStream_t stream = 0) +{ + // Define shapes (dynamic) + auto M = int(m); + auto N = int(n); + auto K = int(k); + auto prob_shape = make_shape(M, N, K); // (M, N, K) + + // Define TN strides (mixed) + auto dA = make_stride(Int<1>{}, ldA); // (dM, dK) + auto dB = make_stride(Int<1>{}, ldB); // (dN, dK) + auto dC = make_stride(Int<1>{}, ldC); // (dM, dN) + + // Define CTA tile sizes (static) + auto bM = Int<128>{}; + auto bN = Int<128>{}; + auto bK = Int< 64>{}; + auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K) + auto bP = Int< 3>{}; // Pipeline + + // Define the smem layouts (static) + auto sA = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, make_shape(bM,bK,bP)); + auto sB = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, make_shape(bN,bK,bP)); + + // Define the MMA + TiledMMA tiled_mma = make_tiled_mma(SM90_64x64x16_F16F16F16_SS{}); + + // Define the TMAs + // Create Global memory tensors for TMA inspection + Tensor mA = make_tensor(A, make_shape(M,K), dA); + Tensor mB = make_tensor(B, make_shape(N,K), dB); + + // Create TMA Atoms with the desired copy operation on the source and destination + Copy_Atom tmaA = make_tma_atom(SM90_TMA_LOAD{}, mA, sA(_,_,0), make_shape(bM,bK)); + Copy_Atom tmaB = make_tma_atom(SM90_TMA_LOAD{}, mB, sB(_,_,0), make_shape(bN,bK)); + + // + // Setup and Launch + // + + // Launch parameter setup + int smem_size = int(sizeof(SharedStorage)); + dim3 dimBlock(size(tiled_mma)); + dim3 dimCluster(2, 1, 1); + dim3 dimGrid(round_up(size(ceil_div(m, bM)), dimCluster.x), + round_up(size(ceil_div(n, bN)), dimCluster.y)); + cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size}; + + void const* kernel_ptr = reinterpret_cast( + &gemm_device); + + CUTE_CHECK_ERROR(cudaFuncSetAttribute( + kernel_ptr, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + + // Kernel Launch + cutlass::Status status = cutlass::launch_kernel_on_cluster(params, kernel_ptr, + prob_shape, cta_tiler, + A, tmaA, + B, tmaB, + C, dC, tiled_mma, + alpha, beta); + CUTE_CHECK_LAST(); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "Error: Failed at kernel Launch" << std::endl; + } +} + +// Setup params for a TN GEMM +template +void +gemm_tn(int m, int n, int k, + Alpha alpha, + TA const* A, int ldA, + TB const* B, int ldB, + Beta beta, + TC * C, int ldC, + cudaStream_t stream = 0) +{ + // Define shapes (dynamic) + auto M = int(m); + auto N = int(n); + auto K = int(k); + auto prob_shape = make_shape(M, N, K); // (M, N, K) + + // Define TN strides (mixed) + auto dA = make_stride(ldA, Int<1>{}); // (dM, dK) + auto dB = make_stride(ldB, Int<1>{}); // (dN, dK) + auto dC = make_stride(Int<1>{}, ldC); // (dM, dN) + + // Define CTA tile sizes (static) + auto bM = Int<128>{}; + auto bN = Int<128>{}; + auto bK = Int< 64>{}; + auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K) + auto bP = Int<3>{}; // Pipeline + + // Define the smem layouts (static) + auto sA = tile_to_shape(GMMA::Layout_K_SW128_Atom{}, make_shape(bM,bK,bP)); + auto sB = tile_to_shape(GMMA::Layout_K_SW128_Atom{}, make_shape(bN,bK,bP)); + + // Define the MMA + TiledMMA tiled_mma = make_tiled_mma(SM90_64x64x16_F16F16F16_SS{}); + + // Define the TMAs + // Create Global memory tensors for TMA inspection + Tensor mA = make_tensor(A, make_shape(M,K), dA); + Tensor mB = make_tensor(B, make_shape(N,K), dB); + + // Create TMA Atoms with the desired copy operation on the source and destination + Copy_Atom tmaA = make_tma_atom(SM90_TMA_LOAD{}, mA, sA(_,_,0), make_shape(bM,bK)); + Copy_Atom tmaB = make_tma_atom(SM90_TMA_LOAD{}, mB, sB(_,_,0), make_shape(bN,bK)); + + // + // Setup and Launch + // + + // Launch parameter setup + int smem_size = int(sizeof(SharedStorage)); + dim3 dimBlock(size(tiled_mma)); + dim3 dimCluster(2, 1, 1); + dim3 dimGrid(round_up(size(ceil_div(m, bM)), dimCluster.x), + round_up(size(ceil_div(n, bN)), dimCluster.y)); + cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size}; + + void const* kernel_ptr = reinterpret_cast( + &gemm_device); + + CUTE_CHECK_ERROR(cudaFuncSetAttribute( + kernel_ptr, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + + // Kernel Launch + cutlass::Status status = cutlass::launch_kernel_on_cluster(params, kernel_ptr, + prob_shape, cta_tiler, + A, tmaA, + B, tmaB, + C, dC, tiled_mma, + alpha, beta); + CUTE_CHECK_LAST(); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "Error: Failed at kernel Launch" << std::endl; + } +} + +template +void +gemm(char transA, char transB, int m, int n, int k, + Alpha alpha, + TA const* A, int ldA, + TB const* B, int ldB, + Beta beta, + TC * C, int ldC, + cudaStream_t stream = 0) +{ + if (transA == 'N' && transB == 'T') { + return gemm_nt(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream); + } else + if (transA == 'T' && transB == 'N') { + return gemm_tn(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream); + } + assert(false && "Not implemented"); +} + +int main(int argc, char** argv) +{ + + cudaDeviceProp props; + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (props.major != 9) { + std::cout << "This example requires NVIDIA's Hopper Architecture GPU with compute capability 90a\n" << std::endl; + return 0; + } + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + + int m = 512; + if (argc >= 2) + sscanf(argv[1], "%d", &m); + + int n = 256; + if (argc >= 3) + sscanf(argv[2], "%d", &n); + + int k = 1024; + if (argc >= 4) + sscanf(argv[3], "%d", &k); + + char transA = 'N'; + if (argc >= 5) + sscanf(argv[4], "%c", &transA); + + char transB = 'T'; + if (argc >= 6) + sscanf(argv[5], "%c", &transB); + + using TA = cute::half_t; + using TB = cute::half_t; + using TC = cute::half_t; + using TI = cute::half_t; + + TI alpha = TI(1.0f); + TI beta = TI(0.0f); + + thrust::host_vector h_A(m*k); + thrust::host_vector h_B(n*k); + thrust::host_vector h_C(m*n); + + // Initialize the tensors + for (int j = 0; j < m*k; ++j) h_A[j] = TA(int((rand() % 2) ? 1 : -1)); + for (int j = 0; j < n*k; ++j) h_B[j] = TB(int((rand() % 2) ? 1 : -1)); + for (int j = 0; j < m*n; ++j) h_C[j] = TC(0); + + thrust::device_vector d_A = h_A; + thrust::device_vector d_B = h_B; + thrust::device_vector d_C = h_C; + + double gflops = (2.0*m*n*k) * 1e-9; + + const int timing_iterations = 100; + GPU_Clock timer; + + int ldA = 0, ldB = 0, ldC = m; + + if (transA == 'N') { + ldA = m; + } else if (transA == 'T') { + ldA = k; + } else { + assert(false); + } + + if (transB == 'N') { + ldB = k; + } else if (transB == 'T') { + ldB = n; + } else { + assert(false); + } + + // Run once + d_C = h_C; + gemm(transA, transB, m, n, k, + alpha, + d_A.data().get(), ldA, + d_B.data().get(), ldB, + beta, + d_C.data().get(), ldC); + CUTE_CHECK_LAST(); + thrust::host_vector cute_result = d_C; + + // Timing iterations + timer.start(); + for (int i = 0; i < timing_iterations; ++i) { + gemm(transA, transB, m, n, k, + alpha, + d_A.data().get(), ldA, + d_B.data().get(), ldB, + beta, + d_C.data().get(), ldC); + } + double cute_time = timer.seconds() / timing_iterations; + CUTE_CHECK_LAST(); + printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000); + +#else + + std::cout << "CUTLASS_ARCH_MMA_SM90_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl; +#endif + + return 0; + +} diff --git a/include/cute/algorithm/axpby.hpp b/include/cute/algorithm/axpby.hpp index df9605b770..339743f491 100644 --- a/include/cute/algorithm/axpby.hpp +++ b/include/cute/algorithm/axpby.hpp @@ -32,7 +32,7 @@ #include -#include +#include #include namespace cute diff --git a/include/cute/algorithm/clear.hpp b/include/cute/algorithm/clear.hpp index f738b35b61..1c7dd5a334 100644 --- a/include/cute/algorithm/clear.hpp +++ b/include/cute/algorithm/clear.hpp @@ -31,9 +31,7 @@ #pragma once #include - -#include - +#include #include namespace cute diff --git a/include/cute/algorithm/cooperative_copy.hpp b/include/cute/algorithm/cooperative_copy.hpp index 7873071084..b2be11717f 100644 --- a/include/cute/algorithm/cooperative_copy.hpp +++ b/include/cute/algorithm/cooperative_copy.hpp @@ -1,51 +1,117 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ +* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +* SPDX-License-Identifier: BSD-3-Clause +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* 1. Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* +* 2. Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* +* 3. Neither the name of the copyright holder nor the names of its +* contributors may be used to endorse or promote products derived from +* this software without specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +* +**************************************************************************************************/ #pragma once #include #include - #include -#include +#include #include namespace cute { +template +CUTE_HOST_DEVICE void +naive_cooperative_copy(uint32_t const& tid, + Tensor const& src, + Tensor & dst) +{ + auto N = size(src); + if (tid < N) { + uint32_t upper_bound = (N / NumThreads) * NumThreads; + CUTE_UNROLL + for (uint32_t i = 0; i < upper_bound; i += NumThreads) { // All in-bounds + dst[tid + i] = src[tid + i]; + } + if (N % NumThreads != 0) { // Likely static condition + uint32_t final_idx = tid + upper_bound; + if (final_idx < N) { // Final in-bounds + dst[final_idx] = src[final_idx]; + } + } + } +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE void +naive_cooperative_copy(uint32_t const& tid, + Tensor const& src, + Tensor && dst) +{ + return naive_cooperative_copy(tid, src, dst); +} + +// A heuristic to determine a "good" permutation of two tensors for later vectorization and thr-assignment +template +CUTE_HOST_DEVICE constexpr +auto +heuristic_permutation(Tensor const& a, + Tensor const& b) +{ + constexpr bool swizzleA = get_swizzle_t::num_bits != 0 or + get_swizzle_t::num_bits != 0; + constexpr bool swizzleB = get_swizzle_t::num_bits != 0 or + get_swizzle_t::num_bits != 0; + auto a_inv = right_inverse(get_nonswizzle_portion(a.layout())); + auto b_inv = right_inverse(get_nonswizzle_portion(b.layout())); + + constexpr uint8_t scoreA = (uint8_t(swizzleA) << 2) | + (uint8_t(is_smem::value) << 1) | + (uint8_t(size(a_inv) > size(b_inv)) << 0); + + constexpr uint8_t scoreB = (uint8_t(swizzleB) << 2) | + (uint8_t(is_smem::value) << 1) | + (uint8_t(size(b_inv) > size(a_inv)) << 0); + + if constexpr (scoreA >= scoreB) { + return a_inv; + } else { + return b_inv; + } +} + // cooperative_copy(thr_idx, src, dst) -// Use NumThreads to copy src to dst with element vectorization up to MaxVecBits. +// Use NumThreads to copy Tensor src to Tensor dst with element-wise vectorization up to MaxVecBits. // @pre 0 <= @a tid < NumThreads // @pre Tensors @a src and @a dst are aligned up to MaxVecBits. +// That is, pointers and dynamic strides are assumed to be aligned up to MaxVecBits. // template const& src, Tensor & dst) { - // Assumes the shapes are static, can generalize + // Assumes the shapes are static, can generalize/fallback + CUTE_STATIC_ASSERT_V(is_static{} && is_static{}); CUTE_STATIC_ASSERT_V(size(src) == size(dst)); - // Assumes the types are the same, can generalize - static_assert(sizeof_bits_v == sizeof_bits_v); + // Assumes the types are the same, can generalize/fallback + static_assert(cute::is_same::value); static_assert(MaxVecBits == sizeof_bits_v || MaxVecBits == 8 || MaxVecBits == 16 || MaxVecBits == 32 || MaxVecBits == 64 || MaxVecBits == 128, "Expected MaxVecBits to be value size or 8 or 16 or 32 or 64 or 128 for alignment and performance."); // Check that the tensors are likely shared across threads: either gmem or smem static_assert((is_gmem::value || is_smem::value), - "cooperative_copy expects shared gmem or smem source tensor."); + "cooperative_copy expects shared gmem or smem source tensor."); static_assert((is_gmem::value || is_smem::value), - "cooperative_copy expects shared gmem or smem destination tensor."); - + "cooperative_copy expects shared gmem or smem destination tensor."); // Precondition on tid in DEBUG assert(tid < NumThreads); + // Precondition on pointer alignment in DEBUG + assert(is_byte_aligned(raw_pointer_cast(src.data()))); + assert(is_byte_aligned(raw_pointer_cast(dst.data()))); - // Fallback - slow path, naive copy, vectorization disabled - if constexpr(size(SrcLayout{}) % NumThreads != 0) { - int index = static_cast(tid); - CUTE_UNROLL - for(int i = 0; i < ceil_div(size(SrcLayout{}), NumThreads); i++) { - if(index < size(SrcLayout{})) { - dst[index] = src[index]; +#if 0 + if (thread0()) { + print(" "); print("cooperative_copy\n"); + print(" "); print("NumThreads: "); print(NumThreads); print("\n"); + print(" "); print("MaxVecBits: "); print(MaxVecBits); print("\n"); + print(" "); print("src: "); print(src); print("\n"); + print(" "); print("dst: "); print(dst); print("\n"); } - index += NumThreads; - } - } else { - // Fast path with vectorization +#ifdef __CUDA_ARCH__ + __syncthreads(); +#endif +#endif - // Precondition on pointer alignment in DEBUG - assert(is_byte_aligned(raw_pointer_cast(src.data()))); - assert(is_byte_aligned(raw_pointer_cast(dst.data()))); - constexpr int elem_bits = sizeof_bits_v; + // The common layout of the two tensors that can be vectorized over elements and threads + // vidx -> coord + auto common_layout = heuristic_permutation(src, dst); - // - // Determine val+thr vectorization based on src/dst size and number of threads - // NOTE: This heuristic promotes parallelization over vectorization - // + // Apply + // (V, rest) + Tensor src_a = coalesce(logical_divide(src, common_layout), Shape<_1,_1>{}); + Tensor dst_a = coalesce(logical_divide(dst, common_layout), Shape<_1,_1>{}); - // The number of elements that can be vectorized in values - constexpr int common_elem = decltype(max_common_vector(src, dst))::value; - constexpr int common_bits = common_elem * elem_bits; - constexpr int total_elem = decltype(size(src))::value; - constexpr int total_bits = total_elem * elem_bits; - static_assert(total_bits % NumThreads == 0); - constexpr int total_bits_per_thr = total_bits / NumThreads; - // If there are too many threads to allow a full elem copy, trunc the thrs and use elem_bits - constexpr int max_vec_bits_by_thr = cute::max(elem_bits, total_bits_per_thr); - - // Cap the vectorization to the common bits, the max_vec_bits_by_thr, and the MaxVecBits - constexpr int vec_bits = cute::min(common_bits, max_vec_bits_by_thr, static_cast(MaxVecBits)); - // Convert back to number of elements, safe_div - static_assert((vec_bits % elem_bits) == 0); - constexpr int vec_elem = vec_bits / elem_bits; + // + // Determine vectorization of elems and thrs based on src/dst size and number of threads + // NOTE: This heuristic promotes parallelization over vectorization + // - // Use only part of threads if there's not enough work for all threads - constexpr int vec_thrs = (total_elem % (vec_elem * NumThreads) == 0) - ? NumThreads - : (total_elem / vec_elem); - static_assert(vec_thrs <= NumThreads); + // The number of elements and number of bits + constexpr int elem_bits = sizeof_bits_v; + constexpr int total_elem = size(SrcLayout{}); - // The common layout of the two tensors that can be vectorized over threads - // vidx -> coord - auto common_layout = max_common_layout(get_nonswizzle_portion(src.layout()), - get_nonswizzle_portion(dst.layout())); + // The number of elements that can be vectorized in values + constexpr int common_elem = decltype(max_common_vector(src_a, dst_a))::value; - // Scale up the common_layout to cover the entire tensors - // vidx -> coord - auto full_perm = tile_to_shape(make_layout(common_layout), size(src)); +#if 0 + if (thread0()) { + print(" "); print("common_layout: "); print(common_layout); print("\n"); + print(" "); print("src_a: "); print(src_a); print("\n"); + print(" "); print("dst_a: "); print(dst_a); print("\n"); + } +#ifdef __CUDA_ARCH__ + __syncthreads(); +#endif +#endif - // Create the Tiler - // ((vid,tid),iter) - auto layout_vt = logical_divide(full_perm, Layout, Int>>{}); + // + if constexpr (total_elem % NumThreads != 0) { + // Not attempting to find a partitioning pattern, fallback to dynamically indexed slowpath - // Apply and slice - Tensor src_v = src.compose(layout_vt)(make_coord(_,tid),_); - Tensor dst_v = dst.compose(layout_vt)(make_coord(_,tid),_); + if constexpr (common_elem > 1 && MaxVecBits > elem_bits) { + // If the vectorization is non-trivial and divides the maximum vectorizations, then vectorize + constexpr auto max_align_src = elem_bits * decltype(max_alignment(src_a.layout()))::value; + constexpr auto max_align_dst = elem_bits * decltype(max_alignment(dst_a.layout()))::value; + constexpr auto vec_bits = gcd(max_align_src, max_align_dst, MaxVecBits); + using VecType = uint_bit_t; + + static_assert(vec_bits % elem_bits == 0, "Expected divisibility"); + static_assert((vec_bits >= 8), "No support for subbyte copying"); + + Tensor src_v = recast(src_a); + Tensor dst_v = recast(dst_a); + +#if 0 + if (thread0()) { + print(" "); print("cooperative_copy -- naive\n"); + print(" "); print("src_v: "); print(src_v); print("\n"); + print(" "); print("dst_v: "); print(dst_v); print("\n"); + } +#ifdef __CUDA_ARCH__ + __syncthreads(); +#endif +#endif + + naive_cooperative_copy(tid, src_v, dst_v); + } else { + naive_cooperative_copy(tid, src_a, dst_a); + } + } else { + // If the tensors can be equally partitioned by the threads, + // compute vectorization widths in elements and threads. + + // If there are too many threads to allow a full vectorized copy, trunc the vectorization + constexpr int total_bits = total_elem * elem_bits; + constexpr int max_bits_per_thr = total_bits / NumThreads; + // At least elem_bits, at most common_bits + constexpr int common_bits = common_elem * elem_bits; + constexpr int vec_bits = cute::max(elem_bits, cute::gcd(common_bits, int(MaxVecBits), max_bits_per_thr)); // Should account for vec_bits < 8 and/or vec_elem <= 1 // And also account for subbyte types, which could cause race conditions // Want to ENFORCE sufficient vectorization in those cases - static_assert((vec_bits >= 8), "No support for subbyte copying"); + static_assert(vec_bits % elem_bits == 0, "Expected divisibility"); + static_assert(vec_bits >= 8, "No support for subbyte copying"); + using VecType = uint_bit_t; + constexpr int vec_elem = vec_bits / elem_bits; + + constexpr int vec_thrs = cute::min(int(NumThreads), total_elem / vec_elem); + + // + // Determine the partitioning patterns for the vec_elems and vec_thrs + // + + // Distribute the rest of the V*T to some consistent portion outside of the common_layout, if needed + auto common_domain_src = domain_distribute(shape(src_a), Int{}); + auto common_domain_dst = domain_distribute(shape(dst_a), Int{}); + + // Make sure for now, could fall back here instead + CUTE_STATIC_ASSERT_V(size(common_domain_src) == Int{}); + CUTE_STATIC_ASSERT_V(compatible(common_domain_src, common_domain_dst) || + compatible(common_domain_dst, common_domain_src)); + // Use the "more specific" domain for the extra elements of V*T + auto common_domain = conditional_return(compatible(common_domain_src, common_domain_dst), + common_domain_dst, common_domain_src); + + // Construct the tiler + auto tiler_vt = common_domain.with_shape(Int{}, Int{}); + + // Apply and slice + Tensor src_v = logical_divide(src_a, tiler_vt)(make_coord(_,tid),_); + Tensor dst_v = logical_divide(dst_a, tiler_vt)(make_coord(_,tid),_); #if 0 - if (thread0()) { - print(" "); print("cooperative_copy -- vec\n"); - print(" "); print("NumThreads: "); print(NumThreads); print("\n"); - print(" "); print("MaxVecBits: "); print(MaxVecBits); print("\n"); - print(" "); print("src: "); print(src); print("\n"); - print(" "); print("dst: "); print(dst); print("\n"); - print(" "); print("common_layout: "); print(common_layout); print("\n"); - print(" "); print("full_perm: "); print(full_perm); print("\n"); - print(" "); print("Used vector: "); print(vec_elem); print("\n"); - print(" "); print("Used threads: "); print(vec_thrs); print("\n"); - print(" "); print("layout_vt: "); print(layout_vt); print("\n"); - print(" "); print("src.compose(layout_vt): "); print(src.compose(layout_vt)); print("\n"); - print(" "); print("dst.compose(layout_vt): "); print(dst.compose(layout_vt)); print("\n"); - print(" "); print("src_v: "); print(src_v); print("\n"); - print(" "); print("dst_v: "); print(dst_v); print("\n"); - print(" "); print("recast(src_v): "); print(recast(src_v)); print("\n"); - print(" "); print("recast(dst_v): "); print(recast(dst_v)); print("\n"); - } + if (thread0()) { + print(" "); print("cooperative_copy -- vec\n"); + print(" "); print("Used vector: "); print(vec_elem); print("\n"); + print(" "); print("Used threads: "); print(vec_thrs); print("\n"); + print(" "); print("tiler_vt: "); print(tiler_vt); print("\n"); + print(" "); print("src_v: "); print(src_v); print("\n"); + print(" "); print("dst_v: "); print(dst_v); print("\n"); + print(" "); print("recast(src_v): "); print(recast(src_v)); print("\n"); + print(" "); print("recast(dst_v): "); print(recast(dst_v)); print("\n"); + } #ifdef __CUDA_ARCH__ - __syncthreads(); + __syncthreads(); #endif #endif - // If we're using all threads (static) or the tid is in in-range (dynamic) - if (vec_thrs >= NumThreads or tid < vec_thrs) { + // If we're using all threads (static) or the tid is in-range (dynamic) + if (vec_thrs == NumThreads or tid < vec_thrs) { return copy_if(TrivialPredTensor{}, recast(src_v), recast(dst_v)); } } } +// Default max-vectorization size to value_type size template @@ -184,7 +300,10 @@ cooperative_copy(uint32_t const& tid, return cooperative_copy(tid, src, dst); } +// // Accept mutable temporaries +// + template @@ -197,9 +316,7 @@ cooperative_copy(uint32_t const& tid, return cooperative_copy(tid, src, dst); } -// Accept mutable temporaries -template CUTE_HOST_DEVICE diff --git a/include/cute/algorithm/cooperative_gemm.hpp b/include/cute/algorithm/cooperative_gemm.hpp index b83881590b..da03bfbd11 100644 --- a/include/cute/algorithm/cooperative_gemm.hpp +++ b/include/cute/algorithm/cooperative_gemm.hpp @@ -39,7 +39,7 @@ #include #include -#include +#include namespace cute { @@ -76,29 +76,15 @@ cooperative_gemm_predication(ThrMMA const& thr_mma, using TypeB = typename TB::value_type; using TypeC = typename TC::value_type; - // Original, static size of the problem - auto M = size<0>(sC); - auto N = size<1>(sC); - auto K = size<1>(sA); - - // Block size of the compute tile - auto BLK_M = tile_size<0>(thr_mma); - auto BLK_N = tile_size<1>(thr_mma); - auto BLK_K = tile_size<2>(thr_mma); - // // MMA Partitioning // - // Round the layout extents up to BLK_X to satisfy MMA partitioning safety - Tensor rounded_sA = sA.compose(make_shape(round_up(M, BLK_M), round_up(K, BLK_K))); - Tensor rounded_sB = sB.compose(make_shape(round_up(N, BLK_N), round_up(K, BLK_K))); - Tensor rounded_sC = sC.compose(make_shape(round_up(M, BLK_M), round_up(N, BLK_N))); + // Partition the sA, sB, and sC tiles across the threads for the MMA + Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K) + Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K) + Tensor tCsC = thr_mma.partition_C(sC); // (MMA,MMA_M,MMA_N) - // Partition the sA and sB tiles across the threads for the MMA - Tensor tCsA = thr_mma.partition_A(rounded_sA); // (MMA,MMA_M,MMA_K) - Tensor tCsB = thr_mma.partition_B(rounded_sB); // (MMA,MMA_N,MMA_K) - Tensor tCsC = thr_mma.partition_C(rounded_sC); // (MMA,MMA_M,MMA_N) // Create register tensors for the MMA to operate on Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K) Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K) @@ -109,9 +95,6 @@ cooperative_gemm_predication(ThrMMA const& thr_mma, print(" sA: "); print( sA); print("\n"); print(" sB: "); print( sB); print("\n"); print(" sC: "); print( sC); print("\n"); - print("r_sA: "); print(rounded_sA); print("\n"); - print("r_sB: "); print(rounded_sB); print("\n"); - print("r_sC: "); print(rounded_sC); print("\n"); print(thr_mma); print("tCsA: "); print(tCsA); print("\n"); print("tCsB: "); print(tCsB); print("\n"); @@ -127,8 +110,8 @@ cooperative_gemm_predication(ThrMMA const& thr_mma, // // Create coordinate tensors for the problem - Tensor cA = make_identity_tensor(shape(rounded_sA)); // (M,K) -> (m,k) - Tensor cB = make_identity_tensor(shape(rounded_sB)); // (N,K) -> (n,k) + Tensor cA = make_identity_tensor(shape(sA)); // (M,K) -> (m,k) + Tensor cB = make_identity_tensor(shape(sB)); // (N,K) -> (n,k) // Repeat partitioning with thr_mma Tensor tCcA = thr_mma.partition_A(cA); // (MMA,MMA_M,MMA_K) -> (m,k) @@ -222,7 +205,7 @@ cooperative_gemm_predication(ThrMMA const& thr_mma, // // Create coordinate tensors for the problem - Tensor cC = make_identity_tensor(shape(rounded_sC)); // (M,N) -> (m,n) + Tensor cC = make_identity_tensor(shape(sC)); // (M,N) -> (m,n) // Repeat partitioning with thr_mma Tensor tCcC = thr_mma.partition_C(cC); // (MMA,MMA_M,MMA_N) -> (m,n) diff --git a/include/cute/algorithm/copy.hpp b/include/cute/algorithm/copy.hpp index 50a092d02e..2a37995eea 100644 --- a/include/cute/algorithm/copy.hpp +++ b/include/cute/algorithm/copy.hpp @@ -34,7 +34,7 @@ #include -#include +#include #include #include @@ -199,14 +199,14 @@ copy_vec(Tensor const& src, { static_assert(sizeof_bits_v >= 8 && sizeof_bits_v % 8 == 0, "Expected a vectorization type of at least a byte."); - using SrcType = typename SrcEngine::element_type; - using DstType = typename DstEngine::element_type; - if constexpr (sizeof_bits_v == sizeof_bits_v && + using SrcType = typename SrcEngine::value_type; + using DstType = typename DstEngine::value_type; + if constexpr (cute::is_same::value && sizeof_bits_v > sizeof_bits_v) { // Preserve volatility of Src/Dst types. - using SrcVecType = conditional_t, VecType const volatile, VecType const>; - using DstVecType = conditional_t, VecType volatile, VecType >; + using SrcVecType = conditional_t, VecType const volatile, VecType const>; + using DstVecType = conditional_t, VecType volatile, VecType >; Tensor src_v = recast(src); Tensor dst_v = recast(dst); @@ -264,22 +264,22 @@ copy(AutoVectorizingCopyWithAssumedAlignment const&, { constexpr int vec_elem = decltype(max_common_vector(src, dst))::value; + constexpr int max_align_src = decltype(max_alignment(src.layout()))::value; + constexpr int max_align_dst = decltype(max_alignment(dst.layout()))::value; + constexpr int max_align = gcd(vec_elem, max_align_src, max_align_dst); + constexpr int src_bits = sizeof_bits::value; - // When layouts are static, accept vec_bits up to 128 - // When layouts are dynamic, accept vec_bits up to MaxVecBits - constexpr int vec_bits = (is_static::value && is_static::value) ? - cute::min(vec_elem * src_bits, 128) : - cute::min(vec_elem * src_bits, MaxVecBits); + constexpr int vec_bits = gcd(src_bits * max_align, MaxVecBits); + if constexpr (vec_elem > 1 && vec_bits >= 8) { + // If more than one element vectorizes to 8bits or more, then copy_vec #if 0 - if (thread0()) { - print("copy -- found max_common_vector of %d elems and vectorization to %d bits\n", vec_elem, vec_bits); - print(" "); print(src); print("\n"); - print(" "); print(dst); print("\n"); - } + if (thread0()) { + print("copy -- found max_common_vector of %d elems and vectorization to %d bits\n", vec_elem, vec_bits); + print(" "); print(src); print("\n"); + print(" "); print(dst); print("\n"); + } #endif - - if constexpr (vec_elem > 1 && vec_bits >= 8) { return copy_vec>(src, dst); } else { return copy_if(TrivialPredTensor{}, src, dst); @@ -294,10 +294,16 @@ void copy(Tensor const& src, Tensor & dst) { - return copy(AutoVectorizingCopy{}, src, dst); + if constexpr (is_static::value && is_static::value) { + // Assume Tensors with static layouts (e.g. registers) have pointers that are 128b aligned + return copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, src, dst); + } else { + // Do not assume that dynamic layouts are aligned. + return copy(AutoVectorizingCopyWithAssumedAlignment<8>{}, src, dst); + } } -// Auto-vectorizing copy with assumed alignment of dynamic layout strides up to 128bit. +// Auto-vectorizing copy with assumed alignment up to 128bit. template CUTE_HOST_DEVICE @@ -308,19 +314,6 @@ copy_aligned(Tensor const& src, return copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, src, dst); } -// Specializaton for Atom AutoVectorizingCopy -template -CUTE_HOST_DEVICE -void -copy(Copy_Atom const&, - Tensor const& src, - Tensor & dst) -{ - return copy(AutoVectorizingCopy{}, src, dst); -} - // Specializaton for Atom AutoVectorizingCopyAssumedAlignment template const& atom, // Copy_Traits m { using SrcType = typename SrcEngine::value_type; using DstType = typename DstEngine::value_type; - static_assert(sizeof_bits::value == sizeof_bits::value); + static_assert(cute::is_same::value); static_assert((is_gmem::value && is_smem::value) || (is_smem::value && is_gmem::value), "Bulk Copy only supports gmem -> smem or smem -> gmem movement."); diff --git a/include/cute/algorithm/fill.hpp b/include/cute/algorithm/fill.hpp index 5206065107..3f33a42ade 100644 --- a/include/cute/algorithm/fill.hpp +++ b/include/cute/algorithm/fill.hpp @@ -32,7 +32,7 @@ #include -#include +#include #include namespace cute diff --git a/include/cute/algorithm/gemm.hpp b/include/cute/algorithm/gemm.hpp index 27c322168a..c4713838b6 100644 --- a/include/cute/algorithm/gemm.hpp +++ b/include/cute/algorithm/gemm.hpp @@ -35,7 +35,7 @@ #include #include -#include +#include #include diff --git a/include/cute/algorithm/prefetch.hpp b/include/cute/algorithm/prefetch.hpp index 47aefa87bb..0d638ab58f 100644 --- a/include/cute/algorithm/prefetch.hpp +++ b/include/cute/algorithm/prefetch.hpp @@ -32,7 +32,7 @@ #include -#include +#include #include @@ -90,12 +90,6 @@ constexpr bool has_prefetch = false; template constexpr bool has_prefetch> = true; -template -constexpr bool is_prefetch = false; - -template -constexpr bool is_prefetch> = is_same_v; - } // end namespace detail template - -#include +#include namespace cute { @@ -100,13 +99,13 @@ transform(Tensor&& tensor, UnaryOp&& op) } // Similar to std::transform transforms one tensors and assigns it to another -template CUTE_HOST_DEVICE constexpr void -transform(Tensor const& tensor_in, - Tensor & tensor_out, +transform(Tensor const& tensor_in, + Tensor & tensor_out, UnaryOp&& op) { CUTE_UNROLL @@ -117,30 +116,30 @@ transform(Tensor const& tensor_in, // Accept mutable temporaries template CUTE_HOST_DEVICE constexpr void -transform(Tensor const& tensor_in, - Tensor && tensor_out, +transform(Tensor const& tensor_in, + Tensor && tensor_out, UnaryOp&& op) { return transform(tensor_in, tensor_out, op); } // Similar to std::transform with a binary operation -// Takes two tensors as input and one tensor as output. +// Takes two tensors as input and one tensor as output. // Applies the binary_op to tensor_in1 and tensor_in2 and // assigns it to tensor_out template CUTE_HOST_DEVICE constexpr void transform(Tensor const& tensor_in1, Tensor const& tensor_in2, - Tensor & tensor_out, + Tensor & tensor_out, BinaryOp&& op) { CUTE_UNROLL @@ -152,11 +151,11 @@ transform(Tensor const& tensor_in1, // Accept mutable temporaries template CUTE_HOST_DEVICE constexpr void -transform(Tensor const& tensor_in1, +transform(Tensor const& tensor_in1, Tensor const& tensor_in2, Tensor && tensor_out, BinaryOp&& op) diff --git a/include/cute/algorithm/tuple_algorithms.hpp b/include/cute/algorithm/tuple_algorithms.hpp index 3157e89719..616960a54a 100644 --- a/include/cute/algorithm/tuple_algorithms.hpp +++ b/include/cute/algorithm/tuple_algorithms.hpp @@ -404,29 +404,54 @@ namespace detail { // This impl compiles much faster than cute::apply and variadic args template CUTE_HOST_DEVICE constexpr -decltype(auto) -fold(T&& t, V&& v, F&& f, seq<>) +auto +fold(T&&, V&& v, F&&, seq<>) { - return static_cast(v); + return v; } -template +template CUTE_HOST_DEVICE constexpr -decltype(auto) -fold(T&& t, V&& v, F&& f, seq) +auto +fold(T&& t, V&& v, F&& f, seq) { - if constexpr (sizeof...(Is) == 0) { - return f(static_cast(v), get(static_cast(t))); - } else { - return fold(static_cast(t), - f(static_cast(v), get(static_cast(t))), - f, - seq{}); - } + return f(static_cast(v), get(static_cast(t))); +} - CUTE_GCC_UNREACHABLE; +template +CUTE_HOST_DEVICE constexpr +auto +fold(T&& t, V&& v, F&& f, seq) +{ + return f(f(static_cast(v), get(static_cast(t))), get(static_cast(t))); +} + +template +CUTE_HOST_DEVICE constexpr +auto +fold(T&& t, V&& v, F&& f, seq) +{ + return f(f(f(static_cast(v), get(static_cast(t))), get(static_cast(t))), get(static_cast(t))); +} + +template +CUTE_HOST_DEVICE constexpr +auto +fold(T&& t, V&& v, F&& f, seq) +{ + return f(f(f(f(static_cast(v), get(static_cast(t))), get(static_cast(t))), get(static_cast(t))), get(static_cast(t))); } +template +CUTE_HOST_DEVICE constexpr +auto +fold(T&& t, V&& v, F&& f, seq) +{ + return fold(static_cast(t), + f(f(f(f(static_cast(v), get(static_cast(t))), get(static_cast(t))), get(static_cast(t))), get(static_cast(t))), + f, + seq{}); +} } // end namespace detail template @@ -448,7 +473,7 @@ fold(T&& t, V&& v, F&& f) template CUTE_HOST_DEVICE constexpr -decltype(auto) +auto fold_first(T&& t, F&& f) { if constexpr (is_tuple>::value) { @@ -457,7 +482,7 @@ fold_first(T&& t, F&& f) f, make_range<1,tuple_size>::value>{}); } else { - return static_cast(t); + return t; } CUTE_GCC_UNREACHABLE; @@ -701,7 +726,14 @@ CUTE_HOST_DEVICE constexpr auto replace(T const& t, X const& x) { - return detail::construct(t, x, make_seq{}, seq<0>{}, make_range::value>{}); + if constexpr (is_tuple::value) { + return detail::construct(t, x, make_seq{}, seq<0>{}, make_range::value>{}); + } else { + static_assert(N == 0); + return x; + } + + CUTE_GCC_UNREACHABLE; } // Replace the first element of the tuple with x @@ -1077,9 +1109,9 @@ zip2_by(T const& t, TG const& guide) /// @return A tuple of the elements of @c t in reverse order. template -CUTE_HOST_DEVICE constexpr +CUTE_HOST_DEVICE constexpr auto -reverse(T const& t) +reverse(T const& t) { if constexpr (is_tuple::value) { return detail::apply(t, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_rseq{}); diff --git a/include/cute/arch/copy.hpp b/include/cute/arch/copy.hpp index b85e6a2002..5139289995 100644 --- a/include/cute/arch/copy.hpp +++ b/include/cute/arch/copy.hpp @@ -68,7 +68,7 @@ struct UniversalCopy // // Placeholder for the copy algorithm's stronger auto-vectorizing behavior -// that assumes alignment of dynamic layouts up to MaxVecBits +// that assumes alignment of pointers and dynamic layouts up to MaxVecBits // template @@ -80,15 +80,17 @@ struct AutoVectorizingCopyWithAssumedAlignment }; // -// Placeholder for the copy algorithm's default auto-vectorizing behavior -// that does not assume alignment of dynamic layouts +// AutoVectorizingCopy alias assumes maximal alignment of pointers and dynamic strides. +// If this is not the case then AutoVectorizingCopyWithAssumedAlignment should be used instead // -using AutoVectorizingCopy = AutoVectorizingCopyWithAssumedAlignment<8>; +using AutoVectorizingCopy = AutoVectorizingCopyWithAssumedAlignment<128>; -// Alias -using DefaultCopy = AutoVectorizingCopy; +// +// DefaultCopy alias does not assume alignment of pointers or dynamic strides. +// +using DefaultCopy = AutoVectorizingCopyWithAssumedAlignment<8>; // // Global memory prefetch into L2 diff --git a/include/cute/arch/copy_sm90_desc.hpp b/include/cute/arch/copy_sm90_desc.hpp index 856d4dd5dd..21e473ede9 100644 --- a/include/cute/arch/copy_sm90_desc.hpp +++ b/include/cute/arch/copy_sm90_desc.hpp @@ -95,8 +95,8 @@ wait_barrier(uint64_t& smem_barrier, // 64 bits user-mange ".reg .pred P1;\n" "LAB_WAIT:\n" "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n" - "@P1 bra.uni DONE;\n" - "bra.uni LAB_WAIT;\n" + "@P1 bra DONE;\n" + "bra LAB_WAIT;\n" "DONE:\n" "}\n" :: "r"(smem_int_ptr), @@ -134,6 +134,48 @@ enum class SmemSwizzleBits : uint8_t { B128 = 3, }; +enum class OOBFill : uint8_t { + ZERO = 0, + CONSTANT = 1, +}; + +CUTE_HOST_DEVICE char const* to_string(OOBFill const& t) { + switch (t) { + case OOBFill::ZERO: return "ZERO"; + case OOBFill::CONSTANT: return "CONSTANT"; + } + return nullptr; +} + +enum class L2Promotion : uint8_t { + DISABLE = 0, + B64 = 1, + B128 = 2, + B256 = 3, +}; + +CUTE_HOST_DEVICE char const* to_string(L2Promotion const& t) { + switch (t) { + case L2Promotion::DISABLE: return "DISABLE"; + case L2Promotion::B64: return "B64"; + case L2Promotion::B128: return "B128"; + case L2Promotion::B256: return "B256"; + } + return nullptr; +} + +// Aux parameters which are independent with the problem size +struct DescriptorAuxParams { + OOBFill oobfill_ = OOBFill::ZERO; + L2Promotion l2promo_ = L2Promotion::DISABLE; +}; + +enum class CacheHintSm90 : uint64_t { + EVICT_NORMAL = 0x1000000000000000, + EVICT_FIRST = 0x12F0000000000000, + EVICT_LAST = 0x14F0000000000000, +}; + #if (__CUDACC_VER_MAJOR__ >= 12) #if !defined(__CUDACC_RTC__) @@ -168,6 +210,27 @@ to_CUtensorMapSwizzle(SmemSwizzleBits const& t) { case SmemSwizzleBits::B128: return CU_TENSOR_MAP_SWIZZLE_128B; } } + +inline CUtensorMapFloatOOBfill +to_CUtensorMapFloatOOBfill(OOBFill const& t) { + switch(t) { + default: assert(false && "Unknown OOBFill!"); + case OOBFill::ZERO: return CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + case OOBFill::CONSTANT: return CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA; + } +} + +inline CUtensorMapL2promotion +to_CUtensorMapL2promotion(L2Promotion const& t) { + switch(t) { + default: assert(false && "Unknown L2Promotion!"); + case L2Promotion::DISABLE: return CU_TENSOR_MAP_L2_PROMOTION_NONE; + case L2Promotion::B64: return CU_TENSOR_MAP_L2_PROMOTION_L2_64B; + case L2Promotion::B128: return CU_TENSOR_MAP_L2_PROMOTION_L2_128B; + case L2Promotion::B256: return CU_TENSOR_MAP_L2_PROMOTION_L2_256B; + } +} + #endif // !defined(__CUDACC_RTC__) #endif // (__CUDACC_VER_MAJOR__ >= 12) @@ -257,22 +320,32 @@ tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor asm volatile ( "cvt.u64.u32 %0, %1;" :: "l"(smem_int64_desc), "r"(smem_int_desc)); - asm volatile ( - "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" - :: "l"(smem_int64_desc), "r"(prob_shape[0])); - asm volatile ( - "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 1, %1;" - :: "l"(smem_int64_desc), "r"(prob_shape[1])); - asm volatile ( - "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 2, %1;" - :: "l"(smem_int64_desc), "r"(prob_shape[2])); - // Strides must be a multiple of 16. Also, stride for the intermost dimension is implicitly 1 - asm volatile ( - "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" - :: "l"(smem_int64_desc), "l"(prob_stride[1] >> 4)); - asm volatile ( - "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 1, %1;" - :: "l"(smem_int64_desc), "l"(prob_stride[2] >> 4)); + asm volatile ( + "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" + :: "l"(smem_int64_desc), "r"(prob_shape[0])); + asm volatile ( + "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 1, %1;" + :: "l"(smem_int64_desc), "r"(prob_shape[1])); + asm volatile ( + "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 2, %1;" + :: "l"(smem_int64_desc), "r"(prob_shape[2])); + // Strides must be a multiple of 16. Also, stride for the intermost dimension is implicitly 1 + #if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 5))) + // 4 LSBs are not included + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[1])); + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 1, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[2])); + #else + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[1] >> 4)); + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 1, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[2] >> 4)); + #endif #else CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3"); #endif diff --git a/include/cute/arch/copy_sm90_tma.hpp b/include/cute/arch/copy_sm90_tma.hpp index 1136c43359..1851482119 100644 --- a/include/cute/arch/copy_sm90_tma.hpp +++ b/include/cute/arch/copy_sm90_tma.hpp @@ -44,7 +44,7 @@ namespace cute struct SM90_TMA_LOAD_1D { CUTE_HOST_DEVICE static void - copy(void const* desc_ptr, uint64_t* mbar_ptr, + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0) { @@ -53,11 +53,11 @@ struct SM90_TMA_LOAD_1D uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); asm volatile ( - "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes" - " [%0], [%1, {%3}], [%2];" + "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3}], [%2], %4;" : : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), - "r"(crd0) + "r"(crd0), "l"(cache_hint) : "memory"); #else CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); @@ -89,7 +89,7 @@ struct SM90_TMA_LOAD_1D struct SM90_TMA_LOAD_2D { CUTE_HOST_DEVICE static void - copy(void const* desc_ptr, uint64_t* mbar_ptr, + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0, int32_t const& crd1) { @@ -98,11 +98,11 @@ struct SM90_TMA_LOAD_2D uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); asm volatile ( - "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes" - " [%0], [%1, {%3, %4}], [%2];" + "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4}], [%2], %5;" : : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), - "r"(crd0), "r"(crd1) + "r"(crd0), "r"(crd1), "l"(cache_hint) : "memory"); #else CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); @@ -134,7 +134,7 @@ struct SM90_TMA_LOAD_2D struct SM90_TMA_LOAD_3D { CUTE_HOST_DEVICE static void - copy(void const* desc_ptr, uint64_t* mbar_ptr, + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) { @@ -143,11 +143,11 @@ struct SM90_TMA_LOAD_3D uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); asm volatile ( - "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes" - " [%0], [%1, {%3, %4, %5}], [%2];" + "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5}], [%2], %6;" : : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), - "r"(crd0), "r"(crd1), "r"(crd2) + "r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint) : "memory"); #else CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); @@ -179,7 +179,7 @@ struct SM90_TMA_LOAD_3D struct SM90_TMA_LOAD_4D { CUTE_HOST_DEVICE static void - copy(void const* desc_ptr, uint64_t* mbar_ptr, + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) { @@ -188,11 +188,11 @@ struct SM90_TMA_LOAD_4D uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); asm volatile ( - "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes" - " [%0], [%1, {%3, %4, %5, %6}], [%2];" + "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6}], [%2], %7;" : : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), - "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint) : "memory"); #else CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); @@ -224,7 +224,7 @@ struct SM90_TMA_LOAD_4D struct SM90_TMA_LOAD_5D { CUTE_HOST_DEVICE static void - copy(void const* desc_ptr, uint64_t* mbar_ptr, + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) { @@ -233,11 +233,11 @@ struct SM90_TMA_LOAD_5D uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); asm volatile ( - "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes" - " [%0], [%1, {%3, %4, %5, %6, %7}], [%2];" + "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;" : : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), - "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4) + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint) : "memory"); #else CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); @@ -269,39 +269,39 @@ struct SM90_TMA_LOAD_5D struct SM90_TMA_LOAD { CUTE_HOST_DEVICE static void - copy(void const* desc_ptr, uint64_t* mbar_ptr, + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0) { - return SM90_TMA_LOAD_1D::copy(desc_ptr, mbar_ptr, smem_ptr, crd0); + return SM90_TMA_LOAD_1D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0); } CUTE_HOST_DEVICE static void - copy(void const* desc_ptr, uint64_t* mbar_ptr, + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0, int32_t const& crd1) { - return SM90_TMA_LOAD_2D::copy(desc_ptr, mbar_ptr, smem_ptr, crd0, crd1); + return SM90_TMA_LOAD_2D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1); } CUTE_HOST_DEVICE static void - copy(void const* desc_ptr, uint64_t* mbar_ptr, + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) { - return SM90_TMA_LOAD_3D::copy(desc_ptr, mbar_ptr, smem_ptr, crd0, crd1, crd2); + return SM90_TMA_LOAD_3D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2); } CUTE_HOST_DEVICE static void - copy(void const* desc_ptr, uint64_t* mbar_ptr, + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) { - return SM90_TMA_LOAD_4D::copy(desc_ptr, mbar_ptr, smem_ptr, crd0, crd1, crd2, crd3); + return SM90_TMA_LOAD_4D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3); } CUTE_HOST_DEVICE static void - copy(void const* desc_ptr, uint64_t* mbar_ptr, + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, void * smem_ptr, int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) { - return SM90_TMA_LOAD_5D::copy(desc_ptr, mbar_ptr, smem_ptr, crd0, crd1, crd2, crd3, crd4); + return SM90_TMA_LOAD_5D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3, crd4); } struct PREFETCH diff --git a/include/cute/arch/mma_sm90_desc.hpp b/include/cute/arch/mma_sm90_desc.hpp index a6cb194360..1d6caba89d 100644 --- a/include/cute/arch/mma_sm90_desc.hpp +++ b/include/cute/arch/mma_sm90_desc.hpp @@ -85,7 +85,6 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, LayoutType const& t) { union GmmaDescriptor { - CUTE_HOST_DEVICE constexpr GmmaDescriptor() noexcept : desc_(0) {} CUTE_HOST_DEVICE constexpr @@ -135,21 +134,22 @@ union GmmaDescriptor // Decay to a uint64_t CUTE_HOST_DEVICE constexpr operator uint64_t() const noexcept { return desc_; } - - // Printer - CUTE_HOST_DEVICE friend void print(GmmaDescriptor const& t) - { - #if !defined(__CUDACC_RTC__) - printf("GmmaDescriptor: 0x%016llx\n", static_cast(t.desc_)); - printf(" start_addr : 0x%04x\n", t.bitfield.start_address_); - printf(" leading_off: 0x%04x (%d)\n", t.bitfield.leading_byte_offset_, t.bitfield.leading_byte_offset_); - printf(" stride_off : 0x%04x (%d)\n", t.bitfield.stride_byte_offset_, t.bitfield.stride_byte_offset_); - printf(" base_offset: 0x%01x\n", t.bitfield.base_offset_); - printf(" layout_type: 0x%01x (%s)\n", t.bitfield.layout_type_, to_string(static_cast(t.bitfield.layout_type_))); - #endif - } }; +// Printer +CUTE_HOST_DEVICE void +print(GmmaDescriptor const& t) +{ +#if !defined(__CUDACC_RTC__) + printf("GmmaDescriptor: 0x%016llx\n", static_cast(t.desc_)); + printf(" start_addr : 0x%04x\n", t.bitfield.start_address_); + printf(" leading_off: 0x%04x (%d)\n", t.bitfield.leading_byte_offset_, t.bitfield.leading_byte_offset_); + printf(" stride_off : 0x%04x (%d)\n", t.bitfield.stride_byte_offset_, t.bitfield.stride_byte_offset_); + printf(" base_offset: 0x%01x\n", t.bitfield.base_offset_); + printf(" layout_type: 0x%01x (%s)\n", t.bitfield.layout_type_, to_string(static_cast(t.bitfield.layout_type_))); +#endif // !defined(__CUDACC_RTC__) +} + //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cute diff --git a/include/cute/arch/util.hpp b/include/cute/arch/util.hpp index 92e342510a..61417d8360 100644 --- a/include/cute/arch/util.hpp +++ b/include/cute/arch/util.hpp @@ -235,24 +235,25 @@ explode(Fn fn, } template + class PtrD, int... Id, + class PtrA, int... Ia, + class PtrB, int... Ib, + class PtrC, int... Ic, + class PtrE, int... Ie, + class PtrF, int... If> CUTE_HOST_DEVICE constexpr void explode(Fn fn, - PtrD&& d, int_sequence, - PtrA&& a, int_sequence, - PtrB&& b, int_sequence, - PtrC&& c, int_sequence, - PtrSFA&& sfa, int_sequence, - PtrSFB&& sfb, int_sequence) + PtrD&& d, int_sequence, + PtrA&& a, int_sequence, + PtrB&& b, int_sequence, + PtrC&& c, int_sequence, + PtrE&& e, int_sequence, + PtrF&& f, int_sequence) { - return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., sfa[Isfa]..., sfb[Isfb]...); + return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., e[Ie]..., f[If]...); } + // // Utility for exploding tuples into functions // diff --git a/include/cute/atom/copy_atom.hpp b/include/cute/atom/copy_atom.hpp index 48a5fd168b..20a0627627 100644 --- a/include/cute/atom/copy_atom.hpp +++ b/include/cute/atom/copy_atom.hpp @@ -39,7 +39,7 @@ #include -#include +#include namespace cute { diff --git a/include/cute/atom/copy_traits.hpp b/include/cute/atom/copy_traits.hpp index 2aa3ba5774..bfbeb4ea51 100644 --- a/include/cute/atom/copy_traits.hpp +++ b/include/cute/atom/copy_traits.hpp @@ -32,7 +32,7 @@ #include -#include +#include namespace cute { @@ -145,4 +145,15 @@ copy_unpack(Copy_Traits const& traits, copy_unpack(traits, src, dst); } +namespace detail { + +template +constexpr bool is_prefetch = false; + +template +constexpr bool is_prefetch> = is_same_v; + +} // end namespace detail + + } // end namespace cute diff --git a/include/cute/atom/copy_traits_sm90_im2col.hpp b/include/cute/atom/copy_traits_sm90_im2col.hpp index 15d9979c92..f6c9e258eb 100644 --- a/include/cute/atom/copy_traits_sm90_im2col.hpp +++ b/include/cute/atom/copy_traits_sm90_im2col.hpp @@ -39,7 +39,7 @@ #include "cute/tensor.hpp" #include "cute/algorithm/prefetch.hpp" - +#include "cutlass/fast_math.h" namespace cute { @@ -388,18 +388,19 @@ template const& tensor_cwhdn, // (C,W,H,D,N) - uint32_t range_c, // TILE_C - uint32_t range_whdn, // TILE_WHDN - SmemSwizzle const& smem_swizzle, // Swizzle - TMALayout const& tma_layout_vt, // TMA layout - LowerCornerStride const& lower_corner_whd, // WHD offset of the "base pointer" - UpperCornerStride const& upper_corner_whd, // WHD upper corner - LowerPaddingStride const& lower_padding_whd, // WHD lower padding - UpperPaddingStride const& upper_padding_whd, // WHD upper padding - TraversalStride const& stride_whd, // WHD traversal stride - LowerSRTStride const& lower_srt, // SRT offset of the "base pointer" - DilationStride const& stride_srt) // SRT stride - dilation + Tensor const& tensor_cwhdn, // (C,W,H,D,N) + uint32_t range_c, // TILE_C + uint32_t range_whdn, // TILE_WHDN + SmemSwizzle const& smem_swizzle, // Swizzle + TMALayout const& tma_layout_vt, // TMA layout + LowerCornerStride const& lower_corner_whd, // WHD offset of the "base pointer" + UpperCornerStride const& upper_corner_whd, // WHD upper corner + LowerPaddingStride const& lower_padding_whd, // WHD lower padding + UpperPaddingStride const& upper_padding_whd, // WHD upper padding + TraversalStride const& stride_whd, // WHD traversal stride + LowerSRTStride const& lower_srt, // SRT offset of the "base pointer" + DilationStride const& stride_srt, // SRT stride - dilation + TMA::DescriptorAuxParams const& aux_params = {}) { static_assert(is_gmem::value, "Tensor must point to GPU global memory."); using value_type = typename EngineA::value_type; @@ -445,8 +446,8 @@ make_im2col_tma_copy_desc( CUtensorMapDataType tma_format = TMA::to_CUtensorMapDataType(); CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; - CUtensorMapL2promotion tma_l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_NONE; - CUtensorMapFloatOOBfill tma_oob_fill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + CUtensorMapL2promotion tma_l2Promotion = to_CUtensorMapL2promotion(aux_params.l2promo_); + CUtensorMapFloatOOBfill tma_oob_fill = to_CUtensorMapFloatOOBfill(aux_params.oobfill_); CUtensorMapSwizzle tma_swizzle = TMA::to_CUtensorMapSwizzle(detail::get_tma_swizzle_bits(smem_swizzle)); CUresult encode_result = cuTensorMapEncodeIm2col( @@ -498,7 +499,11 @@ make_im2col_tma_copy_desc( // For fprop/dgrad kernel, gemm_shapes is ((q, p, z, n), (c, s, r, t)) // For wgrad kernel, gemm_shapes is ((c, s, r, t), (q, p, z, n)) - auto gemm_shapes_common = make_shape(gemm_mn, gemm_k); + auto gemm_shapes_common = make_shape( + transform_leaf(gemm_mn, [](auto s) { + return conditional_return(cute::is_static{}, s, cutlass::FastDivmod(s)); + }), + gemm_k); auto gemm_shapes = make_shape( basis_get(stride<0,1>(tma_layout_vt), gemm_shapes_common), basis_get(stride<0,0>(tma_layout_vt), gemm_shapes_common)); @@ -554,17 +559,18 @@ template const& gtensor, // Full GMEM Tensor: ((w, h, d, n), c) - SLayout const& slayout, // CTA Tile of SMEM, potentially swizzled - int32_t const& num_multicast, // The number of CTAs involved in multicasting - Layout const& cta_v_map, // V: CTA val idx -> gmem mode - LowerCornerStride const& lower_corner_whd, - UpperCornerStride const& upper_corner_whd, - LowerPaddingStride const& lower_padding_whd, - UpperPaddingStride const& upper_padding_whd, - TraversalStride const& stride_whd, // traversal stride - LowerSRTStride const& lower_srt, - DilationStride const& stride_srt) // dilation + Tensor const& gtensor, // Full GMEM Tensor: ((w, h, d, n), c) + SLayout const& slayout, // CTA Tile of SMEM, potentially swizzled + int32_t const& num_multicast, // The number of CTAs involved in multicasting + Layout const& cta_v_map, // V: CTA val idx -> gmem mode + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, // traversal stride + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt, // dilation + TMA::DescriptorAuxParams const& aux_params = {}) { // // TMA parameter checking @@ -645,7 +651,8 @@ make_tma_atom_im2col(CopyOp, upper_padding_whd, stride_whd, lower_srt, - stride_srt); + stride_srt, + aux_params); // // Construct the Copy_Traits @@ -697,18 +704,19 @@ template CUTE_HOST_RTC auto -make_tma_copy_im2col(CopyOp const& copy_op, - Tensor const& gtensor, - SLayout const& slayout, - Layout const& cta_t_map, // CTA tid -> logical TMA tid - Layout const& cta_v_map, // CTA vid -> gmem coord - LowerCornerStride const& lower_corner_whd, - UpperCornerStride const& upper_corner_whd, - LowerPaddingStride const& lower_padding_whd, - UpperPaddingStride const& upper_padding_whd, - TraversalStride const& stride_whd, // traversal stride - LowerSRTStride const& lower_srt, - DilationStride const& stride_srt) // dilation +make_tma_copy_im2col(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + Layout const& cta_t_map, // CTA tid -> logical TMA tid + Layout const& cta_v_map, // CTA vid -> gmem coord + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, // traversal stride + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt, // dilation + TMA::DescriptorAuxParams const& aux_params = {}) { // // TMA parameter checking @@ -719,7 +727,7 @@ make_tma_copy_im2col(CopyOp const& copy_op, Copy_Atom atom = make_tma_atom_im2col(copy_op, gtensor, slayout, cosize(cta_t_map), cta_v_map, lower_corner_whd, upper_corner_whd, lower_padding_whd, - upper_padding_whd, stride_whd, lower_srt, stride_srt); + upper_padding_whd, stride_whd, lower_srt, stride_srt, aux_params); // // Construct the TiledCopy diff --git a/include/cute/atom/copy_traits_sm90_tma.hpp b/include/cute/atom/copy_traits_sm90_tma.hpp index d42c82c915..950855a1a2 100644 --- a/include/cute/atom/copy_traits_sm90_tma.hpp +++ b/include/cute/atom/copy_traits_sm90_tma.hpp @@ -124,17 +124,24 @@ struct Copy_Traits // Construct an executable SM90_TMA_LOAD with tma_mbar CUTE_HOST_DEVICE constexpr Copy_Traits - with(uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0) const { + with( + uint64_t& tma_mbar, + [[maybe_unused]] uint16_t const& multicast_mask = 0, + TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { // We accept multicast_mask here to keep the API for both atoms consistent - return {{}, {&tma_desc_, &tma_mbar}}; + return {{}, {&tma_desc_, &tma_mbar, static_cast(cache_hint)}}; } // Construct an executable SM90_TMA_LOAD with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) CUTE_HOST_DEVICE constexpr Copy_Traits - with(TmaDescriptor const* new_tma_desc, uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0) const { + with( + TmaDescriptor const* new_tma_desc, + uint64_t& tma_mbar, + [[maybe_unused]] uint16_t const& multicast_mask = 0, + TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { // We accept multicast_mask here to keep the API for both atoms consistent - return {{}, {new_tma_desc, &tma_mbar}}; + return {{}, {new_tma_desc, &tma_mbar, static_cast(cache_hint)}}; } // Generate the TMA coord tensor @@ -171,7 +178,8 @@ struct Copy_Traits // SM90_TMA_LOAD arguments tuple< TmaDescriptor const*, - uint64_t* // smem mbarrier + uint64_t*, // smem mbarrier + uint64_t // cache hint > const opargs_; }; @@ -286,6 +294,38 @@ struct Copy_Traits ///////////////////////////// TMA_STORE ////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////// +// Utility for unpacking TMA_STORE arguments into a CopyOp +template +struct TMA_STORE_Unpack +{ + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_smem::value, "Expected smem src for SM90_TMA_STORE"); + + void const* const desc_ptr = traits.tma_desc_; + void const* const src_ptr = cute::raw_pointer_cast(src.data()); + auto dst_coord = dst.data().coord_; +#if 0 + auto [c0,c1,c2,c3,c4] = append<5>(dst_coord, 0); + printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", + threadIdx.x, threadIdx.y, threadIdx.z, + blockIdx.x, blockIdx.y, blockIdx.z, + int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), src_ptr); +#endif + return detail::explode_tuple(detail::CallCOPY{}, + make_tuple(desc_ptr, src_ptr), seq<0,1>{}, + dst_coord, tuple_seq{}); + } +}; + +struct SM90_TMA_STORE_OP : SM90_TMA_STORE {}; + // The executable SM90_TMA_STORE with tma_desc template struct Copy_Traits @@ -343,6 +383,30 @@ struct Copy_Traits make_tuple(desc_ptr, src_ptr), seq<0,1>{}, dst_coord, tuple_seq{}); } + + // Construct Copy_Traits executable (w/ swapped out TMA descriptor) for SM90_TMA_STORE (for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(TmaDescriptor const* new_tma_desc) const { + return {{}, new_tma_desc}; + } +}; + +// The executable SM90_TMA_STORE with tma_desc +template +struct Copy_Traits + : TMA_STORE_Unpack +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_STORE arguments + TmaDescriptor const* tma_desc_; }; ////////////////////////////////////////////////////////////////////////////// @@ -1240,14 +1304,14 @@ template + class Cluster_Size = Int<1>> CUTE_HOST_RTC auto make_tma_atom(CopyOp const& copy_op, Tensor const& gtensor, SLayout const& slayout, CTA_Tiler const& cta_tiler, - Cluster_Size const& cluster_size) + Cluster_Size const& cluster_size = {}) { auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler); // Prefer TmaInternalType if specified. Fallback to GEngine::value_type @@ -1283,8 +1347,8 @@ tma_partition(Copy_Atom const& copy_atom, auto layout_V = make_tile(logical_divide(layout_v, tma_layout_v)); // Append with _ until we cover all Rest... modes - auto glayout_V = append>(layout_V, _); - auto slayout_V = append>(layout_V, _); + auto glayout_V = append(layout_V, _); + auto slayout_V = append(layout_V, _); // Transform tile mode and coalesce Tensor gtensor_v = coalesce(gtensor.compose(glayout_V), Shape>{}); // ((TMA,TMA_Iter), Rest...) Tensor stensor_v = coalesce(stensor.compose(slayout_V), Shape>{}); // ((TMA,TMA_Iter), Rest...) @@ -1304,8 +1368,8 @@ tma_partition(Copy_Atom const& copy_atom, // Offset inside the TMA-mode for the multicast auto multicast_offset = cta_layout(cta_coord) * (size(tma_layout_v) / cosize(cta_layout)); auto multicast_coord = make_coord(make_coord(multicast_offset, Int<0>{})); - auto scoord = append(multicast_coord, Int<0>{}); auto gcoord = append(multicast_coord, Int<0>{}); + auto scoord = append(multicast_coord, Int<0>{}); Tensor gresult = domain_offset(gcoord, gtensor_v); Tensor sresult = domain_offset(scoord, stensor_v); @@ -1332,4 +1396,116 @@ create_tma_multicast_mask(CtaLayout const& cta_layout_vmnk, return mcast_mask; } +//////////////////////////////////// +// Make TMA copy A/B/C +/////////////////////////////////// + +template +CUTE_HOST_RTC +auto +make_tma_copy_A_sm90(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + CTA_Tiler const& cta_tiler, + Cluster_Size const& cluster_size) +{ + // Keep only MK modes from MNK + auto cta_tiler_mk = remove<1>(cta_tiler); + + // mcast along N mode for this M load, if any + auto cluster_size_n = size<1>(cluster_size); + + if constexpr (cute::is_same_v) { + return make_im2col_tma_copy(copy_op, + gtensor, + slayout, + cta_tiler_mk, + cluster_size_n); + } else { + auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler_mk); + auto cta_t_tile = make_layout(cluster_size_n); + + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + auto tma_copy = detail::make_tma_copy_tiled(copy_op, gtensor, slayout, cta_t_tile, cta_v_tile); + return tma_copy; + } +} + +template +CUTE_HOST_RTC +auto +make_tma_copy_B_sm90(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + CTA_Tiler const& cta_tiler, + Cluster_Size const& cluster_size) +{ + // Keep only NK modes from MNK + auto cta_tiler_nk = remove<0>(cta_tiler); + + // mcast along M mode for this N load, if any + auto cluster_size_m = size<0>(cluster_size); + + if constexpr (cute::is_same_v) { + return make_im2col_tma_copy(copy_op, + gtensor, + slayout, + cta_tiler_nk, + cluster_size_m); + } else { + auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler_nk); + auto cta_t_tile = make_layout(cluster_size_m); + + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + auto tma_copy = detail::make_tma_copy_tiled(copy_op, gtensor, slayout, cta_t_tile, cta_v_tile); + return tma_copy; + } +} + +template +CUTE_HOST_RTC +auto +make_tma_copy_C_sm90(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + CTA_Tiler const& cta_tiler) +{ + // Keep only MN modes from MNK + auto cta_tiler_mn = remove<2>(cta_tiler); + + if constexpr (cute::is_same_v || + cute::is_same_v) { + return make_im2col_tma_copy(copy_op, + gtensor, + slayout, + cta_tiler_mn, + _1{}); + } else { + auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler_mn); + + // No multicast, so only 1 CTA involved + auto cta_t_map = Layout<_1,_0>{}; + + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + auto tma_copy = detail::make_tma_copy_tiled(copy_op, gtensor, slayout, cta_t_map, cta_v_tile); + return tma_copy; + } +} } // end namespace cute diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index d164a954a7..6dc826ef2f 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -31,11 +31,9 @@ #pragma once #include - #include - #include -#include +#include #include namespace cute { @@ -102,7 +100,7 @@ struct MMA_Atom> static_assert(BLayout::rank == 1, "Expected rank-1 B tensor"); static_assert(CLayout::rank == 1, "Expected rank-1 C tensor"); - return mma_unpack(*this, D, A, B, C); + return mma_unpack(static_cast(*this), D, A, B, C); } // Three arguments reproduces C @@ -245,12 +243,9 @@ struct TiledMMA : MMA_Atom thrfrg_C(CTensor&& ctensor) const { CUTE_STATIC_ASSERT_V(rank(ctensor) >= Int<2>{}); - //CUTE_STATIC_ASSERT_V(size<0>(ctensor) % size<0>(TiledShape_MNK{}) == Int<0>{}); - //CUTE_STATIC_ASSERT_V(size<1>(ctensor) % size<1>(TiledShape_MNK{}) == Int<0>{}); - // Reorder the tensor for the TiledAtom - auto t_tile = make_tile(get<0>(PermutationMNK{}), - get<1>(PermutationMNK{})); + auto t_tile = make_tile(permutation_mnk<0>(), + permutation_mnk<1>()); auto t_tensor = logical_divide(ctensor, t_tile); // (PermM,PermN) // Tile the tensor for the Atom @@ -287,12 +282,9 @@ struct TiledMMA : MMA_Atom thrfrg_A(ATensor&& atensor) const { CUTE_STATIC_ASSERT_V(rank(atensor) >= Int<2>{}); - //CUTE_STATIC_ASSERT_V(size<0>(atensor) % size<0>(TiledShape_MNK{}) == Int<0>{}); - //CUTE_STATIC_ASSERT_V(size<1>(atensor) % size<2>(TiledShape_MNK{}) == Int<0>{}); - // Reorder the tensor for the TiledAtom - auto t_tile = make_tile(get<0>(PermutationMNK{}), - get<2>(PermutationMNK{})); + auto t_tile = make_tile(permutation_mnk<0>(), + permutation_mnk<2>()); auto t_tensor = logical_divide(atensor, t_tile); // (PermM,PermK) // Tile the tensor for the Atom @@ -329,12 +321,9 @@ struct TiledMMA : MMA_Atom thrfrg_B(BTensor&& btensor) const { CUTE_STATIC_ASSERT_V(rank(btensor) >= Int<2>{}); - //CUTE_STATIC_ASSERT_V(size<0>(btensor) % size<1>(TiledShape_MNK{}) == Int<0>{}); - //CUTE_STATIC_ASSERT_V(size<1>(btensor) % size<2>(TiledShape_MNK{}) == Int<0>{}); - // Reorder the tensor for the TiledAtom - auto t_tile = make_tile(get<1>(PermutationMNK{}), - get<2>(PermutationMNK{})); + auto t_tile = make_tile(permutation_mnk<1>(), + permutation_mnk<2>()); auto t_tensor = logical_divide(btensor, t_tile); // (PermN,PermK) // Tile the tensor for the Atom @@ -377,21 +366,23 @@ struct TiledMMA : MMA_Atom // Utility for printing and visualization // + // The permutation applied to the MNK-mode data + template + CUTE_HOST_DEVICE constexpr + auto + permutation_mnk() const { + static_assert(0 <= I && I < 3); + auto perm = get(PermutationMNK{}); + return conditional_return(is_underscore{}, size(AtomShape_MNK{}) * size(get_thr_layout_vmnk()), perm); + } + // The size of the MNK-mode template CUTE_HOST_DEVICE constexpr auto tile_size_mnk() const { static_assert(0 <= I && I < 3); - auto core_size = size(AtomShape_MNK{}) * size(get_thr_layout_vmnk()); - [[maybe_unused]] auto perm_size = size(PermutationMNK{}); - if constexpr (is_underscore::value) { - return core_size; - } else { - return cute::max(core_size, perm_size); - } - - CUTE_GCC_UNREACHABLE; + return size(permutation_mnk()); } CUTE_HOST_DEVICE constexpr diff --git a/include/cute/atom/mma_traits.hpp b/include/cute/atom/mma_traits.hpp index 34275831b8..8b9ac73642 100644 --- a/include/cute/atom/mma_traits.hpp +++ b/include/cute/atom/mma_traits.hpp @@ -32,7 +32,7 @@ #include -#include +#include namespace cute { diff --git a/include/cute/atom/mma_traits_sm90_gmma.hpp b/include/cute/atom/mma_traits_sm90_gmma.hpp index db6f0fc2e9..b2088b3bf3 100644 --- a/include/cute/atom/mma_traits_sm90_gmma.hpp +++ b/include/cute/atom/mma_traits_sm90_gmma.hpp @@ -332,9 +332,6 @@ struct DescriptorIterator { return { GmmaDescriptor{desc_ + uint64_t(offset)} }; } - - CUTE_HOST_DEVICE friend void - print(DescriptorIterator) { printf("GMMA::DescriptorIterator"); } }; template @@ -353,6 +350,11 @@ recast_ptr(DescriptorIterator const& iter) { return iter; // Do nothing, it will still dereference to GmmaDescriptor and decay to uint64_t } +CUTE_HOST_DEVICE void +print(DescriptorIterator) { + printf("GMMA::DescriptorIterator"); +} + // The GMMA Traits below have custom fragment type flags for their smem desc tensors. // These flags specialize a MakeTensor customization point to correctly make the fragment that is desired. template diff --git a/include/cute/container/alignment.hpp b/include/cute/container/alignment.hpp index 509579eeff..4cf60d899f 100644 --- a/include/cute/container/alignment.hpp +++ b/include/cute/container/alignment.hpp @@ -44,7 +44,7 @@ CUTE_HOST_DEVICE constexpr bool is_byte_aligned(void const* const ptr) { - static_assert(N > 0 && (N & (N - 1)) == 0, "N must be a power of 2 in alignment check"); + static_assert(has_single_bit(N), "N must be a power of 2 in alignment check"); return (reinterpret_cast(ptr) & (N-1)) == 0; } diff --git a/include/cute/container/array_subbyte.hpp b/include/cute/container/array_subbyte.hpp index 3ab3bc3205..1963d8ce7b 100644 --- a/include/cute/container/array_subbyte.hpp +++ b/include/cute/container/array_subbyte.hpp @@ -205,18 +205,22 @@ struct subbyte_iterator private: template friend struct swizzle_ptr; + template friend CUTE_HOST_DEVICE constexpr U* raw_pointer_cast(subbyte_iterator const&); + template friend CUTE_HOST_DEVICE constexpr auto recast_ptr(subbyte_iterator const&); + template friend CUTE_HOST_DEVICE void print(subbyte_iterator const&); // Pointer to storage element - storage_type* ptr_ = nullptr; + storage_type* ptr_; // Bit index of value_type starting position within storage_type element. // RI: 0 <= idx_ < sizeof_bit - uint8_t idx_ = 0; + uint8_t idx_; public: - // Ctor - subbyte_iterator() = default; + // Default Ctor + CUTE_HOST_DEVICE constexpr + subbyte_iterator() : ptr_{nullptr}, idx_{0} {}; // Ctor template @@ -286,42 +290,47 @@ struct subbyte_iterator return x.ptr_ == y.ptr_ && x.idx_ == y.idx_; } CUTE_HOST_DEVICE constexpr friend + bool operator!=(subbyte_iterator const& x, subbyte_iterator const& y) { return !(x == y); } + CUTE_HOST_DEVICE constexpr friend bool operator< (subbyte_iterator const& x, subbyte_iterator const& y) { return x.ptr_ < y.ptr_ || (x.ptr_ == y.ptr_ && x.idx_ < y.idx_); } CUTE_HOST_DEVICE constexpr friend - bool operator!=(subbyte_iterator const& x, subbyte_iterator const& y) { return !(x == y); } - CUTE_HOST_DEVICE constexpr friend bool operator<=(subbyte_iterator const& x, subbyte_iterator const& y) { return !(y < x); } CUTE_HOST_DEVICE constexpr friend bool operator> (subbyte_iterator const& x, subbyte_iterator const& y) { return (y < x); } CUTE_HOST_DEVICE constexpr friend bool operator>=(subbyte_iterator const& x, subbyte_iterator const& y) { return !(x < y); } +}; - // Conversion to raw pointer with loss of subbyte index - CUTE_HOST_DEVICE constexpr friend - T* raw_pointer_cast(subbyte_iterator const& x) { - assert(x.idx_ == 0); - return reinterpret_cast(x.ptr_); - } +// Conversion to raw pointer with loss of subbyte index +template +CUTE_HOST_DEVICE constexpr +T* +raw_pointer_cast(subbyte_iterator const& x) { + assert(x.idx_ == 0); + return reinterpret_cast(x.ptr_); +} - // Conversion to NewT_ with possible loss of subbyte index - template - CUTE_HOST_DEVICE constexpr friend - auto recast_ptr(subbyte_iterator const& x) { - using NewT = conditional_t<(is_const_v), NewT_ const, NewT_>; - if constexpr (cute::is_subbyte_v) { // Making subbyte_iter, preserve the subbyte idx - return subbyte_iterator(x.ptr_, x.idx_); - } else { // Not subbyte, assume/assert subbyte idx 0 - return reinterpret_cast(raw_pointer_cast(x)); - } - CUTE_GCC_UNREACHABLE; - } +// Conversion to NewT_ with possible loss of subbyte index +template +CUTE_HOST_DEVICE constexpr +auto +recast_ptr(subbyte_iterator const& x) { + using NewT = conditional_t<(is_const_v), NewT_ const, NewT_>; + if constexpr (cute::is_subbyte_v) { // Making subbyte_iter, preserve the subbyte idx + return subbyte_iterator(x.ptr_, x.idx_); + } else { // Not subbyte, assume/assert subbyte idx 0 + return reinterpret_cast(raw_pointer_cast(x)); + } + CUTE_GCC_UNREACHABLE; +} - CUTE_HOST_DEVICE friend void print(subbyte_iterator x) { - printf("subptr[%db](%p.%u)", int(sizeof_bits_v), x.ptr_, x.idx_); - } -}; +template +CUTE_HOST_DEVICE void +print(subbyte_iterator const& x) { + printf("subptr[%db](%p.%u)", int(sizeof_bits_v), x.ptr_, x.idx_); +} // // array_subbyte @@ -365,17 +374,6 @@ struct array_subbyte public: - constexpr - array_subbyte() = default; - - CUTE_HOST_DEVICE constexpr - array_subbyte(array_subbyte const& x) { - CUTE_UNROLL - for (size_type i = 0; i < StorageElements; ++i) { - storage[i] = x.storage[i]; - } - } - CUTE_HOST_DEVICE constexpr size_type size() const { return N; @@ -448,25 +446,16 @@ struct array_subbyte return at(N-1); } + // In analogy to std::vector::data(), these functions are deleted to prevent bugs. + // Instead, prefer + // auto* data = raw_pointer_cast(my_subbyte_array.begin()); + // where the type of auto* is implementation-defined and + // with the knowledge that [data, data + my_subbyte_array.size()) may not be a valid range. CUTE_HOST_DEVICE constexpr - pointer data() { - return reinterpret_cast(storage); - } - - CUTE_HOST_DEVICE constexpr - const_pointer data() const { - return reinterpret_cast(storage); - } + pointer data() = delete; CUTE_HOST_DEVICE constexpr - storage_type* raw_data() { - return storage; - } - - CUTE_HOST_DEVICE constexpr - storage_type const* raw_data() const { - return storage; - } + const_pointer data() const = delete; CUTE_HOST_DEVICE constexpr iterator begin() { diff --git a/include/cute/container/packed_tuple.hpp b/include/cute/container/packed_tuple.hpp new file mode 100644 index 0000000000..c20df2c235 --- /dev/null +++ b/include/cute/container/packed_tuple.hpp @@ -0,0 +1,254 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include +#include + +namespace cute { + +namespace detail { + +// Empty Structure Optimization +template +struct ESO; + +template +static constexpr bool is_first_empty_v = cute::is_empty::value; +template +static constexpr bool is_rest_empty_v = (cute::is_empty::value && ...); + +template +using ESO_t = ESO, is_rest_empty_v, T...>; + +// Empty First and Empty Rest... +template +struct ESO { + CUTE_HOST_DEVICE constexpr + ESO() {} + + CUTE_HOST_DEVICE constexpr + ESO(First const&, Rest const&...) {} +}; + +// NonEmpty First and Empty Rest... +template +struct ESO { + CUTE_HOST_DEVICE constexpr + ESO() : first_{} {} + + CUTE_HOST_DEVICE constexpr + ESO(First const& first, Rest const&...) : first_{first} {} + + First first_; +}; + +// Empty First and NonEmpty Rest... +template +struct ESO { + CUTE_HOST_DEVICE constexpr + ESO() : rest_{} {} + + CUTE_HOST_DEVICE constexpr + ESO(First const&, Rest const&... rest) : rest_{rest...} {} + + ESO_t rest_; +}; + +// NonEmpty T and NonEmpty Rest... +template +struct ESO { + CUTE_HOST_DEVICE constexpr + ESO() : first_{}, rest_{} {} + + CUTE_HOST_DEVICE constexpr + ESO(First const& first, Rest const&... rest) : first_{first}, rest_{rest...} {} + + First first_; + ESO_t rest_; +}; + +// Get Nth value from ESO +template +CUTE_HOST_DEVICE constexpr decltype(auto) getv(ESO const& s) { + if constexpr (N == 0) { + if constexpr (F) { return T{}; } + else { return static_cast(s.first_); } + } else { + if constexpr (R) { return cute::tuple_element_t>{}; } + else { return getv(s.rest_); } + } +} + +template +CUTE_HOST_DEVICE constexpr decltype(auto) getv(ESO& s) { + if constexpr (N == 0) { + if constexpr (F) { return T{}; } + else { return static_cast(s.first_); } + } else { + if constexpr (R) { return cute::tuple_element_t>{}; } + else { return getv(s.rest_); } + } +} + +template +CUTE_HOST_DEVICE constexpr decltype(auto) getv(ESO&& s) { + if constexpr (N == 0) { + if constexpr (F) { return T{}; } + else { return static_cast(s.first_); } + } else { + if constexpr (R) { return cute::tuple_element_t>{}; } + else { return getv(static_cast&&>(s.rest_)); } + } +} + +// findt: Implementation detail of cute::find. +// If X is the first template argument of the tuple, findt returns C. + +template +CUTE_HOST_DEVICE constexpr +auto +findt(ESO const& t) noexcept +{ + if constexpr (cute::is_same_v) { + return C{}; + } + else { + static_assert(sizeof...(Rest) != 0, + "The type does not appear in the argument list of the tuple."); + if constexpr (IsRestEmpty) { + // The rest is empty, so creating an instance of it is cheap. + return cute::detail::findt(ESO_t{}); + } + else { + return cute::detail::findt(t.rest_); + } + } +} + +} // end namespace detail + +// packed_tuple is a tuple type that is a standard-layout type +// whenever all of its template arguments are standard layout types: +// (cute::is_standard_layout_v && ...) implies (cute::is_standard_layout_v>) + +template +struct packed_tuple : detail::ESO_t +{ + CUTE_HOST_DEVICE constexpr + packed_tuple() {} + + CUTE_HOST_DEVICE constexpr + packed_tuple(T const&... ts) + : detail::ESO_t(ts...) + {} +}; + +template <> +struct packed_tuple<> {}; + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(packed_tuple const& t) { + static_assert(I < sizeof...(T), "Index out of range"); + return detail::getv(t); +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(packed_tuple& t) { + static_assert(I < sizeof...(T), "Index out of range"); + return detail::getv(t); +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(packed_tuple&& t) { + static_assert(I < sizeof...(T), "Index out of range"); + return detail::getv(static_cast&&>(t)); +} + +template +CUTE_HOST_DEVICE constexpr +packed_tuple +make_packed_tuple(T const&... t) +{ + return {t...}; +} + +// Returns the position of type X (as a static integer) in the tuple +// type's argument list. X must be unique in the argument list. +template +CUTE_HOST_DEVICE constexpr +auto +find(packed_tuple const& t) noexcept +{ + return detail::findt(t); +} + +} // end namespace cute + +namespace CUTE_STL_NAMESPACE +{ + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +} // end namespace CUTE_STL_NAMESPACE + +#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD +namespace std { + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +} // end namespace std +#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/include/cute/container/tuple.hpp b/include/cute/container/tuple.hpp index 0af98f5675..54d282419e 100644 --- a/include/cute/container/tuple.hpp +++ b/include/cute/container/tuple.hpp @@ -36,10 +36,13 @@ #include #include +#include +#if defined(CUTLASS_USE_PACKED_TUPLE) +# include +#endif //#include // Advanced optimizations -// // cute::tuple is like std::tuple, with two differences. // // 1. It works on both host and device. @@ -50,19 +53,30 @@ // but do _not_ include references like int& or float&. // (See std::tie for an example of a tuple of references.) // -// This is simplified over the implementations in std::, cuda::std::, and thrust:: by ignoring much of -// the conversion SFINAE, special overloading, and avoiding cvref template types. -// Furthermore, the empty base optimization (EBO) is MORE aggressive by avoiding -// construction calls, and ignoring any need for unique element addresses. -// -// Over standard-conforming tuple implementations, this appears to accelerate compilation times by over 3x. +// If the template arguments of cute::tuple are all empty types (in +// the sense of std::is_empty_v), then the cute::tuple is also an +// empty type. Furthermore, if CUTLASS_USE_PACKED_TUPLE is defined, +// cute::tuple is always a standard-layout type if all of its template +// arguments are standard-layout types. namespace cute { +#if defined(CUTLASS_USE_PACKED_TUPLE) + +template +using tuple = packed_tuple; + +#else + namespace detail { +// This is simplified over the implementations in std::, cuda::std::, and thrust:: by ignoring much of +// the conversion SFINAE, special overloading, and avoiding cvref template types. +// +// Over standard-conforming tuple implementations, this appears to accelerate compilation times by over 3x. + // EBO stands for "empty base optimization." // We use this technique to ensure that cute::tuple // doesn't need to waste space storing any template arguments @@ -70,6 +84,12 @@ namespace detail // Otherwise, cute::tuple would need to spend at least 1 byte // for each of its template arguments. // +// This is one way in which cute::tuple differs from std::tuple. +// Empty types in the template argument list are not even constructed, +// and do not have unique element addresses. In fact, they are not +// even members of the tuple or stored in any way. Calling `get` +// constructs and returns an instance of an empty type on demand. +// // EBO always "holds" a single value of type T. // N is like an array index that TupleBase uses // to access the desired tuple element. @@ -109,9 +129,8 @@ struct EBO CUTE_HOST_DEVICE constexpr EBO() : t_{} {} - template CUTE_HOST_DEVICE constexpr - EBO(U const& u) : t_{u} {} + EBO(T const& t) : t_{t} {} T t_; }; @@ -141,15 +160,8 @@ struct TupleBase, T...> CUTE_HOST_DEVICE constexpr TupleBase() {} - template - CUTE_HOST_DEVICE constexpr explicit - TupleBase(U const&... u) - : EBO(u)... {} - - template CUTE_HOST_DEVICE constexpr - TupleBase(TupleBase, U...> const& u) - : EBO(getv(static_cast const&>(u)))... {} + TupleBase(T const&... t) : EBO(t)... {} }; } // end namespace detail @@ -172,16 +184,14 @@ struct tuple : detail::TupleBase, T...> CUTE_HOST_DEVICE constexpr tuple() {} - template - CUTE_HOST_DEVICE constexpr - tuple(U const&... u) : detail::TupleBase, T...>(u...) {} - - template CUTE_HOST_DEVICE constexpr - tuple(tuple const& u) - : detail::TupleBase, T...>(static_cast, U...> const&>(u)) {} + tuple(T const&... t) : detail::TupleBase, T...>(t...) {} }; +template <> +struct tuple<> +{}; + // // get for cute::tuple (just like std::get for std::tuple) // @@ -227,6 +237,8 @@ find(tuple const& t) noexcept return detail::findt(t); } +#endif // CUTLASS_USE_PACKED_TUPLE + // // Custom is_tuple trait simply checks the existence of tuple_size // and assumes std::get(.), std::tuple_element @@ -242,6 +254,9 @@ auto has_tuple_size(...) -> false_type; template struct is_tuple : decltype(detail::has_tuple_size((T*)0)) {}; +template +constexpr bool is_tuple_v = cute::is_tuple::value; + // // make_tuple (value-based implementation) // @@ -540,20 +555,12 @@ tuple_cat(Tuples const&... ts) namespace detail { -template +template CUTE_HOST_DEVICE constexpr auto -equal_impl(TupleA const& a, TupleB const& b) +equal_impl(TupleA const& a, TupleB const& b, index_sequence) { - if constexpr (I == tuple_size::value) { - return cute::true_type{}; // Terminal: TupleA is exhausted - } else if constexpr (I == tuple_size::value) { - return cute::false_type{}; // Terminal: TupleA is not exhausted, TupleB is exhausted - } else { - return (get(a) == get(b)) && equal_impl(a,b); - } - - CUTE_GCC_UNREACHABLE; + return (cute::true_type{} && ... && (get(a) == get(b))); } } // end namespace detail @@ -564,7 +571,13 @@ CUTE_HOST_DEVICE constexpr auto operator==(TupleT const& t, TupleU const& u) { - return detail::equal_impl<0>(t, u); + if constexpr (tuple_size::value == tuple_size::value) { + return detail::equal_impl(t, u, make_index_sequence::value>{}); + } else { + return cute::false_type{}; + } + + CUTE_GCC_UNREACHABLE; } template -CUTE_HOST_DEVICE void print_tuple(Tuple const& t, - index_sequence, char s = '(', char e = ')') +CUTE_HOST_DEVICE void print_tuple(Tuple const& t, index_sequence, char s = '(', char e = ')') { using cute::print; - ((void(print(Is == 0 ? s : ',')), void(print(get(t)))), ...); print(e); + print(s); ((void(print(Is == 0 ? '\0' : ',')), void(print(get(t)))), ...); print(e); } #if !defined(__CUDACC_RTC__) template -CUTE_HOST std::ostream& print_tuple_os(std::ostream& os, Tuple const& t, - index_sequence, char s = '(', char e = ')') +CUTE_HOST std::ostream& print_tuple_os(std::ostream& os, Tuple const& t, index_sequence, char s = '(', char e = ')') { - (void(os << (Is == 0 ? s : ',') << get(t)), ...); + os << s; (void(os << (Is == 0 ? '\0' : ',') << get(t)), ...); return os << e; } #endif // !defined(__CUDACC_RTC__) @@ -655,6 +666,8 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, Tuple const& t) } // end namespace cute +#if ! defined(CUTLASS_USE_PACKED_TUPLE) + namespace CUTE_STL_NAMESPACE { @@ -716,5 +729,7 @@ struct tuple_element> : CUTE_STL_NAMESPACE::tuple_element> {}; -} // end namepsace std +} // end namespace std #endif // CUTE_STL_NAMESPACE_IS_CUDA_STD + +#endif // CUTLASS_USE_PACKED_TUPLE diff --git a/include/cute/container/type_list.hpp b/include/cute/container/type_list.hpp index 41c499ecac..2db934356b 100644 --- a/include/cute/container/type_list.hpp +++ b/include/cute/container/type_list.hpp @@ -30,19 +30,24 @@ **************************************************************************************************/ #pragma once -#include +#include +#include namespace cute { -template -struct type_c { - using type = T; -}; - template struct type_list {}; +// get for type_list +// requires tuple_element_t> to have std::is_default_constructible +template +CUTE_HOST_DEVICE constexpr +CUTE_STL_NAMESPACE::tuple_element_t> +get(type_list const& t) noexcept { + return {}; +} + } // end namespace cute // @@ -55,26 +60,6 @@ struct type_list {}; #include #endif -#include - -namespace cute -{ - -template -CUTE_HOST_DEVICE constexpr -CUTE_STL_NAMESPACE::tuple_element_t> -get(type_list&) noexcept { - return {}; -} -template -CUTE_HOST_DEVICE constexpr -CUTE_STL_NAMESPACE::tuple_element_t> -get(type_list const& t) noexcept { - return {}; -} - -} // end namespace cute - namespace CUTE_STL_NAMESPACE { @@ -85,8 +70,9 @@ struct tuple_size> template struct tuple_element> - : cute::type_c>::type> -{}; +{ + using type = typename CUTE_STL_NAMESPACE::tuple_element>::type; +}; template struct tuple_size> @@ -95,8 +81,9 @@ struct tuple_size> template struct tuple_element> - : cute::type_c>::type> -{}; +{ + using type = typename CUTE_STL_NAMESPACE::tuple_element>::type; +}; } // end namespace std @@ -119,8 +106,9 @@ struct tuple_size> template struct tuple_element> - : cute::type_c>::type> -{}; +{ + using type = typename CUTE_STL_NAMESPACE::tuple_element>::type; +}; template struct tuple_size> @@ -129,8 +117,9 @@ struct tuple_size> template struct tuple_element> - : cute::type_c>::type> -{}; +{ + using type = typename CUTE_STL_NAMESPACE::tuple_element>::type; +}; } // end namespace std #endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/include/cute/int_tuple.hpp b/include/cute/int_tuple.hpp index f8ca467181..ceafba0d80 100644 --- a/include/cute/int_tuple.hpp +++ b/include/cute/int_tuple.hpp @@ -493,6 +493,7 @@ using is_weakly_congruent = decltype(weakly_congruent(declval(), declval() /** Test if Shape A is compatible with Shape B: * the size of A and B are the same, and * any coordinate into A can also be used as a coordinate into B + * Equivalently, the size of Shape B is the same as Shape A at each terminal of Shape A. * compatible is a partial order on A and B: A <= B */ template @@ -523,6 +524,7 @@ using is_compatible = decltype(compatible(declval(), declval())); /** Test if Shape A is weakly compatible with Shape B: * there exists a Shape C congruent to A such that compatible(elem_scale(A,C), B) + * Equivalently, the size of Shape B is a multiple of Shape A at each terminal of Shape A. * weakly_compatible is a partial order on A and B: A <= B */ template @@ -551,6 +553,37 @@ weakly_compatible(IntTupleA const& a, IntTupleB const& b) template using is_weakly_compatible = decltype(weakly_compatible(declval(), declval())); +/** Test if Shape A is softly compatible with Shape B: + * there exists a Shape C congruent to A such that compatible(shape_div(A,C), B) + * Equivalently, the size of Shape B divides Shape A at each terminal of Shape A. + * softly_compatible is a partial order on A and B: A <= B + */ +template +CUTE_HOST_DEVICE constexpr +auto +softly_compatible(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + if constexpr (tuple_size::value != tuple_size::value) { + return false_type{}; + } else { + return transform_apply(a, b, [](auto const& x, auto const& y) { return softly_compatible(x,y); }, + [](auto const&... z) { return (true_type{} && ... && z); }); + } + } else if constexpr (is_integral::value) { + return a % size(b) == Int<0>{}; + } else if constexpr (is_integral::value) { + return false_type{}; + } else { + return softly_compatible(shape(a), shape(b)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +using is_softly_compatible = decltype(softly_compatible(declval(), declval())); + /** Replace the elements of Tuple B that are paired with an Int<0> with an Int<1> */ template diff --git a/include/cute/layout.hpp b/include/cute/layout.hpp index b7517a67ce..60581192b0 100644 --- a/include/cute/layout.hpp +++ b/include/cute/layout.hpp @@ -329,54 +329,83 @@ struct is_layout> : true_type {}; // Layout construction // -template ::value || is_integral::value) && - (is_tuple::value || is_integral::value))> +template CUTE_HOST_DEVICE constexpr auto make_layout(Shape const& shape, Stride const& stride) { + static_assert(is_tuple::value || is_integral::value); + static_assert(is_tuple::value || is_integral::value); return Layout(shape, stride); } -template ::value || is_integral::value)> +template CUTE_HOST_DEVICE constexpr auto make_layout(Shape const& shape) { - return make_layout(shape, compact_col_major(shape)); + static_assert(is_tuple::value || is_integral::value); + return make_layout(shape, compact_major(shape)); } -// Construct a layout from multiple layouts by -// concatenating each layout as an independent mode -template +// +// Convenience tags for common layouts +// + +template CUTE_HOST_DEVICE constexpr auto -make_layout(Layout const&... layouts) +make_layout(Shape const& shape, LayoutLeft) { - return make_layout(make_shape (layouts.shape()...), - make_stride(layouts.stride()...)); + return make_layout(shape, compact_major(shape)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Shape const& shape, LayoutRight) +{ + return make_layout(shape, compact_major(shape)); } // -// Convenience tags for common layouts +// Construct a layout from multiple layouts by concatenation // -template +// One argument overload +template CUTE_HOST_DEVICE constexpr auto -make_layout(Shape const& shape, GenColMajor) +make_layout(Layout const& layout0) { - return make_layout(shape, compact_col_major(shape)); + return make_layout(make_shape (layout0.shape() ), + make_stride(layout0.stride())); } -template +// Two argument overload +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Layout const& layout0, + Layout const& layout1) +{ + return make_layout(make_shape (layout0.shape() , layout1.shape() ), + make_stride(layout0.stride(), layout1.stride())); +} + +// Var argument overload +template CUTE_HOST_DEVICE constexpr auto -make_layout(Shape const& shape, GenRowMajor) +make_layout(Layout const& layout0, + Layout const& layout1, + Layout const&... layouts) { - return make_layout(shape, compact_row_major(shape)); + return make_layout(make_shape (layout0.shape() , layout1.shape() , layouts.shape()... ), + make_stride(layout0.stride(), layout1.stride(), layouts.stride()...)); } // @@ -428,7 +457,7 @@ make_fragment_like(Layout const& layout) constexpr int R = Layout::rank; if constexpr (R > 1 && is_static::value) { return tiled_product(make_layout(get<0>(layout.shape()), - compact_col_major(filter_zeros(get<0>(layout.stride()), get<0>(layout.shape())))), + compact_major(filter_zeros(get<0>(layout.stride()), get<0>(layout.shape())))), make_ordered_layout(take<1,R>(layout.shape()), take<1,R>(layout.stride()))); } else { return make_layout(layout.shape()); @@ -1131,7 +1160,7 @@ complement(Shape const& shape, Stride const& stride, CoTarget const& cotarget) // Compute the rest_shape and rest_stride auto new_stride = get<0>(stride_) * get<0>(shape_); // new stride = min_stride * curr_shape auto rest_shape = coalesce(ceil_div(cotarget, new_stride)); - auto rest_stride = compact_col_major(rest_shape, new_stride); + auto rest_stride = compact_major(rest_shape, new_stride); // Coalesce and append (rest_shape, rest_stride) return coalesce(make_layout(make_shape (result_shape , rest_shape ), @@ -1220,7 +1249,7 @@ right_inverse(Layout const& layout) return Layout<_1,_0>{}; // Empty case, nothing found } else { // Generate the corresponding new strides and construct - auto rstride = compact_col_major(flat_layout.shape()); + auto rstride = compact_major(flat_layout.shape()); return make_layout(unwrap(transform(iseq, [&](auto i) { return shape(flat_layout); })), unwrap(transform(iseq, [&](auto i) { return signum(stride(flat_layout)) * get(rstride); }))); } @@ -1318,6 +1347,50 @@ max_common_vector(Layout const& a, CUTE_GCC_UNREACHABLE; } +/* Return a layout that distributes ShapeB over ShapeA. + * + * @returns Layout result + * @post softly_compatible(@a b, @a result) + * @post For all i,j in [0,size(@a result)) with i < j, @a result(i) < @a result(j). Surjective and Ordered. + * @post composition(make_layout(shape(@a a)), @a result) is admissible + * \code + * // Note that 6 does not divide this shape + * Layout layoutA = Layout,Int<14>>>{}; + * + * // Want to tile any 6 elements and don't care where they come from + * Layout dist = domain_distribute(layoutA, Int<6>{}); // (_3,_2):(_1,_15) + * + * // Not guaranteed to find all 6 though... + * CUTE_STATIC_ASSERT_V(Int<6>{} == size(dist)); + * + * Layout result = zipped_divide(layoutA, dist); // (_6,Rest) + * \endcode + */ +template +CUTE_HOST_DEVICE constexpr +auto +domain_distribute(ShapeA const& a, ShapeB const& b) +{ + static_assert(is_integral::value); + static_assert(is_static::value); + + auto flat_shape_a = flatten(shape(a)); + + static_assert(is_static::value); + + // Compute the shape of the result + auto [result_shape, b_rest] = cute::fold(flat_shape_a, cute::make_tuple(cute::tuple<>{}, size(b)), [](auto init, auto a_) { + auto [result, b_] = init; + auto gcd_ = gcd(a_, b_); + return cute::make_tuple(append(result, gcd_), b_ / gcd_); + }); + + // Compute the stride of the result + auto result_stride = compact_major(flat_shape_a); + + return coalesce(make_layout(result_shape, result_stride)); +} + // // Kernel (Nullspace) of a Layout // @@ -1363,7 +1436,7 @@ nullspace(Layout const& layout) return Layout<_1,_0>{}; // Empty case, nothing found } else { // Generate the corresponding new strides and construct - auto rstride = compact_col_major(flat_layout.shape()); + auto rstride = compact_major(flat_layout.shape()); return make_layout(unwrap(transform(iseq, [&](auto i) { return shape(flat_layout); })), unwrap(transform(iseq, [&](auto i) { return get(rstride); }))); } @@ -1458,7 +1531,7 @@ auto ceil_div(Target const& target, Layout const& tiler) { - return complement(tiler, size(target)); + return shape(complement(tiler, shape(target))); } // @@ -1753,6 +1826,26 @@ recast_layout(Layout const& layout) CUTE_GCC_UNREACHABLE; } +// Determine the maximum alignment of a Layout. +// The maximum alignment is the largest N for which upcast(layout) will compile. +// upcast(layout) compiles when the static shapes and strides pass divisibility checks. +// Therefore, upcast(layout) will also compile for all divisors M of N. +// Note that this only considers the static shapes and strides of the Layout +// in symmetry with upcast only checking against static shapes and strides and assuming all +// dynamic shapes and strides are large and multiples of N. +template +CUTE_HOST_DEVICE constexpr +auto +max_alignment(Layout const& layout) +{ + auto flat_layout = coalesce(layout); + auto static_shape = transform( shape(flat_layout), [](auto s){ return conditional_return::value>(s, Int<1>{}); }); + auto static_stride = transform(stride(flat_layout), [](auto d){ return conditional_return::value>(d, Int<0>{}); }); + auto filter_layout = make_layout(static_shape, static_stride); + auto permuted = logical_divide(filter_layout, right_inverse(filter_layout)); + return gcd(size<0>(permuted), stride<1>(permuted)); +} + // // Display utilities // diff --git a/include/cute/layout_composed.hpp b/include/cute/layout_composed.hpp index 3dbd2cd939..fb62541cb4 100644 --- a/include/cute/layout_composed.hpp +++ b/include/cute/layout_composed.hpp @@ -577,6 +577,7 @@ coalesce(ComposedLayout const& layout, Shape const& trg_profile) return composition(layout.layout_a(), layout.offset(), coalesce(layout.layout_b(), trg_profile)); } + // // Upcast and Downcast // @@ -597,6 +598,7 @@ downcast(ComposedLayout const& layout) return composition(downcast(layout.layout_a()), downcast(layout.offset()), downcast(layout.layout_b())); } + template CUTE_HOST_DEVICE constexpr @@ -619,6 +621,16 @@ recast_layout(ComposedLayout const& layout) CUTE_GCC_UNREACHABLE; } +template +CUTE_HOST_DEVICE constexpr +auto +max_alignment(ComposedLayout const& layout) +{ + // Do not attempt for general ComposedLayouts + //return gcd(max_alignment(layout.layout_a()), max_alignment(layout.offset()), max_alignment(layout.layout_b())); + return Int<1>{}; +} + // // Display utilities // diff --git a/include/cute/numeric/complex.hpp b/include/cute/numeric/complex.hpp index 5113719dbd..5aa6664a89 100644 --- a/include/cute/numeric/complex.hpp +++ b/include/cute/numeric/complex.hpp @@ -48,13 +48,13 @@ template static constexpr auto is_complex_v = is_complex::value; /// Fused multiply-add for complex numbers -template +template CUTE_HOST_DEVICE constexpr void -fma(complex & d, - complex const& a, - complex const& b, - complex const& c) +fma(complex & d, + complex const& a, + complex const& b, + complex const& c) { fma(d.real(), a.real(), b.real(), c.real()); fma(d.imag(), a.real(), b.imag(), c.imag()); @@ -63,12 +63,12 @@ fma(complex & d, } /// Fused multiply-add for triplets -template +template CUTE_HOST_DEVICE constexpr void -fma(complex const& a, - complex const& b, - complex & c) +fma(complex const& a, + complex const& b, + complex & c) { return fma(c, a, b, c); } diff --git a/include/cute/numeric/integral_constant.hpp b/include/cute/numeric/integral_constant.hpp index 904a672638..c7bad24b84 100644 --- a/include/cute/numeric/integral_constant.hpp +++ b/include/cute/numeric/integral_constant.hpp @@ -33,6 +33,7 @@ #include "cute/util/print.hpp" #include "cute/util/type_traits.hpp" #include "cute/numeric/math.hpp" +#include "cutlass/fast_math.h" namespace cute { @@ -82,8 +83,11 @@ struct is_integral > : true_type {}; template struct is_integral> : true_type {}; -// is_static detects if an (abstract) value is defined completely by it's type (no members) +// Register FastDivmod as the integral type +template<> +struct is_integral : true_type {}; +// is_static detects if an (abstract) value is defined completely by its type (no members) template struct is_static : bool_constant>::value> {}; diff --git a/include/cute/numeric/math.hpp b/include/cute/numeric/math.hpp index 5be503390a..6d95165de2 100644 --- a/include/cute/numeric/math.hpp +++ b/include/cute/numeric/math.hpp @@ -33,6 +33,7 @@ #include #include +#include namespace cute { @@ -323,4 +324,33 @@ log_2(T x) { return static_cast(bit_width(x)) - 1; } +template +struct DivModReturnType { + IntDiv div_; + IntMod mod_; + CUTE_HOST_DEVICE constexpr + DivModReturnType(IntDiv const& div, IntMod const& mod) : div_(div), mod_(mod) {} +}; + +// General divmod +template +CUTE_HOST_DEVICE constexpr +auto +divmod(CInt0 const& a, CInt1 const& b) { + return DivModReturnType{a / b, a % b}; +} + +// Specialized function with fastDivmod input +template +CUTE_HOST_DEVICE constexpr +auto +divmod(CInt const& a, cutlass::FastDivmod const& b) { + using val_div_type = typename cutlass::FastDivmod::value_div_type; + using val_mod_type = typename cutlass::FastDivmod::value_mod_type; + val_div_type div = 0; + val_mod_type mod = 0; + b(div, mod, a); + return DivModReturnType{div, mod}; +} + } // namespace cute diff --git a/include/cute/stride.hpp b/include/cute/stride.hpp index 3b44bb20a1..09a02a00e7 100644 --- a/include/cute/stride.hpp +++ b/include/cute/stride.hpp @@ -31,8 +31,9 @@ #pragma once #include - #include +#include +#include namespace cute { @@ -79,8 +80,9 @@ crd2idx_itt(CInt const& coord, return crd2idx(_0{}, get(shape), get(stride)) + (_0{} + ... + crd2idx(_0{}, get(shape), get(stride))); } else { // General case - return crd2idx(coord % product(get(shape)), get(shape), get(stride)) - + crd2idx_itt(coord / product(get(shape)), shape, stride, seq{}); + auto [div, mod] = divmod(coord, product(get(shape))); + return crd2idx(mod, get(shape), get(stride)) + + crd2idx_itt(div, shape, stride, seq{}); } CUTE_GCC_UNREACHABLE; @@ -229,7 +231,7 @@ idx2crd(Index const& idx, } } else { if constexpr (is_tuple::value) { // "int" tuple - return idx2crd(idx, shape, compact_col_major(shape)); + return transform_leaf(as_arithmetic_tuple(crd2idx(idx, shape, make_basis_like(shape))), identity{}); } else { // "int" "int" return idx; } diff --git a/include/cute/swizzle.hpp b/include/cute/swizzle.hpp index 57735ce169..9ceb0d32b0 100644 --- a/include/cute/swizzle.hpp +++ b/include/cute/swizzle.hpp @@ -360,7 +360,7 @@ shiftr(MixedBits const& m, C s) } // -// upcast and downcast +// Upcast and Downcast // template @@ -410,6 +410,22 @@ downcast(T const& m) return m * C{}; } +template +CUTE_HOST_DEVICE constexpr +auto +max_alignment(MixedBits const&) +{ + return C{}; +} + +template +CUTE_HOST_DEVICE constexpr +C +max_alignment(C const& c) +{ + return c; +} + // // Convert a Pow2Layout+Coord to a MixedBits // diff --git a/include/cute/swizzle_layout.hpp b/include/cute/swizzle_layout.hpp index e345dadb86..82e51c79c6 100644 --- a/include/cute/swizzle_layout.hpp +++ b/include/cute/swizzle_layout.hpp @@ -36,6 +36,7 @@ #include #include +#include // get_swizzle /* Specialized functionality for a ComposedLayout of the form * InvolutionFn o Offset o LayoutB @@ -56,6 +57,9 @@ namespace cute { +template +struct get_swizzle,Offset,LayoutB>> { using type = Swizzle; }; + // // Constructors // @@ -117,7 +121,7 @@ CUTE_HOST_DEVICE constexpr auto make_fragment_like(ComposedLayout,Offset,Layout> const& layout) { - return detail::transfer_swizzle(layout.layout_b(), make_fragment_like(layout.layout_b())); + return make_fragment_like(layout.layout_b()); } // @@ -441,7 +445,7 @@ recast_layout(Swizzle const& swizzle) else if constexpr (scale::num == 1) { return downcast(swizzle); } - else if constexpr (scale::den == 1) { + else if constexpr (scale::den == 1) { return upcast(swizzle); } else { @@ -450,6 +454,24 @@ recast_layout(Swizzle const& swizzle) CUTE_GCC_UNREACHABLE; } +template +CUTE_HOST_DEVICE constexpr +auto +max_alignment(Swizzle const&) +{ + return Int{}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +max_alignment(ComposedLayout,Offset,LayoutB> const& layout) +{ + return gcd(max_alignment(layout.layout_a()), + max_alignment(layout.offset()), + max_alignment(layout.layout_b())); +} + // // Other operations // @@ -485,7 +507,7 @@ max_common_vector(ComposedLayout,Offset,LayoutB> const& a, Layout const& b) { // This assumes that Offset is in the YZ domain of the Swizzle... - return cute::min(Int<(1 << M)>{}, max_common_vector(a.layout_b(), b)); + return cute::min(max_common_vector(a.layout_b(), b), Int<(1 << M)>{}); } template @@ -504,12 +526,15 @@ auto max_common_vector(ComposedLayout,Offset0,LayoutB0> const& a, ComposedLayout,Offset1,LayoutB1> const& b) { - auto result = coalesce(composition(a, right_inverse(b))); + // Typical impl is composition(a, right_inverse(b)) + // so this is Sw0 o B0 o rinv(Sw1 o B1) = Sw0 o B0 o rinv(B1) o Sw1 + auto vec = max_common_vector(a.layout_b(), b.layout_b()); - if constexpr (is_constant<1, decltype(stride<0>(result.layout_b()))>::value) { - return shape<0>(result); + // This assumes that Offset is in the YZ domain of the Swizzle... + if constexpr (Swizzle{} == Swizzle{}) { + return vec; } else { - return Int<1>{}; + return cute::min(vec, Int<(1 << M0)>{}, Int<(1 << M1)>{}); } CUTE_GCC_UNREACHABLE; diff --git a/include/cute/tensor.hpp b/include/cute/tensor.hpp index 71ace9a81c..a45cbd0132 100644 --- a/include/cute/tensor.hpp +++ b/include/cute/tensor.hpp @@ -30,1053 +30,7 @@ **************************************************************************************************/ #pragma once -#include - -#include -#include -#include - -#include -#include -#include - -#include -#include - -namespace cute -{ - -// -// Engine -- owning or non-owning data store -// - -// concept Engine { -// using iterator = ; -// using value_type = ; -// using element_type = ; -// using reference = ; -// iterator begin(); -// }; - -template -struct ArrayEngine -{ - using Storage = typename conditional<(sizeof_bits::value % 8 == 0), - array_aligned, - array_subbyte>::type; - using iterator = typename Storage::iterator; - using reference = typename iterator_traits::reference; - using element_type = typename iterator_traits::element_type; - using value_type = typename iterator_traits::value_type; - Storage storage_; - - CUTE_HOST_DEVICE constexpr auto begin() const { return storage_.begin(); } - CUTE_HOST_DEVICE constexpr auto begin() { return storage_.begin(); } -}; - -template -struct ViewEngine -{ - using iterator = Iterator; - using reference = typename iterator_traits::reference; - using element_type = typename iterator_traits::element_type; - using value_type = typename iterator_traits::value_type; - iterator storage_; - - CUTE_HOST_DEVICE constexpr iterator const& begin() const { return storage_; } - CUTE_HOST_DEVICE constexpr iterator & begin() { return storage_; } -}; - -template -struct ConstViewEngine -{ - using iterator = Iterator; - using reference = typename iterator_traits::reference; - using element_type = typename iterator_traits::element_type; - using value_type = typename iterator_traits::value_type; - iterator storage_; - - CUTE_HOST_DEVICE constexpr iterator const& begin() const { return storage_; } -}; - -// -// Tensor -// - -template -struct Tensor -{ - using iterator = typename Engine::iterator; - using value_type = typename Engine::value_type; - using element_type = typename Engine::element_type; - using reference = typename Engine::reference; - - using engine_type = Engine; - using layout_type = Layout; - - CUTE_HOST_DEVICE constexpr - Tensor() {} - - template - CUTE_HOST_DEVICE constexpr - Tensor(Ptr const& ptr, Layout const& layout) - : rep_(layout, ptr) { - } - - // - // Accessors - // - - static constexpr int rank = Layout::rank; - - CUTE_HOST_DEVICE constexpr - decltype(auto) - tensor() const { - return *this; - } - - CUTE_HOST_DEVICE constexpr - decltype(auto) - layout() const { - return get<0>(rep_); - } - - CUTE_HOST_DEVICE constexpr - decltype(auto) - engine() const { - return get<1>(rep_); - } - - CUTE_HOST_DEVICE constexpr - decltype(auto) - engine() { - return get<1>(rep_); - } - - CUTE_HOST_DEVICE constexpr - decltype(auto) - data() const { - return engine().begin(); - } - - CUTE_HOST_DEVICE constexpr - decltype(auto) - data() { - return engine().begin(); - } - - CUTE_HOST_DEVICE constexpr - decltype(auto) - shape() const { - return layout().shape(); - } - - CUTE_HOST_DEVICE constexpr - auto - size() const { - return cute::size(shape()); - } - - CUTE_HOST_DEVICE constexpr - decltype(auto) - stride() const { - return layout().stride(); - } - - // - // Indexing op() and op[] - // - - // Index into this tensor like an array by computing the offset via layout() - template - CUTE_HOST_DEVICE constexpr - decltype(auto) - operator[](Coord const& coord) { - return data()[layout()(coord)]; - } - - template - CUTE_HOST_DEVICE constexpr - decltype(auto) - operator[](Coord const& coord) const { - return data()[layout()(coord)]; - } - - template - CUTE_HOST_DEVICE constexpr - decltype(auto) - operator()(Coord const& coord) { - if constexpr (has_underscore::value) { - auto const& [sliced_layout,offset] = slice_and_offset(coord, layout()); - return make_tensor(data() + offset, sliced_layout); - } else { - return data()[layout()(coord)]; - } - - CUTE_GCC_UNREACHABLE; - } - - template - CUTE_HOST_DEVICE constexpr - decltype(auto) - operator()(Coord const& coord) const { - if constexpr (has_underscore::value) { - auto const& [sliced_layout,offset] = slice_and_offset(coord, layout()); - return make_tensor(data() + offset, sliced_layout); - } else { - return data()[layout()(coord)]; - } - - CUTE_GCC_UNREACHABLE; - } - - // op() convenience function for multi-dimensional coordinates - template - CUTE_HOST_DEVICE constexpr - decltype(auto) - operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) { - return operator()(make_coord(c0,c1,cs...)); - } - - template - CUTE_HOST_DEVICE constexpr - decltype(auto) - operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const { - return operator()(make_coord(c0,c1,cs...)); - } - - // - // Compose - // - - template - CUTE_HOST_DEVICE constexpr - auto - compose(Layouts const&... layouts) { - return make_tensor(data(), layout().compose(layouts...)); - } - - template - CUTE_HOST_DEVICE constexpr - auto - compose(Layouts const&... layouts) const { - return make_tensor(data(), layout().compose(layouts...)); - } - - // - // Tile - // - - template - CUTE_HOST_DEVICE constexpr - auto - tile(Layouts const&... layouts) { - return make_tensor(data(), layout().tile(layouts...)); - } - - template - CUTE_HOST_DEVICE constexpr - auto - tile(Layouts const&... layouts) const { - return make_tensor(data(), layout().tile(layouts...)); - } - - // - // Utility - // - - template ::value)> - CUTE_HOST_DEVICE constexpr - auto - get_1d_coord(Int const& linear_idx) const { - return layout().get_1d_coord(linear_idx); - } - - template ::value)> - CUTE_HOST_DEVICE constexpr - auto - get_hier_coord(Int const& linear_idx) const { - return layout().get_hier_coord(linear_idx); - } - - template ::value)> - CUTE_HOST_DEVICE constexpr - auto - get_flat_coord(Int const& linear_idx) const { - return layout().get_flat_coord(linear_idx); - } - - cute::tuple rep_; -}; - -template -struct is_tensor : false_type {}; -template -struct is_tensor> : true_type {}; -template -constexpr bool is_tensor_v = is_tensor::value; - -// Customization point for creation of owning and non-owning Tensors -template -struct MakeTensor -{ - template ::value && - is_layout::value)> - CUTE_HOST_DEVICE constexpr auto - operator()(Layout const& layout) const - { - static_assert(is_static::value, "Dynamic owning tensors not supported"); - using Engine = ArrayEngine>; - return Tensor(); - } - - template ::value && - is_layout::value)> - CUTE_HOST_DEVICE constexpr auto - operator()(T const& iter, Layout const& layout) - { - using Engine = ViewEngine; - return Tensor(iter, layout); - } - - template ::value)> - CUTE_HOST_DEVICE constexpr auto - operator()(LayoutArg const& arg, LayoutArgs const&... args) const - { - return operator()(make_layout(arg, args...)); - } - - template ::value)> - CUTE_HOST_DEVICE constexpr auto - operator()(T const& iter, LayoutArg const& arg, LayoutArgs const&... args) - { - return operator()(iter, make_layout(arg, args...)); - } -}; - -// -// make_tensor -// - -// Make an owning Tensor that will allocate a static array -// e.g. make_tensor(Int<12>{}) -template -CUTE_HOST_DEVICE constexpr -auto -make_tensor(Args const&... args) -{ - return MakeTensor{}(args...); -} - -// Make a non-owning Tensor that will use a pointer (view) -// e.g. make_tensor(vec.data(), 12) -template -CUTE_HOST_DEVICE constexpr -auto -make_tensor(Iterator const& iter, Args const&... args) -{ - return MakeTensor{}(iter, args...); -} - -// -// make_tensor_like -// Make a register tensor the same type and shape and (if possible) order as another tensor -// - -template -CUTE_HOST_DEVICE constexpr -auto -make_tensor_like(Layout const& layout) -{ - return make_tensor(make_layout_like(layout)); -} - -template -CUTE_HOST_DEVICE constexpr -auto -make_tensor_like(Tensor const& tensor) -{ - return make_tensor_like(tensor.layout()); -} - -template -CUTE_HOST_DEVICE constexpr -auto -make_tensor_like(Tensor const& tensor) -{ - return make_tensor_like(tensor.layout()); -} - -// -// make_fragment_like -- -// Make a tensor the same shape and (if possible) order as another tensor, with special -// consideration of the 0th mode. The 0th mode is commonly used for MMA_Atoms or Copy_Atoms -// so this allocates the 0th mode with LayoutLeft regardless of the reference layout. -// - -template -CUTE_HOST_DEVICE constexpr -auto -make_fragment_like(Layout const& layout) -{ - return make_tensor(make_fragment_like(layout)); -} - -template -CUTE_HOST_DEVICE constexpr -auto -make_fragment_like(Tensor const& tensor) -{ - return make_fragment_like(tensor.layout()); -} - -template -CUTE_HOST_DEVICE constexpr -auto -make_fragment_like(Tensor const& tensor) -{ - return make_fragment_like(tensor.layout()); -} - -// -// make_counting_tensor -// Make a tensor from a layout by binding it to a counting iter with 0-offset of the same profile as the codomain. -// - -template ::value)> -CUTE_HOST_DEVICE constexpr -auto -make_counting_tensor(Layout const& layout) -{ - return make_tensor(make_inttuple_iter(repeat_like(coshape(layout), Int<0>{})), layout); -} - -// -// make_identity_tensor -// Make a tensor that maps coordinates within a shape to themselves. -// - -template -CUTE_HOST_DEVICE constexpr -auto -make_identity_tensor(Shape const& shape) -{ - return make_counting_tensor(make_identity_layout(shape)); -} - -// -// Utilities -// - -// Return the subtensor of a mode -template >::value)> -CUTE_HOST_DEVICE constexpr -decltype(auto) -tensor(Tensor&& tensor) -{ - return static_cast(tensor); -} - -template >::value)> -CUTE_HOST_DEVICE constexpr -decltype(auto) -tensor(Tensor&& tensor) -{ - return make_tensor(static_cast(tensor).data(), get(tensor.layout())); -} - -// Return the layout of a mode -template -CUTE_HOST_DEVICE constexpr -decltype(auto) -layout(Tensor const& tensor) -{ - return layout(tensor.layout()); -} - -// Return the shape of a mode -template -CUTE_HOST_DEVICE constexpr -decltype(auto) -shape(Tensor const& tensor) -{ - return shape(tensor.layout()); -} - -// Return the stride of a mode -template -CUTE_HOST_DEVICE constexpr -decltype(auto) -stride(Tensor const& tensor) -{ - return stride(tensor.layout()); -} - -// Return the number of elements in a mode -template -CUTE_HOST_DEVICE constexpr -decltype(auto) -size(Tensor const& tensor) -{ - return size(tensor.layout()); -} - -// Return the rank of a mode -template -CUTE_HOST_DEVICE constexpr -auto -rank(Tensor const& tensor) -{ - return rank(tensor.layout()); -} - -// Return the depth of a mode -template -CUTE_HOST_DEVICE constexpr -auto -depth(Tensor const& tensor) -{ - return depth(tensor.layout()); -} - -// -// Operations to manipulate Tensors like a Layout -// - -template >::value)> -CUTE_HOST_DEVICE constexpr -auto -flatten(Tensor&& tensor) -{ - return make_tensor(static_cast(tensor).data(), flatten(tensor.layout())); -} - -template >::value)> -CUTE_HOST_DEVICE constexpr -auto -coalesce(Tensor&& tensor) -{ - return make_tensor(static_cast(tensor).data(), coalesce(tensor.layout())); -} - -template >::value)> -CUTE_HOST_DEVICE constexpr -auto -coalesce(Tensor&& tensor, Profile const& profile) -{ - return make_tensor(static_cast(tensor).data(), coalesce(tensor.layout(), profile)); -} - -template >::value)> -CUTE_HOST_DEVICE constexpr -auto -filter_zeros(Tensor&& tensor) -{ - return make_tensor(static_cast(tensor).data(), filter_zeros(tensor.layout())); -} - -template >::value)> -CUTE_HOST_DEVICE constexpr -auto -filter(Tensor&& tensor) -{ - return make_tensor(static_cast(tensor).data(), filter(tensor.layout())); -} - -template >::value)> -CUTE_HOST_DEVICE constexpr -auto -filter(Tensor&& tensor, Profile const& profile) -{ - return make_tensor(static_cast(tensor).data(), filter(tensor.layout(), profile)); -} - -// Return a tensor with the same shape as input but offset by a given coordinate -template >::value)> -CUTE_HOST_DEVICE constexpr -auto -domain_offset(Coord const& coord, Tensor&& tensor) -{ - auto [layout, ptr_offset] = domain_offset(coord, tensor.layout()); - return make_tensor(static_cast(tensor).data() + ptr_offset, layout); -} - -// Group the modes [B,E) into a single mode -// e.g. group<2,4>(make_tensor(Layout>{})) -// => make_tensor(Layout,_5,_6>>{}) -template >::value)> -CUTE_HOST_DEVICE constexpr -auto -group_modes(Tensor&& tensor) -{ - return make_tensor(static_cast(tensor).data(), - group(tensor.layout())); -} - -// Return the subtensor of a range of modes -template >::value)> -CUTE_HOST_DEVICE constexpr -decltype(auto) -take(Tensor&& tensor) -{ - return make_tensor(static_cast(tensor).data(), take(tensor.layout())); -} - -// -// Recast -// - -// NOTE: This is very dangerous to do -// -- doesn't check dynamic integer divisibility -// -- doesn't check alignment - -template -CUTE_HOST_DEVICE constexpr -auto -recast(Tensor&& tensor) -{ - using OldType = typename remove_cvref_t::value_type; - auto old_layout = tensor.layout(); - auto new_layout = recast_layout(old_layout); - - // If this is an upcast of a normal Layout with static negative strides, then offset as well - if constexpr (sizeof(OldType) < sizeof(NewType) && not is_composed_layout::value) { - auto shape_diff = transform(flatten(old_layout.shape()), flatten(new_layout.shape()), minus{}); - auto extent_diff = transform(shape_diff, flatten(old_layout.stride()), multiplies{}); - auto offset = fold(extent_diff, Int<0>{}, [](auto const& i, auto const& a) { return i + cute::min(a,Int<0>{}); }); - - return make_tensor(recast_ptr(static_cast(tensor).data() + offset), new_layout); - } else { - return make_tensor(recast_ptr(static_cast(tensor).data() ), new_layout); - } - - CUTE_GCC_UNREACHABLE; -} - -// -// max_common_vector -// - -/* Return Int such that N is the maximum number of contiguous elements - * that logically correspond in the tensors of @a a and @a b. This is, - * the number of elements that could reasonably be vectorized into a single load/store. - * - * @returns Int with N >= 0 - * - * A return value of Int<0> indicates that no such conclusion can be made and no - * vectorization should be attempted. - * - * Note that the return value does NOT include alignment concerns such as the pointer value and - * the divisbility of dynamic strides. - */ -template -CUTE_HOST_DEVICE constexpr -auto -max_common_vector(Tensor const& a, - Tensor const& b) -{ - using SrcType = typename Tensor::value_type; - using DstType = typename Tensor::value_type; - using SrcRef = typename Tensor::reference; - using DstRef = typename Tensor::reference; - - // Determine if vectorization candidates at all - if constexpr (// Should be the same value_types, else the copy is also performing a cast - sizeof_bits_v == sizeof_bits_v && - // The types should be trivially copyable so that vectorization is valid - is_trivially_copyable::value && - is_trivially_copyable::value && - // Should be load/storing real data, rather than implicit iterators or such - is_reference::value && - is_reference::value) - { - return max_common_vector(a.layout(), b.layout()); - } else { - return Int<0>{}; - } - - CUTE_GCC_UNREACHABLE; -} - -/* Return a layout that points to the maximum number of contiguous elements - * that logically correspond in the tensors of @a a and @a b. This is, - * the elements that could reasonably be "vectorized" into a single load/store. - * - * @returns Layout R such that composition(a.layout(), R) and composition(b.layout(), R) - * are both identity Layouts. - * - * Note that the returned layout does NOT include alignment concerns such as the pointer value and - * the divisbility of dynamic strides. - */ -template -CUTE_HOST_DEVICE constexpr -auto -max_common_layout(Tensor const& a, - Tensor const& b) -{ - using SrcType = typename Tensor::value_type; - using DstType = typename Tensor::value_type; - using SrcRef = typename Tensor::reference; - using DstRef = typename Tensor::reference; - - // Determine if vectorization candidates at all - if constexpr (// Should be the same value_types, else the copy is also performing a cast - sizeof_bits_v == sizeof_bits_v && - // The types should be trivially copyable so that vectorization is valid - is_trivially_copyable::value && - is_trivially_copyable::value && - // Should be load/storing real data, rather than implicit iterators or such - is_reference::value && - is_reference::value) - { - return max_common_layout(a.layout(), b.layout()); - } else { - return Layout<_1,_0>{}; - } - - CUTE_GCC_UNREACHABLE; -} - -// -// Key algebraic operations -- Divide and Product -// - -// Apply a Tiler to the Tensor. -// -// Consider a Tensor with shape (A,B,x,y) -// And a Tiler that is: -// -// * A Layout with shape (BLK_A,BLK_B) -// ** Result Tensor shape ((BLK_A,BLK_B),Rest). -// ** That is, the Tensor and Tile are treated as 1D for the tiling. -// ** See logical_divide(Layout,Layout) -// -// * A Tile with shape -// ** Result Tensor shape ((BLK_A,a),(BLK_B,b),x,y). -// ** Each mode of the Tile is applied to the corresponding mode of the Tensor. -// ** See logical_divide(Layout,Tuple) -// -// * A Shape (BLK_A,BLK_B) -// ** Result Tensor shape ((BLK_A,a),(BLK_B,b),x,y). -// ** Equivalent to applying Tile. -// ** See logical_divide(Layout,Tuple) and logical_divide(Layout,Int) -// -// Note that the Tile/Shape Tilers must be weakly_congruent to the Tensor -template >::value)> -CUTE_HOST_DEVICE constexpr -auto -logical_divide(Tensor && tensor, - Tiler const& tiler) // Layout or Tile or Shape -{ - return make_tensor(static_cast(tensor).data(), - logical_divide(tensor.layout(), tiler)); -} - -// zipped_divide is logical_divide with Tiler modes and Rest modes gathered together: (Tiler,Rest) -// When Tiler is Layout, this has no effect as logical_divide results in the same. -// When Tiler is Tile or Shape, this zips modes into standard form ((BLK_A,BLK_B),(a,b,x,y)) -template >::value)> -CUTE_HOST_DEVICE constexpr -auto -zipped_divide(Tensor && tensor, - Tiler const& tiler) // Layout or Tile or Shape -{ - return make_tensor(static_cast(tensor).data(), - zipped_divide(tensor.layout(), tiler)); -} - -// tiled_divide is zipped_divide with the second output mode flattened ((BLK_A,BLK_B),a,b,x,y) -template >::value)> -CUTE_HOST_DEVICE constexpr -auto -tiled_divide(Tensor && tensor, - Tiler const& tiler) // Layout or Tile or Shape -{ - return make_tensor(static_cast(tensor).data(), - tiled_divide(tensor.layout(), tiler)); -} - -// flat_divide is zipped_divide with the both modes flattened (BLK_A,BLK_B,a,b,x,y) -template >::value)> -CUTE_HOST_DEVICE constexpr -auto -flat_divide(Tensor && tensor, - Tiler const& tiler) // Layout or Tile or Shape -{ - return make_tensor(static_cast(tensor).data(), - flat_divide(tensor.layout(), tiler)); -} - -// logical_product on a Tensor doesn't make sense since it often increases cosize -// though this might make sense for creating Tensors with broadcasted (stride-0) modes - -// -// Tensor partitioning utilities -// - -// Apply a Tiler to the Tensor, then slice out one of those tiles by slicing into the "Rest" modes. -// With an inner_partition, you get everything that's inside the Tiler. Everything that the Tiler is pointing to. -// Split the modes of tensor according to the Tiler -// zipped_divide returns something like ((BLK_A,BLK_B,...),(a,b,...,x,y)) -// Then slice into the second mode (the "Rest" mode) with Coord -template >::value)> -CUTE_HOST_DEVICE constexpr -auto -inner_partition(Tensor && tensor, - Tiler const& tiler, - Coord const& coord) -{ - auto tensor_tiled = zipped_divide(static_cast(tensor), tiler); - constexpr int R0 = decltype(rank<0>(tensor_tiled))::value; - - // The coord slices into the second mode (the "rest" mode), flatten the first - if constexpr (is_tuple::value) { - // Append trailing modes if coord is tuple - constexpr int R1 = decltype(rank<1>(tensor_tiled))::value;; - return tensor_tiled(repeat(_), append(coord,_)); - } else { - // Flat indexing if coord is not tuple - return tensor_tiled(repeat(_), coord); - } -} - -// Apply a Tiler to the Tensor, then slice out the remainder by slicing into the "Tile" modes. -// With an outer_partition, you get everything that's outside the Tiler. The layout of the Tile in the Tensor. -// Split the modes of tensor according to the Tiler -// zipped_divide returns something like ((BLK_A,BLK_B,...),(a,b,...,x,y)) -// Then slice into the first mode (the "Tile" mode) with Coord -template >::value)> -CUTE_HOST_DEVICE constexpr -auto -outer_partition(Tensor && tensor, - Tiler const& tiler, - Coord const& coord) -{ - auto tensor_tiled = zipped_divide(static_cast(tensor), tiler); - constexpr int R1 = decltype(rank<1>(tensor_tiled))::value; - - // The coord slices into the first mode (the "tile" mode), flatten the second - if constexpr (is_tuple::value) { - // Append trailing modes if coord is tuple - constexpr int R0 = decltype(rank<0>(tensor_tiled))::value; - return tensor_tiled(append(coord,_), repeat(_)); - } else { - // Flat indexing if coord is not tuple - return tensor_tiled(coord, repeat(_)); - } -} - -// Tile a tensor according to @a tiler and use @a coord to index into the remainder, keeping the tile. -// This is typical at the CTA level where tiles of data are extracted: -// Tensor data = ... // ( M, N) -// Tensor cta_data = local_tile(data, Shape<_32,_64>{}, make_coord(blockIdx.x,blockIdx.y)); // (_32,_64) -template >::value)> -CUTE_HOST_DEVICE constexpr -auto -local_tile(Tensor && tensor, - Tiler const& tiler, // tiler to apply - Coord const& coord) // coord to slice into "remainder" -{ - return inner_partition(static_cast(tensor), - tiler, - coord); -} - -// Same as above, but with a projection parameter to strip out unwanted tiling modes for convenience -// when using projections of the same tiler. -// This is typical at the CTA level where tiles of data are extracted as projections: -// Tensor dataA = ... // (M,K) -// Tensor dataB = ... // (N,K) -// Tensor dataC = ... // (M,N) -// auto cta_tiler = Shape<_32, _64, _4>{}; -// auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); -// Tensor ctaA = local_tile(dataA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (_32,_4,k) -// Tensor ctaB = local_tile(dataA, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (_64,_4,k) -// Tensor ctaC = local_tile(dataA, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (_32,_64) -template >::value)> -CUTE_HOST_DEVICE -auto -local_tile(Tensor && tensor, - Tiler const& tiler, // tiler to apply - Coord const& coord, // coord to slice into "remainder" - Proj const& proj) // projection to apply to tiler and coord -{ - return local_tile(static_cast(tensor), - dice(proj, tiler), - dice(proj, coord)); -} - -// Tile a tensor according to the flat shape of a layout that provides the coordinate of the target index. -// This is typical at the Thread level where data is partitioned across repeated patterns of threads: -// Tensor data = ... // (_16,_64) -// Tensor thr_data = local_partition(data, Layout>{}, thr_idx); // ( _8, _4) -template >::value)> -CUTE_HOST_DEVICE -auto -local_partition(Tensor && tensor, - Layout const& tile, // coord -> index - Index const& index) // index to slice for -{ - static_assert(is_integral::value); - return outer_partition(static_cast(tensor), - product_each(shape(tile)), - tile.get_flat_coord(index)); -} - -// Same as above, but with a projection parameter to strip out unwanted tiling modes for convenience -// when using projections of the same tiler. -// This is typical at the Thread level where data is partitioned across projected layouts of threads: -// Tensor dataA = ... // (M,K) -// Tensor dataB = ... // (N,K) -// Tensor dataC = ... // (M,N) -// auto thr_layout = Layout, Stride<_16,_1,_0>>{}; -// Tensor thrA = local_partition(dataA, thr_layout, thr_idx, Step<_1, X,_1>{}); // (M/2,K/1) -// Tensor thrB = local_partition(dataB, thr_layout, thr_idx, Step< X,_1,_1>{}); // (N/16,K/1) -// Tensor thrC = local_partition(dataC, thr_layout, thr_idx, Step<_1,_1, X>{}); // (M/2,N/16) -template >::value)> -CUTE_HOST_DEVICE -auto -local_partition(Tensor && tensor, - Layout const& tile, // coord -> index - Index const& index, // index to slice for - Projection const& proj) -{ - return local_partition(static_cast(tensor), - dice(proj, tile), - index); -} - -// -// Display utilities -// - -template -CUTE_HOST_DEVICE void print(Tensor const& tensor) -{ - print(tensor.data()); print(" o "); print(tensor.layout()); -} - -template -CUTE_HOST_DEVICE void print_tensor(Tensor const& tensor, bool print_type = true) -{ - if (print_type) { - print(tensor); print(":\n"); - } - - if constexpr (Layout::rank == 1) - { - for (int m = 0; m < size(tensor); ++m) { - pretty_print(tensor(m)); - printf("\n"); - } - } else - if constexpr (Layout::rank == 2) - { - for (int m = 0; m < size<0>(tensor); ++m) { - for (int n = 0; n < size<1>(tensor); ++n) { - pretty_print(tensor(m,n)); - } - printf("\n"); - } - } else - if constexpr (Layout::rank == 3) - { - print_tensor(tensor(_,_,0), false); - for (int k = 1; k < size<2>(tensor); ++k) { - for (int i = 0; i < 5*size<1>(tensor); ++i) { print("-"); } print("\n"); - print_tensor(tensor(_,_,k), false); - } - } else - if constexpr (Layout::rank == 4) - { - print_tensor(tensor(_,_,_,0), false); - for (int p = 1; p < size<3>(tensor); ++p) { - for (int i = 0; i < 5*size<1>(tensor); ++i) { print("="); } print("\n"); - print_tensor(tensor(_,_,_,p), false); - } - } -} - -#if !defined(__CUDACC_RTC__) -template -CUTE_HOST std::ostream& print_tensor_os(std::ostream& os, Tensor const& tensor) -{ - int digits = 9; - - if constexpr (Layout::rank == 1) - { - for (int m = 0; m < size(tensor); ++m) { - os << std::setw(digits) << tensor(m) << std::endl; - } - } else - if constexpr (Layout::rank == 2) - { - for (int m = 0; m < size<0>(tensor); ++m) { - for (int n = 0; n < size<1>(tensor); ++n) { - os << std::setw(digits) << tensor(m,n); - } - os << std::endl; - } - } else - if constexpr (Layout::rank == 3) - { - print_tensor_os(os, tensor(_,_,0)); - for (int k = 1; k < size<2>(tensor); ++k) { - for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "-"; } os << std::endl; - print_tensor_os(os, tensor(_,_,k)); - } - } else - if constexpr (Layout::rank == 4) - { - print_tensor_os(os, tensor(_,_,_,0)); - for (int p = 1; p < size<3>(tensor); ++p) { - for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "="; } os << std::endl; - print_tensor_os(os, tensor(_,_,_,p)); - } - } - - return os; -} - -template -CUTE_HOST std::ostream& operator<<(std::ostream& os, Tensor const& tensor) -{ - os << tensor.layout() << std::endl; - return print_tensor_os(os, tensor); -} -#endif // !defined(__CUDACC_RTC__) - -} // end namespace cute +#include // // Extended Engines @@ -1098,3 +52,4 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, Tensor const #include #include + diff --git a/include/cute/tensor_impl.hpp b/include/cute/tensor_impl.hpp new file mode 100644 index 0000000000..da0e245636 --- /dev/null +++ b/include/cute/tensor_impl.hpp @@ -0,0 +1,1153 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief This file contains the definition of Tensor as well as classes/functions most closely associated with it. + + For backwards-compatibility, "tensor.hpp" is the "entrypoint" header for a collection of classes and utilities + that are adjacent to Tensor, e.g. fill(). Whereas this file contains the actual definition of Tensor and + a small set of functions central to its usage. + + Within the CUTLASS codebase, favor not including "tensor.hpp" wherever possible; instead include "tensor_impl.hpp" + along with other specific headers that you need. This helps to avoid circular includes and to reduce build time. +*/ + +#pragma once + +#include + +#include +#include +#include + +#include +#include +#include + +#include +#include + +namespace cute +{ + +// +// Engine -- owning or non-owning data store +// + +// concept Engine { +// using iterator = ; +// using value_type = ; +// using element_type = ; +// using reference = ; +// iterator begin(); +// }; + +template +struct ArrayEngine +{ + using Storage = typename conditional<(sizeof_bits::value % 8 == 0), + array_aligned, + array_subbyte>::type; + using iterator = typename Storage::iterator; + using reference = typename iterator_traits::reference; + using element_type = typename iterator_traits::element_type; + using value_type = typename iterator_traits::value_type; + Storage storage_; + + CUTE_HOST_DEVICE constexpr auto begin() const { return storage_.begin(); } + CUTE_HOST_DEVICE constexpr auto begin() { return storage_.begin(); } +}; + +template +struct ViewEngine +{ + using iterator = Iterator; + using reference = typename iterator_traits::reference; + using element_type = typename iterator_traits::element_type; + using value_type = typename iterator_traits::value_type; + iterator storage_; + + CUTE_HOST_DEVICE constexpr iterator const& begin() const { return storage_; } + CUTE_HOST_DEVICE constexpr iterator & begin() { return storage_; } +}; + +template +struct ConstViewEngine +{ + using iterator = Iterator; + using reference = typename iterator_traits::reference; + using element_type = typename iterator_traits::element_type; + using value_type = typename iterator_traits::value_type; + iterator storage_; + + CUTE_HOST_DEVICE constexpr iterator const& begin() const { return storage_; } +}; + +// +// Tensor +// + +template +struct Tensor +{ + using iterator = typename Engine::iterator; + using value_type = typename Engine::value_type; + using element_type = typename Engine::element_type; + using reference = typename Engine::reference; + + using engine_type = Engine; + using layout_type = Layout; + + CUTE_HOST_DEVICE constexpr + Tensor() {} + + CUTE_HOST_DEVICE constexpr + Tensor(Engine const& engine, Layout const& layout) + : rep_(layout, engine) { + } + + // + // Accessors + // + + static constexpr int rank = Layout::rank; + + CUTE_HOST_DEVICE constexpr + decltype(auto) + tensor() const { + return *this; + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + engine() const { + return get<1>(rep_); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + engine() { + return get<1>(rep_); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + data() const { + return engine().begin(); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + data() { + return engine().begin(); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout() const { + return get<0>(rep_); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + shape() const { + return layout().shape(); + } + + CUTE_HOST_DEVICE constexpr + auto + size() const { + return cute::size(shape()); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + stride() const { + return layout().stride(); + } + + // + // Indexing op() and op[] + // + + // Index into this tensor like an array by computing the offset via layout() + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator[](Coord const& coord) { + return data()[layout()(coord)]; + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator[](Coord const& coord) const { + return data()[layout()(coord)]; + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator()(Coord const& coord) { + if constexpr (has_underscore::value) { + auto const& [sliced_layout,offset] = slice_and_offset(coord, layout()); + return make_tensor(data() + offset, sliced_layout); + } else { + return data()[layout()(coord)]; + } + + CUTE_GCC_UNREACHABLE; + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator()(Coord const& coord) const { + if constexpr (has_underscore::value) { + auto const& [sliced_layout,offset] = slice_and_offset(coord, layout()); + return make_tensor(data() + offset, sliced_layout); + } else { + return data()[layout()(coord)]; + } + + CUTE_GCC_UNREACHABLE; + } + + // op() convenience function for multi-dimensional coordinates + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) { + return operator()(make_coord(c0,c1,cs...)); + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const { + return operator()(make_coord(c0,c1,cs...)); + } + + // + // Compose + // + + template + CUTE_HOST_DEVICE constexpr + auto + compose(Layouts const&... layouts) { + return make_tensor(data(), layout().compose(layouts...)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + compose(Layouts const&... layouts) const { + return make_tensor(data(), layout().compose(layouts...)); + } + + // + // Tile + // + + template + CUTE_HOST_DEVICE constexpr + auto + tile(Layouts const&... layouts) { + return make_tensor(data(), layout().tile(layouts...)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + tile(Layouts const&... layouts) const { + return make_tensor(data(), layout().tile(layouts...)); + } + + // + // Utility + // + + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_1d_coord(Int const& linear_idx) const { + return layout().get_1d_coord(linear_idx); + } + + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_hier_coord(Int const& linear_idx) const { + return layout().get_hier_coord(linear_idx); + } + + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_flat_coord(Int const& linear_idx) const { + return layout().get_flat_coord(linear_idx); + } + + cute::tuple rep_; +}; + +template +struct is_tensor : false_type {}; +template +struct is_tensor> : true_type {}; +template +constexpr bool is_tensor_v = is_tensor::value; + +// Customization point for creation of owning and non-owning Tensors +template +struct MakeTensor +{ + template + CUTE_HOST_DEVICE constexpr auto + operator()(Arg0 const& arg0, Args const&... args) const + { + if constexpr (has_dereference::value) { + // Construct a non-owning Tensor + using Engine = ViewEngine; + if constexpr (sizeof...(Args) == 1 && (is_layout::value && ...)) { + // Forward a Layout + return Tensor{Engine{arg0}, args...}; + } else { + // Construct a Layout from Args + return Tensor{Engine{arg0}, make_layout(args...)}; + } + } else { + // Construct an owning Tensor + static_assert((is_static::value && ... && is_static::value), + "Dynamic owning tensors not supported"); + if constexpr (sizeof...(Args) == 0 && is_layout::value) { + // Forward a Layout + using Layout = Arg0; + using Engine = ArrayEngine>; + return Tensor(); + } else { + // Construct a Layout from Args + using Layout = decltype(make_layout(arg0, args...)); + using Engine = ArrayEngine>; + return Tensor(); + } + } + } +}; + +// +// make_tensor +// + +// Make an owning Tensor that will allocate a static array +// e.g. make_tensor(Int<12>{}) +template +CUTE_HOST_DEVICE constexpr +auto +make_tensor(Args const&... args) +{ + static_assert((not has_dereference::value && ...), "Expected layout args... in make_tensor(args...)"); + return MakeTensor{}(args...); +} + +// Make a non-owning Tensor that will use a pointer (view) +// e.g. make_tensor(vec.data(), 12) +template +CUTE_HOST_DEVICE constexpr +auto +make_tensor(Iterator const& iter, Args const&... args) +{ + static_assert(has_dereference::value, "Expected iterator iter in make_tensor(iter, args...)"); + static_assert((not has_dereference::value && ...), "Expected layout args... in make_tensor(iter, args...)"); + return MakeTensor{}(iter, args...); +} + +// +// make_tensor_like +// Make a register tensor the same type and shape and (if possible) order as another tensor +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_tensor_like(Layout const& layout) +{ + return make_tensor(make_layout_like(layout)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_tensor_like(Tensor const& tensor) +{ + return make_tensor_like(tensor.layout()); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_tensor_like(Tensor const& tensor) +{ + return make_tensor_like(tensor.layout()); +} + +// +// make_fragment_like +// Make a tensor the same shape and (if possible) order as another tensor, with special +// consideration of the 0th mode. The 0th mode is commonly used for MMA_Atoms or Copy_Atoms +// so this allocates the 0th mode with LayoutLeft regardless of the reference layout. +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_fragment_like(Layout const& layout) +{ + return make_tensor(make_fragment_like(layout)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_fragment_like(Tensor const& tensor) +{ + return make_fragment_like(tensor.layout()); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_fragment_like(Tensor const& tensor) +{ + return make_fragment_like(tensor.layout()); +} + +// +// make_counting_tensor +// Make a tensor from a layout by binding it to a counting iter with 0-offset of the same profile as the codomain. +// + +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +make_counting_tensor(Layout const& layout) +{ + return make_tensor(make_inttuple_iter(repeat_like(coshape(layout), Int<0>{})), layout); +} + +// +// make_identity_tensor +// Make a tensor that maps coordinates within a shape to themselves. +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_identity_tensor(Shape const& shape) +{ + return make_counting_tensor(make_identity_layout(shape)); +} + +// +// Utilities +// + +// Return the subtensor of a mode +template +CUTE_HOST_DEVICE constexpr +auto +tensor(Tensor&& tensor) +{ + if constexpr (sizeof...(Is) == 0) { + return tensor; + } else { + return make_tensor(tensor.data(), get(tensor.layout())); + } + + CUTE_GCC_UNREACHABLE; +} + +// Return the layout of a mode +template +CUTE_HOST_DEVICE constexpr +auto +layout(Tensor const& tensor) +{ + return layout(tensor.layout()); +} + +// Return the shape of a mode +template +CUTE_HOST_DEVICE constexpr +auto +shape(Tensor const& tensor) +{ + return shape(tensor.layout()); +} + +// Return the stride of a mode +template +CUTE_HOST_DEVICE constexpr +auto +stride(Tensor const& tensor) +{ + return stride(tensor.layout()); +} + +// Return the number of elements in a mode +template +CUTE_HOST_DEVICE constexpr +auto +size(Tensor const& tensor) +{ + return size(tensor.layout()); +} + +// Return the rank of a mode +template +CUTE_HOST_DEVICE constexpr +auto +rank(Tensor const& tensor) +{ + return rank(tensor.layout()); +} + +// Return the depth of a mode +template +CUTE_HOST_DEVICE constexpr +auto +depth(Tensor const& tensor) +{ + return depth(tensor.layout()); +} + +// +// Operations to manipulate Tensors like a Layout or IntTuple +// These are implemented with explicit modifier overloads because these +// methods likely also have a general IntTuple overload that can shadow. +// + +template +CUTE_HOST_DEVICE constexpr +auto +flatten(Tensor const& tensor) { + return make_tensor(tensor.data(), flatten(tensor.layout())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +flatten(Tensor& tensor) { + return make_tensor(tensor.data(), flatten(tensor.layout())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +flatten(Tensor&& tensor) { + return make_tensor(tensor.data(), flatten(tensor.layout())); +} + +template > +CUTE_HOST_DEVICE constexpr +auto +coalesce(Tensor const& tensor, Profile const& profile = {}) { + return make_tensor(tensor.data(), coalesce(tensor.layout(), profile)); +} + +template > +CUTE_HOST_DEVICE constexpr +auto +coalesce(Tensor& tensor, Profile const& profile = {}) { + return make_tensor(tensor.data(), coalesce(tensor.layout(), profile)); +} + +template > +CUTE_HOST_DEVICE constexpr +auto +coalesce(Tensor&& tensor, Profile const& profile = {}) { + return make_tensor(tensor.data(), coalesce(tensor.layout(), profile)); +} + +// Replace the modes in layout that have a 0-stride with a 1-size +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Tensor const& tensor) { + return make_tensor(tensor.data(), filter_zeros(tensor.layout())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Tensor& tensor) { + return make_tensor(tensor.data(), filter_zeros(tensor.layout())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Tensor&& tensor) { + return make_tensor(tensor.data(), filter_zeros(tensor.layout())); +} + +// Remove all of the 0-strides and 1-sizes +template +CUTE_HOST_DEVICE constexpr +auto +filter(Tensor const& tensor) { + return make_tensor(tensor.data(), filter(tensor.layout())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter(Tensor& tensor) { + return make_tensor(tensor.data(), filter(tensor.layout())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter(Tensor&& tensor) { + return make_tensor(tensor.data(), filter(tensor.layout())); +} + +// Group the modes [B,E) into a single mode +// e.g. group<2,4>(make_tensor(Layout>{})) +// => make_tensor(Layout,_5,_6>>{}) +template +CUTE_HOST_DEVICE constexpr +auto +group_modes(Tensor const& tensor) { + return make_tensor(tensor.data(), group(tensor.layout())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +group_modes(Tensor& tensor) { + return make_tensor(tensor.data(), group(tensor.layout())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +group_modes(Tensor&& tensor) { + return make_tensor(tensor.data(), group(tensor.layout())); +} + +// Return the subtensor of a range of modes +template +CUTE_HOST_DEVICE constexpr +auto +take(Tensor const& tensor) { + return make_tensor(tensor.data(), take(tensor.layout())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +take(Tensor& tensor) { + return make_tensor(tensor.data(), take(tensor.layout())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +take(Tensor&& tensor) { + return make_tensor(tensor.data(), take(tensor.layout())); +} + +// Return a tensor with the same shape as input but offset by a given coordinate +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +domain_offset(Coord const& coord, Tensor&& tensor) +{ + auto [layout, ptr_offset] = domain_offset(coord, tensor.layout()); + return make_tensor(static_cast(tensor).data() + ptr_offset, layout); +} + +// +// Recast +// + +// NOTE: This is very dangerous to do +// -- doesn't check dynamic integer divisibility +// -- doesn't check alignment + +template +CUTE_HOST_DEVICE constexpr +auto +recast(Tensor&& tensor) +{ + using OldType = typename remove_cvref_t::value_type; + auto old_layout = tensor.layout(); + auto new_layout = recast_layout(old_layout); + + // If this is an upcast of a normal Layout with static negative strides, then offset as well + if constexpr (sizeof(OldType) < sizeof(NewType) && not is_composed_layout::value) { + auto shape_diff = transform(flatten(old_layout.shape()), flatten(new_layout.shape()), minus{}); + auto extent_diff = transform(shape_diff, flatten(old_layout.stride()), multiplies{}); + auto offset = fold(extent_diff, Int<0>{}, [](auto const& i, auto const& a) { return i + cute::min(a,Int<0>{}); }); + + return make_tensor(recast_ptr(static_cast(tensor).data() + offset), new_layout); + } else { + return make_tensor(recast_ptr(static_cast(tensor).data() ), new_layout); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// max_common_vector +// + +/* Return Int such that N is the maximum number of contiguous elements + * that logically correspond in the tensors of @a a and @a b. This is, + * the number of elements that could reasonably be vectorized into a single load/store. + * + * @returns Int with N >= 0 + * + * A return value of Int<0> indicates that no such conclusion can be made and no + * vectorization should be attempted. + * + * Note that the return value does NOT include alignment concerns such as the pointer value and + * the divisbility of dynamic strides. + */ +template +CUTE_HOST_DEVICE constexpr +auto +max_common_vector(Tensor const& a, + Tensor const& b) +{ + using SrcType = typename Tensor::value_type; + using DstType = typename Tensor::value_type; + using SrcRef = typename Tensor::reference; + using DstRef = typename Tensor::reference; + + // Determine if vectorization candidates at all + if constexpr (// Should be the same value_types, else the copy is also performing a cast + cute::is_same::value && + // The types should be trivially copyable so that vectorization is valid + is_trivially_copyable::value && + is_trivially_copyable::value && + // Should be load/storing real data, rather than implicit iterators or such + is_reference::value && + is_reference::value) + { + return max_common_vector(a.layout(), b.layout()); + } else { + return Int<0>{}; + } + + CUTE_GCC_UNREACHABLE; +} + +/* Return a layout that points to the maximum number of contiguous elements + * that logically correspond in the tensors of @a a and @a b. This is, + * the elements that could reasonably be "vectorized" into a single load/store. + * + * @returns Layout R such that composition(a.layout(), R) and composition(b.layout(), R) + * are both identity Layouts. + * + * Note that the returned layout does NOT include alignment concerns such as the pointer value and + * the divisbility of dynamic strides. + */ +template +CUTE_HOST_DEVICE constexpr +auto +max_common_layout(Tensor const& a, + Tensor const& b) +{ + using SrcType = typename Tensor::value_type; + using DstType = typename Tensor::value_type; + using SrcRef = typename Tensor::reference; + using DstRef = typename Tensor::reference; + + // Determine if vectorization candidates at all + if constexpr (// Should be the same value_types, else the copy is also performing a cast + cute::is_same::value && + // The types should be trivially copyable so that vectorization is valid + is_trivially_copyable::value && + is_trivially_copyable::value && + // Should be load/storing real data, rather than implicit iterators or such + is_reference::value && + is_reference::value) + { + return max_common_layout(a.layout(), b.layout()); + } else { + return Layout<_1,_0>{}; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Key algebraic operations -- Composition, Divide, and Product +// + +// Apply a Tiler to the Tensor via composition. +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +composition(Tensor && tensor, + Tiler const& tiler) // Layout or Tile or Shape +{ + return make_tensor(static_cast(tensor).data(), + composition(tensor.layout(), tiler)); +} + +// Apply a Tiler to the Tensor. +// +// Consider a Tensor with shape (A,B,x,y) +// And a Tiler that is: +// +// * A Layout with shape (BLK_A,BLK_B) +// ** Result Tensor shape ((BLK_A,BLK_B),Rest). +// ** That is, the Tensor and Tile are treated as 1D for the tiling. +// ** See logical_divide(Layout,Layout) +// +// * A Tile with shape +// ** Result Tensor shape ((BLK_A,a),(BLK_B,b),x,y). +// ** Each mode of the Tile is applied to the corresponding mode of the Tensor. +// ** See logical_divide(Layout,Tuple) +// +// * A Shape (BLK_A,BLK_B) +// ** Result Tensor shape ((BLK_A,a),(BLK_B,b),x,y). +// ** Equivalent to applying Tile. +// ** See logical_divide(Layout,Tuple) and logical_divide(Layout,Int) +// +// Note that the Tile/Shape Tilers must be weakly_congruent to the Tensor +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +logical_divide(Tensor && tensor, + Tiler const& tiler) // Layout or Tile or Shape +{ + return make_tensor(static_cast(tensor).data(), + logical_divide(tensor.layout(), tiler)); +} + +// zipped_divide is logical_divide with Tiler modes and Rest modes gathered together: (Tiler,Rest) +// When Tiler is Layout, this has no effect as logical_divide results in the same. +// When Tiler is Tile or Shape, this zips modes into standard form ((BLK_A,BLK_B),(a,b,x,y)) +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +zipped_divide(Tensor && tensor, + Tiler const& tiler) // Layout or Tile or Shape +{ + return make_tensor(static_cast(tensor).data(), + zipped_divide(tensor.layout(), tiler)); +} + +// tiled_divide is zipped_divide with the second output mode flattened ((BLK_A,BLK_B),a,b,x,y) +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +tiled_divide(Tensor && tensor, + Tiler const& tiler) // Layout or Tile or Shape +{ + return make_tensor(static_cast(tensor).data(), + tiled_divide(tensor.layout(), tiler)); +} + +// flat_divide is zipped_divide with the both modes flattened (BLK_A,BLK_B,a,b,x,y) +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +flat_divide(Tensor && tensor, + Tiler const& tiler) // Layout or Tile or Shape +{ + return make_tensor(static_cast(tensor).data(), + flat_divide(tensor.layout(), tiler)); +} + +// logical_product on a Tensor doesn't make sense since it often increases cosize +// though this might make sense for creating Tensors with broadcasted (stride-0) modes + +// +// Tensor partitioning utilities +// + +// Apply a Tiler to the Tensor, then slice out one of those tiles by slicing into the "Rest" modes. +// With an inner_partition, you get everything that's inside the Tiler. Everything that the Tiler is pointing to. +// Split the modes of tensor according to the Tiler +// zipped_divide returns something like ((BLK_A,BLK_B,...),(a,b,...,x,y)) +// Then slice into the second mode (the "Rest" mode) with Coord +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +inner_partition(Tensor && tensor, + Tiler const& tiler, + Coord const& coord) +{ + auto tensor_tiled = zipped_divide(static_cast(tensor), tiler); + constexpr int R0 = decltype(rank<0>(tensor_tiled))::value; + + // The coord slices into the second mode (the "rest" mode), flatten the first + if constexpr (is_tuple::value) { + // Append trailing modes if coord is tuple + constexpr int R1 = decltype(rank<1>(tensor_tiled))::value; + return tensor_tiled(repeat(_), append(coord,_)); + } else { + // Flat indexing if coord is not tuple + return tensor_tiled(repeat(_), coord); + } +} + +// Apply a Tiler to the Tensor, then slice out the remainder by slicing into the "Tile" modes. +// With an outer_partition, you get everything that's outside the Tiler. The layout of the Tile in the Tensor. +// Split the modes of tensor according to the Tiler +// zipped_divide returns something like ((BLK_A,BLK_B,...),(a,b,...,x,y)) +// Then slice into the first mode (the "Tile" mode) with Coord +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +outer_partition(Tensor && tensor, + Tiler const& tiler, + Coord const& coord) +{ + auto tensor_tiled = zipped_divide(static_cast(tensor), tiler); + constexpr int R1 = decltype(rank<1>(tensor_tiled))::value; + + // The coord slices into the first mode (the "tile" mode), flatten the second + if constexpr (is_tuple::value) { + // Append trailing modes if coord is tuple + constexpr int R0 = decltype(rank<0>(tensor_tiled))::value; + return tensor_tiled(append(coord,_), repeat(_)); + } else { + // Flat indexing if coord is not tuple + return tensor_tiled(coord, repeat(_)); + } +} + +// Tile a tensor according to @a tiler and use @a coord to index into the remainder, keeping the tile. +// This is typical at the CTA level where tiles of data are extracted: +// Tensor data = ... // ( M, N) +// Tensor cta_data = local_tile(data, Shape<_32,_64>{}, make_coord(blockIdx.x,blockIdx.y)); // (_32,_64) +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +local_tile(Tensor && tensor, + Tiler const& tiler, // tiler to apply + Coord const& coord) // coord to slice into "remainder" +{ + return inner_partition(static_cast(tensor), + tiler, + coord); +} + +// Same as above, but with a projection parameter to strip out unwanted tiling modes for convenience +// when using projections of the same tiler. +// This is typical at the CTA level where tiles of data are extracted as projections: +// Tensor dataA = ... // (M,K) +// Tensor dataB = ... // (N,K) +// Tensor dataC = ... // (M,N) +// auto cta_tiler = Shape<_32, _64, _4>{}; +// auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); +// Tensor ctaA = local_tile(dataA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (_32,_4,k) +// Tensor ctaB = local_tile(dataA, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (_64,_4,k) +// Tensor ctaC = local_tile(dataA, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (_32,_64) +template >::value)> +CUTE_HOST_DEVICE +auto +local_tile(Tensor && tensor, + Tiler const& tiler, // tiler to apply + Coord const& coord, // coord to slice into "remainder" + Proj const& proj) // projection to apply to tiler and coord +{ + return local_tile(static_cast(tensor), + dice(proj, tiler), + dice(proj, coord)); +} + +// Tile a tensor according to the flat shape of a layout that provides the coordinate of the target index. +// This is typical at the Thread level where data is partitioned across repeated patterns of threads: +// Tensor data = ... // (_16,_64) +// Tensor thr_data = local_partition(data, Layout>{}, thr_idx); // ( _8, _4) +template >::value)> +CUTE_HOST_DEVICE +auto +local_partition(Tensor && tensor, + Layout const& tile, // coord -> index + Index const& index) // index to slice for +{ + static_assert(is_integral::value); + return outer_partition(static_cast(tensor), + product_each(shape(tile)), + tile.get_flat_coord(index)); +} + +// Same as above, but with a projection parameter to strip out unwanted tiling modes for convenience +// when using projections of the same tiler. +// This is typical at the Thread level where data is partitioned across projected layouts of threads: +// Tensor dataA = ... // (M,K) +// Tensor dataB = ... // (N,K) +// Tensor dataC = ... // (M,N) +// auto thr_layout = Layout, Stride<_16,_1,_0>>{}; +// Tensor thrA = local_partition(dataA, thr_layout, thr_idx, Step<_1, X,_1>{}); // (M/2,K/1) +// Tensor thrB = local_partition(dataB, thr_layout, thr_idx, Step< X,_1,_1>{}); // (N/16,K/1) +// Tensor thrC = local_partition(dataC, thr_layout, thr_idx, Step<_1,_1, X>{}); // (M/2,N/16) +template >::value)> +CUTE_HOST_DEVICE +auto +local_partition(Tensor && tensor, + Layout const& tile, // coord -> index + Index const& index, // index to slice for + Projection const& proj) +{ + return local_partition(static_cast(tensor), + dice(proj, tile), + index); +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(Tensor const& tensor) +{ + print(tensor.data()); print(" o "); print(tensor.layout()); +} + +template +CUTE_HOST_DEVICE void print_tensor(Tensor const& tensor, bool print_type = true) +{ + if (print_type) { + print(tensor); print(":\n"); + } + + if constexpr (Layout::rank == 1) + { + for (int m = 0; m < size(tensor); ++m) { + pretty_print(tensor(m)); + printf("\n"); + } + } else + if constexpr (Layout::rank == 2) + { + for (int m = 0; m < size<0>(tensor); ++m) { + for (int n = 0; n < size<1>(tensor); ++n) { + pretty_print(tensor(m,n)); + } + printf("\n"); + } + } else + if constexpr (Layout::rank == 3) + { + print_tensor(tensor(_,_,0), false); + for (int k = 1; k < size<2>(tensor); ++k) { + for (int i = 0; i < 5*size<1>(tensor); ++i) { print("-"); } print("\n"); + print_tensor(tensor(_,_,k), false); + } + } else + if constexpr (Layout::rank == 4) + { + print_tensor(tensor(_,_,_,0), false); + for (int p = 1; p < size<3>(tensor); ++p) { + for (int i = 0; i < 5*size<1>(tensor); ++i) { print("="); } print("\n"); + print_tensor(tensor(_,_,_,p), false); + } + } +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& print_tensor_os(std::ostream& os, Tensor const& tensor) +{ + int digits = 9; + + if constexpr (Layout::rank == 1) + { + for (int m = 0; m < size(tensor); ++m) { + os << std::setw(digits) << tensor(m) << std::endl; + } + } else + if constexpr (Layout::rank == 2) + { + for (int m = 0; m < size<0>(tensor); ++m) { + for (int n = 0; n < size<1>(tensor); ++n) { + os << std::setw(digits) << tensor(m,n); + } + os << std::endl; + } + } else + if constexpr (Layout::rank == 3) + { + print_tensor_os(os, tensor(_,_,0)); + for (int k = 1; k < size<2>(tensor); ++k) { + for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "-"; } os << std::endl; + print_tensor_os(os, tensor(_,_,k)); + } + } else + if constexpr (Layout::rank == 4) + { + print_tensor_os(os, tensor(_,_,_,0)); + for (int p = 1; p < size<3>(tensor); ++p) { + for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "="; } os << std::endl; + print_tensor_os(os, tensor(_,_,_,p)); + } + } + + return os; +} + +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, Tensor const& tensor) +{ + os << tensor.layout() << std::endl; + return print_tensor_os(os, tensor); +} +#endif // !defined(__CUDACC_RTC__) + +} // end namespace cute + diff --git a/include/cute/util/debug.hpp b/include/cute/util/debug.hpp index 966bb1153e..86da7cae91 100644 --- a/include/cute/util/debug.hpp +++ b/include/cute/util/debug.hpp @@ -120,7 +120,7 @@ print_type(T&&...) { CUTE_HOST_DEVICE bool -block(int bid) +block([[maybe_unused]] int bid) { #if defined(__CUDA_ARCH__) return blockIdx.x + blockIdx.y*gridDim.x + blockIdx.z*gridDim.x*gridDim.y == bid; @@ -131,7 +131,7 @@ block(int bid) CUTE_HOST_DEVICE bool -thread(int tid, int bid) +thread([[maybe_unused]] int tid, [[maybe_unused]] int bid) { #if defined(__CUDA_ARCH__) return (threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.x*blockDim.y == tid) && block(bid); diff --git a/include/cute/util/type_traits.hpp b/include/cute/util/type_traits.hpp index 56cc814ec3..f12cdb594f 100644 --- a/include/cute/util/type_traits.hpp +++ b/include/cute/util/type_traits.hpp @@ -85,6 +85,8 @@ using CUTE_STL_NAMESPACE::is_volatile_v; using CUTE_STL_NAMESPACE::conditional; using CUTE_STL_NAMESPACE::conditional_t; +using CUTE_STL_NAMESPACE::add_const_t; + using CUTE_STL_NAMESPACE::remove_const_t; using CUTE_STL_NAMESPACE::remove_cv_t; using CUTE_STL_NAMESPACE::remove_reference_t; @@ -107,6 +109,13 @@ using CUTE_STL_NAMESPACE::is_convertible_v; using CUTE_STL_NAMESPACE::is_same; using CUTE_STL_NAMESPACE::is_same_v; +using CUTE_STL_NAMESPACE::is_constructible; +using CUTE_STL_NAMESPACE::is_constructible_v; +using CUTE_STL_NAMESPACE::is_default_constructible; +using CUTE_STL_NAMESPACE::is_default_constructible_v; +using CUTE_STL_NAMESPACE::is_standard_layout; +using CUTE_STL_NAMESPACE::is_standard_layout_v; + using CUTE_STL_NAMESPACE::is_arithmetic; using CUTE_STL_NAMESPACE::is_unsigned; using CUTE_STL_NAMESPACE::is_unsigned_v; @@ -131,6 +140,9 @@ using CUTE_STL_NAMESPACE::common_type_t; using CUTE_STL_NAMESPACE::remove_pointer; using CUTE_STL_NAMESPACE::remove_pointer_t; +using CUTE_STL_NAMESPACE::alignment_of; +using CUTE_STL_NAMESPACE::alignment_of_v; + // using CUTE_STL_NAMESPACE::declval; @@ -261,4 +273,5 @@ struct conditional_template { template using type = False; }; + } // end namespace cute diff --git a/include/cutlass/arch/barrier.h b/include/cutlass/arch/barrier.h index b08ee546d3..cd2d7be3cb 100644 --- a/include/cutlass/arch/barrier.h +++ b/include/cutlass/arch/barrier.h @@ -50,11 +50,11 @@ namespace arch { // Enumerates the reserved named barriers to avoid potential conflicts // This enum class specifies the NamedBarriers reserved by CUTLASS. enum class ReservedNamedBarriers { - EpilogueBarrier = 0, - TransposeBarrier = 1, - TransformBarrier = 2, - StreamkBarrier0 = 3, - StreamkBarrier1 = 4 + EpilogueBarrier = 1, + TransposeBarrier = 2, + TransformBarrier = 3, + StreamkBarrier0 = 4, + StreamkBarrier1 = 5 , FirstUserBarrier = StreamkBarrier1 + 1 }; @@ -204,12 +204,12 @@ struct ClusterBarrier { } CUTLASS_DEVICE - uint32_t test_wait(uint32_t phase, uint32_t pred=true) const { + bool test_wait(uint32_t phase, uint32_t pred=true) const { return ClusterBarrier::test_wait(&this->barrier_, phase, pred); } CUTLASS_DEVICE - uint32_t try_wait(uint32_t phase) const { + bool try_wait(uint32_t phase) const { return ClusterBarrier::try_wait(&this->barrier_, phase); } @@ -260,8 +260,8 @@ struct ClusterBarrier { ".reg .pred P1; \n\t" "LAB_WAIT: \n\t" "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1, %2; \n\t" - "@P1 bra.uni DONE; \n\t" - "bra.uni LAB_WAIT; \n\t" + "@P1 bra DONE; \n\t" + "bra LAB_WAIT; \n\t" "DONE: \n\t" "}" : @@ -273,7 +273,7 @@ struct ClusterBarrier { } CUTLASS_DEVICE - static uint32_t test_wait(ValueType const* smem_ptr, uint32_t phase, uint32_t pred) { + static bool test_wait(ValueType const* smem_ptr, uint32_t phase, uint32_t pred) { #if CUDA_BARRIER_ENABLED uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); uint32_t waitComplete; @@ -289,7 +289,7 @@ struct ClusterBarrier { : "=r"(waitComplete) : "r"(smem_addr), "r"(phase), "r"(pred)); - return waitComplete; + return static_cast(waitComplete); #elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif @@ -297,7 +297,7 @@ struct ClusterBarrier { } CUTLASS_DEVICE - static uint32_t try_wait(ValueType const* smem_ptr, uint32_t phase) { + static bool try_wait(ValueType const* smem_ptr, uint32_t phase) { #if CUDA_BARRIER_ENABLED uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); uint32_t waitComplete; @@ -311,7 +311,7 @@ struct ClusterBarrier { : "=r"(waitComplete) : "r"(smem_addr), "r"(phase)); - return waitComplete; + return static_cast(waitComplete); #elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif @@ -323,16 +323,17 @@ struct ClusterBarrier { static void arrive(ValueType const* smem_ptr, uint32_t cta_id, uint32_t pred) { #if CUDA_BARRIER_ENABLED uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); - asm volatile( - "{\n\t" - ".reg .pred p;\n\t" - ".reg .b32 remAddr32;\n\t" - "setp.eq.u32 p, %2, 1;\n\t" - "@p mapa.shared::cluster.u32 remAddr32, %0, %1;\n\t" - "@p mbarrier.arrive.shared::cluster.b64 _, [remAddr32];\n\t" - "}" - : - : "r"(smem_addr), "r"(cta_id), "r"(pred)); + if (pred) { + asm volatile( + "{\n\t" + ".reg .b32 remAddr32;\n\t" + "mapa.shared::cluster.u32 remAddr32, %0, %1;\n\t" + "mbarrier.arrive.shared::cluster.b64 _, [remAddr32];\n\t" + "}" + : + : "r"(smem_addr), "r"(cta_id)); + } + #elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif diff --git a/include/cutlass/arch/mma_sm75.h b/include/cutlass/arch/mma_sm75.h index 0a3b33a556..6cced190e8 100644 --- a/include/cutlass/arch/mma_sm75.h +++ b/include/cutlass/arch/mma_sm75.h @@ -201,252 +201,6 @@ struct Mma< } }; -//////////////////////////////////////////////////////////////////////////////// -// -// Integer matrix multiply .8816 (8b) -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: S32 = S8 * S8 + S32 -template <> -struct Mma< - gemm::GemmShape<8, 8, 16>, - 32, - int8_t, - layout::RowMajor, - int8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<8, 8, 16>; - - using ElementA = int8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = int8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm75; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) - - unsigned const & A = reinterpret_cast(a); - unsigned const & B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U8 * S8 + S32 -template <> -struct Mma< - gemm::GemmShape<8, 8, 16>, - 32, - uint8_t, - layout::RowMajor, - int8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<8, 8, 16>; - - using ElementA = uint8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = int8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm75; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) - - unsigned const & A = reinterpret_cast(a); - unsigned const & B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.u8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = S8 * U8 + S32 -template <> -struct Mma< - gemm::GemmShape<8, 8, 16>, - 32, - int8_t, - layout::RowMajor, - uint8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<8, 8, 16>; - - using ElementA = int8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = uint8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm75; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) - - unsigned const & A = reinterpret_cast(a); - unsigned const & B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k16.row.col.s8.u8 {%0,%1}, {%2}, {%3}, {%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U8 * U8 + S32 -template <> -struct Mma< - gemm::GemmShape<8, 8, 16>, - 32, - uint8_t, - layout::RowMajor, - uint8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<8, 8, 16>; - - using ElementA = uint8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = uint8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm75; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) - - unsigned const & A = reinterpret_cast(a); - unsigned const & B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - //////////////////////////////////////////////////////////////////////////////// // // Integer matrix multiply (8b) with SATURATE @@ -693,252 +447,6 @@ struct Mma< } }; -//////////////////////////////////////////////////////////////////////////////// -// -// Integer matrix multiply (4b) -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: S32 = S4 * S4 + S32 -template <> -struct Mma< - gemm::GemmShape<8, 8, 32>, - 32, - int4b_t, - layout::RowMajor, - int4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<8, 8, 32>; - - using ElementA = int4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = int4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm75; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) - - unsigned const & A = reinterpret_cast(a); - unsigned const & B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U4 * S4 + S32 -template <> -struct Mma< - gemm::GemmShape<8, 8, 32>, - 32, - uint4b_t, - layout::RowMajor, - int4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<8, 8, 32>; - - using ElementA = uint4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = int4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm75; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) - - unsigned const & A = reinterpret_cast(a); - unsigned const & B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.u4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = S4 * U4 + S32 -template <> -struct Mma< - gemm::GemmShape<8, 8, 32>, - 32, - int4b_t, - layout::RowMajor, - uint4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<8, 8, 32>; - - using ElementA = int4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = uint4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm75; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) - - unsigned const & A = reinterpret_cast(a); - unsigned const & B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.s4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U4 * U4 + S32 -template <> -struct Mma< - gemm::GemmShape<8, 8, 32>, - 32, - uint4b_t, - layout::RowMajor, - uint4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<8, 8, 32>; - - using ElementA = uint4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = uint4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm75; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) - - unsigned const & A = reinterpret_cast(a); - unsigned const & B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - //////////////////////////////////////////////////////////////////////////////// // // Integer matrix multiply (4b) - SATURATE diff --git a/include/cutlass/arch/mma_sm80.h b/include/cutlass/arch/mma_sm80.h index 82152ecbd5..f990c1ac27 100644 --- a/include/cutlass/arch/mma_sm80.h +++ b/include/cutlass/arch/mma_sm80.h @@ -537,756 +537,14 @@ struct Mma< //////////////////////////////////////////////////////////////////////////////// // -// Matrix Multiply 16816 - S8 input, S32 accumulation -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: S32 = S8 * S8 + S32 -template <> -struct Mma< - gemm::GemmShape<16,8,16>, - 32, - int8_t, - layout::RowMajor, - int8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<16,8,16>; - - using ElementA = int8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = int8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - uint32_t const *A = reinterpret_cast(&a); - uint32_t const &B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0,%1,%2,%3}, {%4,%5}, {%6}, " - "{%7,%8,%9,%10};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), - "r"(C[3])); - -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U8 * S8 + S32 -template <> -struct Mma< - gemm::GemmShape<16,8,16>, - 32, - uint8_t, - layout::RowMajor, - int8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<16,8,16>; - - using ElementA = uint8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = int8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - uint32_t const *A = reinterpret_cast(&a); - uint32_t const &B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32 {%0,%1,%2,%3}, {%4,%5}, {%6}, " - "{%7,%8,%9,%10};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), - "r"(C[3])); - -#else - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = S8 * U8 + S32 -template <> -struct Mma< - gemm::GemmShape<16,8,16>, - 32, - int8_t, - layout::RowMajor, - uint8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<16,8,16>; - - using ElementA = int8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = uint8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const &B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32 {%0,%1,%2,%3}, {%4,%5}, {%6}, " - "{%7,%8,%9,%10};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), - "r"(C[3])); - -#else - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U8 * U8 + S32 -template <> -struct Mma< - gemm::GemmShape<16,8,16>, - 32, - uint8_t, - layout::RowMajor, - uint8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<16,8,16>; - - using ElementA = uint8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = uint8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const &B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32 {%0,%1,%2,%3}, {%4,%5}, {%6}, " - "{%7,%8,%9,%10};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), - "r"(C[3])); - - -#else - assert(0); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// Matrix Multiply 16816 - S8 input, S32 accumulation - SATURATE -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: S32 = S8 * S8 + S32 -template <> -struct Mma< - gemm::GemmShape<16,8,16>, - 32, - int8_t, - layout::RowMajor, - int8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<16,8,16>; - - using ElementA = int8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = int8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const &B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " - "{%6}, {%7,%8,%9,%10};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), - "r"(C[3])); - -#else - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U8 * S8 + S32 -template <> -struct Mma< - gemm::GemmShape<16,8,16>, - 32, - uint8_t, - layout::RowMajor, - int8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<16,8,16>; - - using ElementA = uint8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = int8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const &B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " - "{%6}, {%7,%8,%9,%10};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), - "r"(C[3])); - -#else - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = S8 * U8 + S32 -template <> -struct Mma< - gemm::GemmShape<16,8,16>, - 32, - int8_t, - layout::RowMajor, - uint8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<16,8,16>; - - using ElementA = int8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = uint8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const &B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " - "{%6}, {%7,%8,%9,%10};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), - "r"(C[3])); - -#else - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U8 * U8 + S32 -template <> -struct Mma< - gemm::GemmShape<16,8,16>, - 32, - uint8_t, - layout::RowMajor, - uint8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<16,8,16>; - - using ElementA = uint8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = uint8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const &B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " - "{%6}, {%7,%8,%9,%10};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), - "r"(C[3])); - -#else - assert(0); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// Matrix Multiply 16832 - S8 input, S32 accumulation -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: S32 = S8 * S8 + S32 -template <> -struct Mma< - gemm::GemmShape<16,8,32>, - 32, - int8_t, - layout::RowMajor, - int8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<16,8,32>; - - using ElementA = int8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = int8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - -#else - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U8 * S8 + S32 -template <> -struct Mma< - gemm::GemmShape<16,8,32>, - 32, - uint8_t, - layout::RowMajor, - int8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<16,8,32>; - - using ElementA = uint8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = int8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - -#else - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = S8 * U8 + S32 -template <> -struct Mma< - gemm::GemmShape<16,8,32>, - 32, - int8_t, - layout::RowMajor, - uint8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<16,8,32>; - - using ElementA = int8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = uint8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - -#else - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U8 * U8 + S32 -template <> -struct Mma< - gemm::GemmShape<16,8,32>, - 32, - uint8_t, - layout::RowMajor, - uint8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<16,8,32>; - - using ElementA = uint8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = uint8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - -#else - assert(0); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// Matrix Multiply 16832 - S8 input, S32 accumulation - SATURATE +// Matrix Multiply 16816 - S8 input, S32 accumulation - SATURATE // //////////////////////////////////////////////////////////////////////////////// /// Matrix multiply-add operation: S32 = S8 * S8 + S32 template <> struct Mma< - gemm::GemmShape<16,8,32>, + gemm::GemmShape<16,8,16>, 32, int8_t, layout::RowMajor, @@ -1296,21 +554,21 @@ struct Mma< layout::RowMajor, OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<16,8,32>; + using Shape = gemm::GemmShape<16,8,16>; using ElementA = int8_t; using LayoutA = layout::RowMajor; - using FragmentA = Array; + using FragmentA = Array; using ElementB = int8_t; using LayoutB = layout::ColumnMajor; - using FragmentB = Array; + using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; using FragmentC = Array; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; /// Computes multiply-add @@ -1324,18 +582,18 @@ struct Mma< #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - uint32_t const * A = reinterpret_cast(&a); - uint32_t const * B = reinterpret_cast(&b); + uint32_t const *A = reinterpret_cast(&a); + uint32_t const &B = reinterpret_cast(b); - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); - asm volatile( - "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, " - "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " + "{%6}, {%7,%8,%9,%10};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), + "r"(C[3])); #else assert(0); @@ -1346,7 +604,7 @@ struct Mma< /// Matrix multiply-add operation: S32 = U8 * S8 + S32 template <> struct Mma< - gemm::GemmShape<16,8,32>, + gemm::GemmShape<16,8,16>, 32, uint8_t, layout::RowMajor, @@ -1356,15 +614,15 @@ struct Mma< layout::RowMajor, OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<16,8,32>; + using Shape = gemm::GemmShape<16,8,16>; using ElementA = uint8_t; using LayoutA = layout::RowMajor; - using FragmentA = Array; + using FragmentA = Array; using ElementB = int8_t; using LayoutB = layout::ColumnMajor; - using FragmentB = Array; + using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; @@ -1385,17 +643,17 @@ struct Mma< #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); + uint32_t const &B = reinterpret_cast(b); int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); asm volatile( - "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, " - "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " + "{%6}, {%7,%8,%9,%10};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), + "r"(C[3])); #else assert(0); @@ -1406,7 +664,7 @@ struct Mma< /// Matrix multiply-add operation: S32 = S8 * U8 + S32 template <> struct Mma< - gemm::GemmShape<16,8,32>, + gemm::GemmShape<16,8,16>, 32, int8_t, layout::RowMajor, @@ -1416,21 +674,21 @@ struct Mma< layout::RowMajor, OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<16,8,32>; + using Shape = gemm::GemmShape<16,8,16>; using ElementA = int8_t; using LayoutA = layout::RowMajor; - using FragmentA = Array; + using FragmentA = Array; using ElementB = uint8_t; using LayoutB = layout::ColumnMajor; - using FragmentB = Array; + using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; using FragmentC = Array; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; /// Computes multiply-add @@ -1445,18 +703,18 @@ struct Mma< #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); + uint32_t const &B = reinterpret_cast(b); int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); asm volatile( - "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, " - "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " + "{%6}, {%7,%8,%9,%10};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - + : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), + "r"(C[3])); + #else assert(0); #endif @@ -1466,7 +724,7 @@ struct Mma< /// Matrix multiply-add operation: S32 = U8 * U8 + S32 template <> struct Mma< - gemm::GemmShape<16,8,32>, + gemm::GemmShape<16,8,16>, 32, uint8_t, layout::RowMajor, @@ -1476,15 +734,15 @@ struct Mma< layout::RowMajor, OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<16,8,32>; + using Shape = gemm::GemmShape<16,8,16>; using ElementA = uint8_t; using LayoutA = layout::RowMajor; - using FragmentA = Array; + using FragmentA = Array; using ElementB = uint8_t; using LayoutB = layout::ColumnMajor; - using FragmentB = Array; + using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; @@ -1505,17 +763,17 @@ struct Mma< #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); + uint32_t const &B = reinterpret_cast(b); int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); asm volatile( - "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, " - "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " + "{%6}, {%7,%8,%9,%10};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), + "r"(C[3])); #else assert(0); @@ -1525,38 +783,38 @@ struct Mma< //////////////////////////////////////////////////////////////////////////////// // -// Matrix Multiply 16864 - S4 input, S32 accumulation +// Matrix Multiply 16832 - S8 input, S32 accumulation - SATURATE // //////////////////////////////////////////////////////////////////////////////// -/// Matrix multiply-add operation: S32 = S4 * S4 + S32 +/// Matrix multiply-add operation: S32 = S8 * S8 + S32 template <> struct Mma< - gemm::GemmShape<16, 8, 64>, + gemm::GemmShape<16,8,32>, 32, - cutlass::int4b_t, + int8_t, layout::RowMajor, - cutlass::int4b_t, + int8_t, layout::ColumnMajor, int, layout::RowMajor, - OpMultiplyAdd> { + OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<16, 8, 64>; + using Shape = gemm::GemmShape<16,8,32>; - using ElementA = cutlass::int4b_t; + using ElementA = int8_t; using LayoutA = layout::RowMajor; - using FragmentA = Array; + using FragmentA = Array; - using ElementB = cutlass::int4b_t; + using ElementB = int8_t; using LayoutB = layout::ColumnMajor; - using FragmentB = Array; + using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; using FragmentC = Array; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; /// Computes multiply-add @@ -1570,57 +828,53 @@ struct Mma< #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); + uint32_t const * A = reinterpret_cast(&a); + uint32_t const * B = reinterpret_cast(&b); - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); - asm volatile( - "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); assert(0); #endif } }; -/// Matrix multiply-add operation: S32 = U4 * S4 + S32 +/// Matrix multiply-add operation: S32 = U8 * S8 + S32 template <> struct Mma< - gemm::GemmShape<16, 8, 64>, + gemm::GemmShape<16,8,32>, 32, - cutlass::uint4b_t, + uint8_t, layout::RowMajor, - cutlass::int4b_t, + int8_t, layout::ColumnMajor, int, layout::RowMajor, - OpMultiplyAdd> { + OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<16, 8, 64>; + using Shape = gemm::GemmShape<16,8,32>; - using ElementA = cutlass::uint4b_t; + using ElementA = uint8_t; using LayoutA = layout::RowMajor; - using FragmentA = Array; + using FragmentA = Array; - using ElementB = cutlass::int4b_t; + using ElementB = int8_t; using LayoutB = layout::ColumnMajor; - using FragmentB = Array; + using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; using FragmentC = Array; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; /// Computes multiply-add @@ -1641,50 +895,46 @@ struct Mma< int *D = reinterpret_cast(&d); asm volatile( - "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9}, {%10,%11,%12,%13};\n" + "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); assert(0); #endif } }; -/// Matrix multiply-add operation: S32 = S4 * U4 + S32 +/// Matrix multiply-add operation: S32 = S8 * U8 + S32 template <> struct Mma< - gemm::GemmShape<16, 8, 64>, + gemm::GemmShape<16,8,32>, 32, - cutlass::int4b_t, + int8_t, layout::RowMajor, - cutlass::uint4b_t, + uint8_t, layout::ColumnMajor, int, layout::RowMajor, - OpMultiplyAdd> { + OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<16, 8, 64>; + using Shape = gemm::GemmShape<16,8,32>; - using ElementA = cutlass::int4b_t; + using ElementA = int8_t; using LayoutA = layout::RowMajor; - using FragmentA = Array; + using FragmentA = Array; - using ElementB = cutlass::uint4b_t; + using ElementB = uint8_t; using LayoutB = layout::ColumnMajor; - using FragmentB = Array; + using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; using FragmentC = Array; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; /// Computes multiply-add @@ -1705,50 +955,46 @@ struct Mma< int *D = reinterpret_cast(&d); asm volatile( - "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9}, {%10,%11,%12,%13};\n" + "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); assert(0); #endif } }; -/// Matrix multiply-add operation: S32 = U4 * U4 + S32 +/// Matrix multiply-add operation: S32 = U8 * U8 + S32 template <> struct Mma< - gemm::GemmShape<16, 8, 64>, + gemm::GemmShape<16,8,32>, 32, - cutlass::uint4b_t, + uint8_t, layout::RowMajor, - cutlass::uint4b_t, + uint8_t, layout::ColumnMajor, int, layout::RowMajor, - OpMultiplyAdd> { + OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<16, 8, 64>; + using Shape = gemm::GemmShape<16,8,32>; - using ElementA = cutlass::uint4b_t; + using ElementA = uint8_t; using LayoutA = layout::RowMajor; - using FragmentA = Array; + using FragmentA = Array; - using ElementB = cutlass::uint4b_t; + using ElementB = uint8_t; using LayoutB = layout::ColumnMajor; - using FragmentB = Array; + using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; using FragmentC = Array; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; /// Computes multiply-add @@ -1769,23 +1015,18 @@ struct Mma< int *D = reinterpret_cast(&d); asm volatile( - "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9}, {%10,%11,%12,%13};\n" + "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); assert(0); #endif } }; - //////////////////////////////////////////////////////////////////////////////// // // Matrix Multiply 16864 - S4 input, S32 accumulation - SATURATE @@ -1819,7 +1060,7 @@ struct Mma< using LayoutC = layout::RowMajor; using FragmentC = Array; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; /// Computes multiply-add @@ -1947,7 +1188,7 @@ struct Mma< using LayoutC = layout::RowMajor; using FragmentC = Array; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; /// Computes multiply-add @@ -2261,5 +1502,4 @@ struct Mma< } // namespace arch } // namespace cutlass - ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/arch/mma_sparse_sm80.h b/include/cutlass/arch/mma_sparse_sm80.h index 0bf842e193..7041d04dd4 100644 --- a/include/cutlass/arch/mma_sparse_sm80.h +++ b/include/cutlass/arch/mma_sparse_sm80.h @@ -54,6 +54,7 @@ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) #define CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED #endif + #endif ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -121,6 +122,27 @@ struct SparseMma< uint32_t const *C = reinterpret_cast(&c); uint32_t *D = reinterpret_cast(&d); +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {%0,%1}, " + "{%2,%3,%4,%5}, {%6,%7,%8,%9}, {%10,%11}, %12, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(B[2]), "r"(B[3]), "r"(C[0]), "r"(C[1]), "r"(E)); + } + else if (id2 == 1) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {%0,%1}, " + "{%2,%3,%4,%5}, {%6,%7,%8,%9}, {%10,%11}, %12, 0x1;\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(B[2]), "r"(B[3]), "r"(C[0]), "r"(C[1]), "r"(E)); + } + else { + assert(0); + } +#else if (id2 == 0) { asm volatile( "mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {%0,%1}, " @@ -140,6 +162,8 @@ struct SparseMma< else { assert(0); } +#endif + #else CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); @@ -204,6 +228,29 @@ struct SparseMma< float const *C = reinterpret_cast(&c); float *D = reinterpret_cast(&d); +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(B[2]), "r"(B[3]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), + "r"(E)); + } + else if (id2 == 1) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(B[2]), "r"(B[3]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), + "r"(E)); + } + else { + assert(0); + } +#else if (id2 == 0) { asm volatile( "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, " @@ -226,8 +273,9 @@ struct SparseMma< assert(0); } -#else +#endif +#else CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); @@ -284,23 +332,43 @@ struct SparseMma, 32, bfloat16_t, layout::RowMajor, float const *C = reinterpret_cast(&c); float *D = reinterpret_cast(&d); +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) if (id2 == 0) { - asm volatile( - "mma.sp.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); } else if (id2 == 1) { - asm volatile( - "mma.sp.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); } else { - assert(0); + assert(0); + } +#else + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } else if (id2 == 1) { + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } else { + assert(0); } +#endif #else @@ -360,23 +428,43 @@ struct SparseMma, 32, tfloat32_t, layout::RowMajor, float const *C = reinterpret_cast(&c); float *D = reinterpret_cast(&d); +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) if (id2 == 0) { - asm volatile( - "mma.sp.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); } else if (id2 == 1) { - asm volatile( - "mma.sp.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); } else { - assert(0); + assert(0); + } +#else + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } else if (id2 == 1) { + asm volatile( + "mma.sp.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } else { + assert(0); } +#endif #else @@ -391,7 +479,7 @@ struct SparseMma, 32, tfloat32_t, layout::RowMajor, //////////////////////////////////////////////////////////////////////////////// // -// Sparse Matrix Multiply 16864 - S8 input, S32 accumulation +// Sparse Matrix Multiply 16864 - S8 input, S32 accumulation - SATURATE // //////////////////////////////////////////////////////////////////////////////// @@ -406,7 +494,7 @@ struct SparseMma< layout::ColumnMajor, int, layout::RowMajor, - OpMultiplyAdd, + OpMultiplyAddSaturate, SPFormatType::Thread> { using Shape = gemm::GemmShape<16,8,64>; @@ -425,7 +513,7 @@ struct SparseMma< using FragmentE = uint32_t; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; static int const kSparse = 2; @@ -453,18 +541,31 @@ struct SparseMma< int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.s8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); - +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } #else + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#endif +#else CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); @@ -485,7 +586,7 @@ struct SparseMma< layout::ColumnMajor, int, layout::RowMajor, - OpMultiplyAdd, + OpMultiplyAddSaturate, SPFormatType::Thread> { using Shape = gemm::GemmShape<16,8,64>; @@ -504,7 +605,7 @@ struct SparseMma< using FragmentE = uint32_t; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; static int const kSparse = 2; @@ -532,15 +633,29 @@ struct SparseMma< int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.u8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#else + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#endif #else @@ -564,7 +679,7 @@ struct SparseMma< layout::ColumnMajor, int, layout::RowMajor, - OpMultiplyAdd, + OpMultiplyAddSaturate, SPFormatType::Thread> { using Shape = gemm::GemmShape<16,8,64>; @@ -583,7 +698,7 @@ struct SparseMma< using FragmentE = uint32_t; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; static int const kSparse = 2; @@ -611,18 +726,31 @@ struct SparseMma< int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.s8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); - +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } #else + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#endif +#else CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); @@ -643,7 +771,7 @@ struct SparseMma< layout::ColumnMajor, int, layout::RowMajor, - OpMultiplyAdd, + OpMultiplyAddSaturate, SPFormatType::Thread> { using Shape = gemm::GemmShape<16,8,64>; @@ -662,7 +790,7 @@ struct SparseMma< using FragmentE = uint32_t; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; static int const kSparse = 2; @@ -690,18 +818,31 @@ struct SparseMma< int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.u8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); - +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } #else + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#endif +#else CUTLASS_UNUSED(a); CUTLASS_UNUSED(b); CUTLASS_UNUSED(c); @@ -713,33 +854,33 @@ struct SparseMma< //////////////////////////////////////////////////////////////////////////////// // -// Sparse Matrix Multiply 16864 - S8 input, S32 accumulation - SATURATE +// Sparse Matrix Multiply 168128 - S4 input, S32 accumulation - SATURATE // //////////////////////////////////////////////////////////////////////////////// -/// Matrix multiply-add operation: S32 = S8 * S8 + S32 +/// Matrix multiply-add operation: S32 = S4 * S4 + S32 template <> struct SparseMma< - gemm::GemmShape<16,8,64>, + gemm::GemmShape<16,8,128>, 32, - int8_t, + cutlass::int4b_t, layout::RowMajor, - int8_t, + cutlass::int4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate, SPFormatType::Thread> { - using Shape = gemm::GemmShape<16,8,64>; + using Shape = gemm::GemmShape<16,8,128>; - using ElementA = int8_t; + using ElementA = cutlass::int4b_t; using LayoutA = layout::RowMajor; - using FragmentA = Array; + using FragmentA = Array; - using ElementB = int8_t; + using ElementB = cutlass::int4b_t; using LayoutB = layout::ColumnMajor; - using FragmentB = Array; + using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; @@ -747,7 +888,7 @@ struct SparseMma< using FragmentE = uint32_t; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; static int const kSparse = 2; @@ -775,15 +916,29 @@ struct SparseMma< int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k128.row.col.s32.s4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#else + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#endif #else @@ -796,29 +951,29 @@ struct SparseMma< } }; -/// Matrix multiply-add operation: S32 = S8 * U8 + S32 +/// Matrix multiply-add operation: S32 = S4 * U4 + S32 template <> struct SparseMma< - gemm::GemmShape<16,8,64>, + gemm::GemmShape<16,8,128>, 32, - int8_t, + cutlass::int4b_t, layout::RowMajor, - uint8_t, + cutlass::uint4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate, SPFormatType::Thread> { - using Shape = gemm::GemmShape<16,8,64>; + using Shape = gemm::GemmShape<16,8,128>; - using ElementA = int8_t; + using ElementA = cutlass::int4b_t; using LayoutA = layout::RowMajor; - using FragmentA = Array; + using FragmentA = Array; - using ElementB = uint8_t; + using ElementB = cutlass::uint4b_t; using LayoutB = layout::ColumnMajor; - using FragmentB = Array; + using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; @@ -826,7 +981,7 @@ struct SparseMma< using FragmentE = uint32_t; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; static int const kSparse = 2; @@ -854,15 +1009,29 @@ struct SparseMma< int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k128.row.col.s32.s4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#else + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#endif #else @@ -875,29 +1044,29 @@ struct SparseMma< } }; -/// Matrix multiply-add operation: S32 = U8 * S8 + S32 +/// Matrix multiply-add operation: S32 = U4 * S4 + S32 template <> struct SparseMma< - gemm::GemmShape<16,8,64>, + gemm::GemmShape<16,8,128>, 32, - uint8_t, + cutlass::uint4b_t, layout::RowMajor, - int8_t, + cutlass::int4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate, SPFormatType::Thread> { - using Shape = gemm::GemmShape<16,8,64>; + using Shape = gemm::GemmShape<16,8,128>; - using ElementA = uint8_t; + using ElementA = cutlass::uint4b_t; using LayoutA = layout::RowMajor; - using FragmentA = Array; + using FragmentA = Array; - using ElementB = int8_t; + using ElementB = cutlass::int4b_t; using LayoutB = layout::ColumnMajor; - using FragmentB = Array; + using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; @@ -905,7 +1074,7 @@ struct SparseMma< using FragmentE = uint32_t; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; static int const kSparse = 2; @@ -933,15 +1102,29 @@ struct SparseMma< int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k128.row.col.s32.u4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#else + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#endif #else @@ -954,29 +1137,29 @@ struct SparseMma< } }; -/// Matrix multiply-add operation: S32 = U8 * U8 + S32 +/// Matrix multiply-add operation: S32 = U4 * U4 + S32 template <> struct SparseMma< - gemm::GemmShape<16,8,64>, + gemm::GemmShape<16,8,128>, 32, - uint8_t, + cutlass::uint4b_t, layout::RowMajor, - uint8_t, + cutlass::uint4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate, SPFormatType::Thread> { - using Shape = gemm::GemmShape<16,8,64>; + using Shape = gemm::GemmShape<16,8,128>; - using ElementA = uint8_t; + using ElementA = cutlass::uint4b_t; using LayoutA = layout::RowMajor; - using FragmentA = Array; + using FragmentA = Array; - using ElementB = uint8_t; + using ElementB = cutlass::uint4b_t; using LayoutB = layout::ColumnMajor; - using FragmentB = Array; + using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; @@ -984,7 +1167,7 @@ struct SparseMma< using FragmentE = uint32_t; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; static int const kSparse = 2; @@ -1012,659 +1195,29 @@ struct SparseMma< int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); - +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k128.row.col.s32.u4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } #else - - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } #endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// Sparse Matrix Multiply 168128 - S4 input, S32 accumulation -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: S32 = S4 * S4 + S32 -template <> -struct SparseMma< - gemm::GemmShape<16,8,128>, - 32, - cutlass::int4b_t, - layout::RowMajor, - cutlass::int4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd, - SPFormatType::Thread> { - - using Shape = gemm::GemmShape<16,8,128>; - - using ElementA = cutlass::int4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::int4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 1; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c, - uint32_t const &E, - int const id2 - ) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.s4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); - -#else - - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = S4 * U4 + S32 -template <> -struct SparseMma< - gemm::GemmShape<16,8,128>, - 32, - cutlass::int4b_t, - layout::RowMajor, - cutlass::uint4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd, - SPFormatType::Thread> { - - using Shape = gemm::GemmShape<16,8,128>; - - using ElementA = cutlass::int4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::uint4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 1; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c, - uint32_t const &E, - int const id2 - ) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.u4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); - -#else - - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U4 * S4 + S32 -template <> -struct SparseMma< - gemm::GemmShape<16,8,128>, - 32, - cutlass::uint4b_t, - layout::RowMajor, - cutlass::int4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd, - SPFormatType::Thread> { - - using Shape = gemm::GemmShape<16,8,128>; - - using ElementA = cutlass::uint4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::int4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 1; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c, - uint32_t const &E, - int const id2 - ) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.s4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); - -#else - - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U4 * U4 + S32 -template <> -struct SparseMma< - gemm::GemmShape<16,8,128>, - 32, - cutlass::uint4b_t, - layout::RowMajor, - cutlass::uint4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd, - SPFormatType::Thread> { - - using Shape = gemm::GemmShape<16,8,128>; - - using ElementA = cutlass::uint4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::uint4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 1; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c, - uint32_t const &E, - int const id2 - ) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.u4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); - -#else - - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// Sparse Matrix Multiply 168128 - S4 input, S32 accumulation - SATURATE -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: S32 = S4 * S4 + S32 -template <> -struct SparseMma< - gemm::GemmShape<16,8,128>, - 32, - cutlass::int4b_t, - layout::RowMajor, - cutlass::int4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate, - SPFormatType::Thread> { - - using Shape = gemm::GemmShape<16,8,128>; - - using ElementA = cutlass::int4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::int4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 1; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c, - uint32_t const &E, - int const id2 - ) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); - -#else - - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = S4 * U4 + S32 -template <> -struct SparseMma< - gemm::GemmShape<16,8,128>, - 32, - cutlass::int4b_t, - layout::RowMajor, - cutlass::uint4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate, - SPFormatType::Thread> { - - using Shape = gemm::GemmShape<16,8,128>; - - using ElementA = cutlass::int4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::uint4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 1; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c, - uint32_t const &E, - int const id2 - ) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); - -#else - - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U4 * S4 + S32 -template <> -struct SparseMma< - gemm::GemmShape<16,8,128>, - 32, - cutlass::uint4b_t, - layout::RowMajor, - cutlass::int4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate, - SPFormatType::Thread> { - - using Shape = gemm::GemmShape<16,8,128>; - - using ElementA = cutlass::uint4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::int4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 1; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c, - uint32_t const &E, - int const id2 - ) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); - -#else - - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U4 * U4 + S32 -template <> -struct SparseMma< - gemm::GemmShape<16,8,128>, - 32, - cutlass::uint4b_t, - layout::RowMajor, - cutlass::uint4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate, - SPFormatType::Thread> { - - using Shape = gemm::GemmShape<16,8,128>; - - using ElementA = cutlass::uint4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::uint4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 1; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c, - uint32_t const &E, - int const id2 - ) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); #else diff --git a/include/cutlass/arch/reg_reconfig.h b/include/cutlass/arch/reg_reconfig.h index f7b12a706b..c1ffbeeb57 100644 --- a/include/cutlass/arch/reg_reconfig.h +++ b/include/cutlass/arch/reg_reconfig.h @@ -37,12 +37,9 @@ #include "cutlass/cutlass.h" -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) - #if (defined(__CUDA_ARCH_FEAT_SM90_ALL)) +#if (defined(__CUDA_ARCH__) &&\ + (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) #define CUDA_CTA_RECONFIG_ACTIVATED 1 - #endif -#else - #define CUDA_CTA_RECONFIG_ACTIVATED 0 #endif namespace cutlass { @@ -55,6 +52,7 @@ void warpgroup_reg_alloc(){ asm volatile( "setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount) ); #endif } + template CUTLASS_DEVICE void warpgroup_reg_dealloc(){ diff --git a/include/cutlass/array.h b/include/cutlass/array.h index dcaa1093c8..499d45c724 100644 --- a/include/cutlass/array.h +++ b/include/cutlass/array.h @@ -785,6 +785,24 @@ struct reciprocal_approximate> { } }; +template +struct reciprocal_approximate_ftz> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs) const { + + Array result; + reciprocal_approximate_ftz scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i]); + } + + return result; + } +}; + template struct maximum, false> { @@ -979,6 +997,10 @@ struct minimum, true> { } }; +template +struct minimum_with_nan_propagation> : minimum, true> +{}; + template struct negate> { diff --git a/include/cutlass/bfloat16.h b/include/cutlass/bfloat16.h index c2e6cb0de6..50506c73be 100644 --- a/include/cutlass/bfloat16.h +++ b/include/cutlass/bfloat16.h @@ -100,6 +100,17 @@ struct alignas(2) bfloat16_t { /// Default constructor bfloat16_t() = default; + /// Reinterpret cast from CUDA's __nv_bfloat16 type + CUTLASS_HOST_DEVICE + explicit bfloat16_t(__nv_bfloat16 const & x) { + #if defined(__CUDA_ARCH__) + storage = reinterpret_cast(x); + #else + __nv_bfloat16_raw raw(x); + std::memcpy(&storage, &raw.x, sizeof(storage)); + #endif + } + /// Floating-point conversion - round toward nearest CUTLASS_HOST_DEVICE explicit bfloat16_t(float x) { diff --git a/include/cutlass/cluster_launch.hpp b/include/cutlass/cluster_launch.hpp index 36f603e31a..3d140eaa84 100644 --- a/include/cutlass/cluster_launch.hpp +++ b/include/cutlass/cluster_launch.hpp @@ -133,7 +133,8 @@ struct ClusterLauncher { size_t const smem_size, cudaStream_t cuda_stream, void const* kernel, - void** kernel_params) { + void** kernel_params, + bool launch_with_pdl = false) { #if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) if (check_cluster_dims(grid_dims, cluster_dims) != Status::kSuccess) { CUTLASS_TRACE_HOST("ClusterLauncher: check_cluster_dims() failed. Aborting."); @@ -152,14 +153,19 @@ struct ClusterLauncher { launch_config.dynamicSmemBytes = smem_size; launch_config.stream = cuda_stream; - cudaLaunchAttribute launch_attribute[1]; + cudaLaunchAttribute launch_attribute[2]; + launch_attribute[0].id = cudaLaunchAttributeClusterDimension; launch_attribute[0].val.clusterDim.x = cluster_dims.x; launch_attribute[0].val.clusterDim.y = cluster_dims.y; launch_attribute[0].val.clusterDim.z = cluster_dims.z; + launch_attribute[1].id = cudaLaunchAttributeProgrammaticStreamSerialization; + launch_attribute[1].val.programmaticStreamSerializationAllowed = 1; + + launch_config.numAttrs = launch_with_pdl ? 2 : 1; + launch_config.attrs = launch_attribute; - launch_config.numAttrs = 1; CUTLASS_TRACE_HOST("ClusterLauncher: Launching GPC_CLUSTER_GRID GridDims = " "(" << grid_dims.x << ", " << grid_dims.y << ", " << grid_dims.z << "), " diff --git a/include/cutlass/complex.h b/include/cutlass/complex.h index 1f92b667e6..6d0bf31df6 100644 --- a/include/cutlass/complex.h +++ b/include/cutlass/complex.h @@ -43,6 +43,7 @@ #include "cutlass/cutlass.h" #include "cutlass/functional.h" +#include "cutlass/platform/platform.h" #include "cutlass/real.h" #include "cutlass/numeric_types.h" @@ -117,6 +118,18 @@ double const &imag(cuDoubleComplex const &z) { return z.y; } /// Returns the imaginary part of the complex number CUTLASS_HOST_DEVICE double &imag(cuDoubleComplex &z) { return z.y; } + +// Returns the conjugate of the complex number +CUTLASS_HOST_DEVICE cuFloatComplex +conj(cuFloatComplex const& z) { + return make_cuFloatComplex(z.x, -z.y); +} + +// Returns the conjugate of the complex number +CUTLASS_HOST_DEVICE cuDoubleComplex +conj(cuDoubleComplex const& z) { + return make_cuDoubleComplex(z.x, -z.y); +} #endif /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -315,60 +328,90 @@ class complex #endif }; +// Complex conjugate +template +CUTLASS_HOST_DEVICE complex conj(complex const& z) { + return {z.real(), -z.imag()}; +} + /////////////////////////////////////////////////////////////////////////////////////////////////// // // Accessors for complex template // -/// Returns the real part of the complex number -template -CUTLASS_HOST_DEVICE T const &real(complex const &z) { - return z.real(); -} +// Nonmember real and imag need to work for non-complex numbers too. +// That means cutlass::complex, std::complex, cuda::std::complex, and +// any user-defined complex number type that looks like std::complex. +// It's reasonable to assume that a "complex number type" has +// zero-argument real() and imag() member functions returning +// non-void. While cuFloatComplex and cuDoubleComplex lack those +// member functions, one-argument nonmember real and imag overloads +// for those types are defined above. -/// Returns the real part of the complex number -template -CUTLASS_HOST_DEVICE T &real(complex &z) { - return z.real(); -} +namespace detail { -/// Returns the imaginary part of the complex number -template -CUTLASS_HOST_DEVICE T const &imag(complex const &z) { - return z.imag(); -} +template +struct has_zero_argument_real_member_function : + cutlass::platform::false_type +{}; -/// Returns the imaginary part of the complex number template -CUTLASS_HOST_DEVICE T &imag(complex &z) { - return z.imag(); -} +struct has_zero_argument_real_member_function().real()) + > + > +> : cutlass::platform::true_type +{}; -/// Returns the real part of the real number template -CUTLASS_HOST_DEVICE T const &real(T const &r) { - return r; -} +constexpr bool has_zero_argument_real_member_function_v = + has_zero_argument_real_member_function::value; -/// Returns the real part of the real number -template -CUTLASS_HOST_DEVICE T &real(T &r) { - return r; -} +template +struct has_zero_argument_imag_member_function : + cutlass::platform::false_type +{}; -/// Returns the imaginary part of the real number template -CUTLASS_HOST_DEVICE T const &imag(T const &r) { - return T(); -} +struct has_zero_argument_imag_member_function().imag()) + > + > +> : cutlass::platform::true_type +{}; -/// Returns the imaginary part of the complex number template -CUTLASS_HOST_DEVICE T &imag(T &r) { - return T(); -} +constexpr bool has_zero_argument_imag_member_function_v = + has_zero_argument_imag_member_function::value; +} // namespace detail + +template +CUTLASS_HOST_DEVICE auto real(T z) { + if constexpr (detail::has_zero_argument_real_member_function_v) { + return z.real(); + } else { + return z; + } +} + +template +CUTLASS_HOST_DEVICE auto imag(T z) { + if constexpr (detail::has_zero_argument_imag_member_function_v) { + return z.imag(); + } else { + // Imaginary part of a non-complex input has the same type as the + // input, and its value is zero. CUTLASS assumes in this case + // that value-initializing T is well-formed and results in zero. + return T{}; + } +} + // // Output operators // @@ -395,10 +438,36 @@ std::ostream &operator<<(std::ostream &out, complex const &z) { // Non-member functions defined for complex numbers // -/// Returns the magnitude of the complex number +// abs returns the magnitude of the complex number. + +CUTLASS_HOST_DEVICE float abs(complex const &z) { + return ::hypot(z.real(), z.imag()); +} + +CUTLASS_HOST_DEVICE double abs(complex const &z) { + return ::hypot(z.real(), z.imag()); +} + +// In theory, it would make sense to add a complex +// specialization of abs here, since hypot works for long double too. +// In practice, long double doesn't have a portable number of bits or +// behavior, so users who care about higher-precision floating-point +// computation should probably insist on an actual FP128 type. + template CUTLASS_HOST_DEVICE T abs(complex const &z) { - return sqrt(norm(z)); + // cutlass::complex permits all kinds of T, including types that + // don't have NaN. For a generic floating-point type with Inf + // and/or NaN, LAPACK's DLAPY2 algorithm would make sense, as it + // would handle issues like avoiding unwarranted overflow if + // z.real() or z.imag() is slightly bigger than the square root of + // the max finite number. That could be a future improvement; for + // now, the code just uses the naive algorithm. + // + // Use the "swap two-step" idiom so that argument-dependent lookup + // can find any CUTLASS-specific overloads. + using cutlass::sqrt; + return sqrt(z.real() * z.real() + z.imag() * z.imag()); } /// Returns the magnitude of the complex number @@ -438,67 +507,66 @@ CUTLASS_HOST_DEVICE R norm_accumulate(complex const &z, R const &accumulator) static_cast(imag(z)) * static_cast(imag(z)); } -CUTLASS_HOST_DEVICE float conj(float const &z) { - return z; -} - -CUTLASS_HOST_DEVICE double conj(double const &z) { - return z; -} - -CUTLASS_HOST_DEVICE half_t conj(half_t const& z) { - return z; -} - -CUTLASS_HOST_DEVICE int32_t conj(int32_t const& z) { - return z; -} - -CUTLASS_HOST_DEVICE uint32_t conj(uint32_t const& z) { - return z; -} - -CUTLASS_HOST_DEVICE int64_t conj(int64_t const& z) { - return z; -} - -CUTLASS_HOST_DEVICE uint64_t conj(uint64_t const& z) { - return z; -} - -CUTLASS_HOST_DEVICE int4b_t conj(int4b_t const& z) { - return z; -} - -CUTLASS_HOST_DEVICE uint4b_t conj(uint4b_t const& z) { - return z; -} - -CUTLASS_HOST_DEVICE bfloat16_t conj(bfloat16_t const& z) { - return z; -} - -CUTLASS_HOST_DEVICE uint1b_t conj(uint1b_t const& z) { - return z; -} - -CUTLASS_HOST_DEVICE tfloat32_t conj(tfloat32_t const& z) { - return z; +namespace detail { + +template +CUTLASS_HOST_DEVICE T conj_impl(T const& z, cutlass::platform::true_type) { + return conj(z); } -CUTLASS_HOST_DEVICE float_e4m3_t conj(float_e4m3_t const& z) { +template +CUTLASS_HOST_DEVICE T conj_impl(T const& z, cutlass::platform::false_type) { return z; } -CUTLASS_HOST_DEVICE float_e5m2_t conj(float_e5m2_t const& z) { - return z; +template +CUTLASS_HOST_DEVICE T conj_impl(T const& z) { + constexpr bool use_unqualified_conj = + ! cutlass::platform::is_arithmetic_v && + ! detail::has_cutlass_conj_v && + detail::has_unqualified_conj_v; + return conj_impl(z, cutlass::platform::bool_constant{}); } + +} // namespace detail - -/// Returns the complex conjugate -template -CUTLASS_HOST_DEVICE complex conj(complex const &z) { - return complex(real(z), -imag(z)); +// Return the complex conjugate of the input. +// +// This MUST be a function and not a function object, because it may +// be common practice for downstream types to define specifically +// cutlass::conj overloads, instead of overloads in their namespace. +// +// As a result of this being a function and not a function object, +// CUTLASS code needs to declare "using cutlass::conj;" in scope and +// then call this function unqualified, just like std::swap. +// +// If an overload already exists for cutlass::conj(T), that overload +// will be called instead of this one. Otherwise: +// +// 1. for arithmetic types, return z; +// +// 2. for types where (namespace-unqualified) conj(z) is well formed +// and cutlass::conj(z) is NOT well formed, return conj(z); and, +// +// 3. for everything else, return z. +// +// Regarding (1), the C++ Standard Library makes std::conj always +// return std::complex, even for (noncomplex) arithmetic types. +// cutlass::conj(T t) needs to return type T. This follows the +// convention of linear algebra software like the BLAS, where +// "conjugate transpose" means the same thing as "transpose" for a +// matrix of noncomplex numbers. +// +// Case (2) covers std::complex, cuda::std::complex, and non-Standard +// (including user-defined) complex number types (for which "conj(z)" +// is findable via argument-dependent lookup, but does not live in the +// cutlass namespace). It excludes cutlass::conj(z) in order to +// prevent infinite recursion. +// +// Case (3) covers non-Standard non-complex number types. +template +CUTLASS_HOST_DEVICE T conj(T const& z) { + return detail::conj_impl(z); } /// Projects the complex number z onto the Riemann sphere @@ -699,10 +767,30 @@ template struct conjugate> { CUTLASS_HOST_DEVICE complex operator()(complex const &a) const { - return conj(a); + // Invoke the complex overload specifically, rather than + // wasting the compiler's effort on overload resolution. + return cutlass::conj(a); + } +}; + +#if ! defined(__CUDACC_RTC__) +template <> +struct conjugate { + CUTLASS_HOST_DEVICE + cuFloatComplex operator()(cuFloatComplex const& z) const { + return make_cuFloatComplex(z.x, -z.y); } }; +template <> +struct conjugate { + CUTLASS_HOST_DEVICE + cuDoubleComplex operator()(cuDoubleComplex const& z) const { + return make_cuDoubleComplex(z.x, -z.y); + } +}; +#endif + /// Computes the square of a difference with optional conversion template struct magnitude_squared_difference, Output> { diff --git a/include/cutlass/conv/collective/detail.hpp b/include/cutlass/conv/collective/detail.hpp index 0f192209de..ac272c8e20 100644 --- a/include/cutlass/conv/collective/detail.hpp +++ b/include/cutlass/conv/collective/detail.hpp @@ -246,6 +246,9 @@ compute_lower_srt(ConvProblemShape const& problem_ return lower; } +template struct is_im2col_load { static constexpr bool value = false; }; +template <> struct is_im2col_load { static constexpr bool value = true; }; +template <> struct is_im2col_load { static constexpr bool value = true; }; ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::conv::collective::detail diff --git a/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp b/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp index 124c781b08..13bb7c515c 100644 --- a/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp @@ -131,6 +131,9 @@ struct CollectiveConv< && (cute::is_same_v || cute::is_same_v)), "GmemTiledCopyB - invalid SM90 TMA copy atom specified."); + static constexpr bool is_im2col_A = detail::is_im2col_load::value; + static constexpr bool is_im2col_B = detail::is_im2col_load::value; + // TMA converts f32 input to tf32 when copying from GMEM to SMEM // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. static constexpr bool ConvertF32toTF32A = cute::is_same_v; @@ -169,12 +172,11 @@ struct CollectiveConv< // Note that for fprop and dgrad kernel, the tma load mode is im2col for tensor A and tiled for // tensor B while for wgrad kernel, the tma load mode is tiled for tensor A and im2col for tensor // B since operand A, B is swapped. - // Get tma_load_a instantce. template static constexpr auto get_tma_load_a_instance(TensorA const& tensor_a, typename Arguments::ProblemShape const& problem_shape) { - if constexpr (ConvOp == conv::Operator::kFprop || ConvOp == conv::Operator::kDgrad) { + if constexpr (is_im2col_A) { // compute the upper and lower corners based on the conv padding auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape); @@ -203,7 +205,7 @@ struct CollectiveConv< shape(stride_srt)); } // TMA tiled mode for tensor A in wgrad kernel. - else if constexpr (ConvOp == conv::Operator::kWgrad) { + else { return make_tma_copy( GmemTiledCopyA{}, tensor_a, @@ -217,16 +219,8 @@ struct CollectiveConv< template static constexpr auto get_tma_load_b_instance(TensorB const& tensor_b, typename Arguments::ProblemShape const& problem_shape) { - if constexpr (ConvOp == conv::Operator::kFprop || ConvOp == conv::Operator::kDgrad) { - return make_tma_copy( - GmemTiledCopyB{}, - tensor_b, - SmemLayoutB{}(_,_,_0{}), - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{})); - } // TMA im2col mode for tensor B in wgrad kernel. - else if constexpr (ConvOp == conv::Operator::kWgrad) { + if constexpr (is_im2col_B) { // compute the upper and lower corners based on the conv padding auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape); @@ -246,6 +240,26 @@ struct CollectiveConv< shape(lower_srt), cute::reverse(shape(problem_shape.dilation))); } + else { + return make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_0{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); + } + } + + static constexpr auto + get_problem_shape_MNKL(typename Arguments::ProblemShape const& problem_shape) { + if constexpr (is_im2col_A || is_im2col_B) { + // transformation + im2col linearization + return problem_shape.get_linearized_problem_shape_MNKL(); + } + else { + // transformation + return problem_shape.get_transformed_problem_shape_MNKL(); + } } public: @@ -253,9 +267,7 @@ struct CollectiveConv< // Device side kernel params struct Params { using _Submode = decltype(take<0,NumTensorDimensions-1>(typename Arguments::ProblemShape::TensorExtent{})); - using ProblemShape = cute::conditional_t, - Shape<_Submode, int, _Submode>>; + using ProblemShape = decltype(get_problem_shape_MNKL(typename Arguments::ProblemShape{})); // Assumption: StrideA is congruent with Problem_MK // Select TMA load type according to convolution operator. @@ -283,6 +295,7 @@ struct CollectiveConv< TMA_A tma_load_a; TMA_B tma_load_b; ProblemShape problem_shape; + uint32_t tma_transaction_bytes = TmaTransactionBytes; }; // @@ -314,17 +327,18 @@ struct CollectiveConv< auto tma_load_a = get_tma_load_a_instance(tensor_a, args.problem_shape); auto tma_load_b = get_tma_load_b_instance(tensor_b, args.problem_shape); - auto problem_shape_mnk = args.problem_shape.get_transformed_problem_shape_MNK(); + auto problem_shape_mnkl = get_problem_shape_MNKL(args.problem_shape); return { tma_load_a, tma_load_b, - problem_shape_mnk + problem_shape_mnkl, + TmaTransactionBytes }; } template - CUTLASS_HOST_DEVICE static bool + static bool can_implement( ProblemShape const& problem_shape, Arguments const& args) { @@ -389,13 +403,12 @@ struct CollectiveConv< TensorA const& gA, TMA_LOAD_A& tma_load_a, TensorB const& gB, TMA_LOAD_B& tma_load_b, KTileIterator k_tile_iter, int k_tile_count, - int thad_idx, + int thread_idx, + uint32_t block_rank_in_cluster, TensorStorage& shared_tensors) { - int warp_idx = canonical_warp_idx_sync(); - int warp_idx_in_warp_group = warp_idx % 4; int lane_predicate = cute::elect_one_sync(); - if (warp_idx_in_warp_group == 0 and lane_predicate) { + if (lane_predicate) { Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) @@ -403,7 +416,8 @@ struct CollectiveConv< // Prepare the TMA loads for A and B // - dim3 cluster_local_block_id = cute::block_id_in_cluster(); + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; auto block_tma_a = tma_load_a.get_slice(cluster_local_block_id.y); auto block_tma_b = tma_load_b.get_slice(cluster_local_block_id.x); @@ -462,12 +476,10 @@ struct CollectiveConv< /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_producer_state) { - int warp_idx = canonical_warp_idx_sync(); - int warp_idx_in_warp_group = warp_idx % 4; int lane_predicate = cute::elect_one_sync(); // Issue the epilogue waits - if (warp_idx_in_warp_group == 0 and lane_predicate) { + if (lane_predicate) { /* This helps avoid early exit of blocks in Cluster * Waits for all stages to either be released (all * Consumer UNLOCKs), or if the stage was never used diff --git a/include/cutlass/conv/convnd_problem_shape.hpp b/include/cutlass/conv/convnd_problem_shape.hpp index a32389f61b..0172120538 100644 --- a/include/cutlass/conv/convnd_problem_shape.hpp +++ b/include/cutlass/conv/convnd_problem_shape.hpp @@ -352,15 +352,16 @@ struct ConvProblemShape { } } - // Get problem shape MNK according to following table: - // | | Fprop | Dgrad | Wgrad | - // | ---- | --------- | -------- | -------- | - // | Shape_M | (Q,P,Z,N) | (W,H,D,N) | (K) | - // | Shape_N | (K) | (C) | (C,S,R,T) | - // | Shape_K | (C,S,R,T) | (K,S,R,T) | (Q,P,Z,N) | + // Get problem shape MNKL according to following table: + // | | Fprop | Dgrad | Wgrad | + // | ---- | --------- | -------- | -------- | + // | Shape_M | (Q,P,Z,N) | (W/V,H/U,D/O,N) | (K) | + // | Shape_N | (K) | (C) | (C,S,R,T) | + // | Shape_K | (C,S,R,T) | (K,S,R,T) | (Q,P,Z,N) | + // | Shape_L | _1 | (V,U,O) | _1 | CUTLASS_HOST_DEVICE constexpr auto - get_transformed_problem_shape_MNK() const { + get_transformed_problem_shape_MNKL() const { using cute::insert; using cute::make_shape; using cute::reverse; @@ -370,32 +371,56 @@ struct ConvProblemShape { auto M_xformed = shape_C[0]; auto N_xformed = reverse(take<1, RankT>(shape_C)); auto K_xformed = reverse(take<0, RankT - 1>(shape_A)); + auto L_xformed = cute::Int<1>{}; - return make_shape(M_xformed, N_xformed, K_xformed); + return make_shape(M_xformed, N_xformed, K_xformed, L_xformed); } else if constexpr (ConvOp == conv::Operator::kFprop){ auto M_xformed = reverse(take<0, RankT - 1>(shape_C)); auto N_xformed = shape_C[RankT - 1]; auto K_xformed = reverse(take<1, RankT>(shape_B)); + auto L_xformed = cute::Int<1>{}; - return make_shape(M_xformed, N_xformed, K_xformed); + return make_shape(M_xformed, N_xformed, K_xformed, L_xformed); } else if constexpr (ConvOp == conv::Operator::kDgrad) { - auto M_xformed = reverse(take<0,RankT - 1>(shape_C)); + auto L_xformed = reverse(traversal_stride); // (V,U,O) + auto M_xformed = ceil_div(reverse(take<0,RankT - 1>(shape_C)), L_xformed); auto N_xformed = shape_C[RankT - 1]; // shape_B: [K,T,R,S,C], K_xformed: [K,S,R,T] auto K_xformed = insert<0>( (reverse(take<1,RankT - 1>(shape_B))), shape_B[0]); - return make_shape(M_xformed, N_xformed, K_xformed); + + return make_shape(M_xformed, N_xformed, K_xformed, L_xformed); } } + // Assuming im2col linearization + // Get problem shape MNKL according to following table: + // | | Fprop | Dgrad | Wgrad | + // | ---- | --------- | -------- | -------- | + // | Shape_M | (Q*P*Z*N) | ([W/V]*[H/U]*[D/O]*N) | (K) | + // | Shape_N | (K) | (C) | (C,S,R,T) | + // | Shape_K | (C,S,R,T) | (K,S,R,T) | (Q*P*Z*N) | + // | Shape_L | _1 | (V*U*O) | _1 | + CUTLASS_HOST_DEVICE + constexpr auto + get_linearized_problem_shape_MNKL() const { + auto [M, N, K, L] = get_transformed_problem_shape_MNKL(); + + if constexpr (ConvOp == conv::Operator::kFprop || ConvOp == conv::Operator::kDgrad) { + return cute::make_shape(cute::product(M), N, K, cute::product(L)); + } + else if constexpr (ConvOp == conv::Operator::kWgrad) { + return cute::make_shape(M, N, cute::product(K), L); + } + } // Get A extents. // fprop: A extents array contains [N,D,H,W,C]. Turn that into ((W,H,D,N), (C)) - // wgrad: A extents array contains [N,Z,P,Q,K]. Turn that into ((K), (Q,P,Z,N)) // dgrad: A extents array contains [N,Z,P,Q,K]. Turn that into ((Q,P,Z,N), (K)) + // wgrad: A extents array contains [N,Z,P,Q,K]. Turn that into ((K), (Q,P,Z,N)) CUTLASS_HOST_DEVICE constexpr auto get_shape_A() const { @@ -418,8 +443,8 @@ struct ConvProblemShape { // Get B extents. // fprop: B extents array contains [K,T,R,S,C]. Turn that into ((K), (C,S,R,T)) - // wgrad: B extents array contains [N,D,H,W,C]. Turn that into ((C), (W,H,D,N)) // dgrad: B extents array contains [K,T,R,S,C]. Turn that into ((C), (K,S,R,T)) + // wgrad: B extents array contains [N,D,H,W,C]. Turn that into ((C), (W,H,D,N)) CUTLASS_HOST_DEVICE constexpr auto get_shape_B() const { @@ -447,6 +472,30 @@ struct ConvProblemShape { } } + // Get C extents. + // fprop: C extents array contains [N,Z,P,Q,K]. Turn that into ((Q,P,Z,N), (K)) + // dgrad: C extents array contains [N,D,H,W,C]. Turn that into ((W,H,D,N), (C)) + // wgrad: C extents array contains [K,T,R,S,C]. Turn that into ((K), (C,S,R,T)) + CUTLASS_HOST_DEVICE + constexpr auto + get_shape_C() const { + using cute::make_shape; + using cute::reverse; + using cute::take; + + if constexpr (ConvOp == conv::Operator::kFprop || + ConvOp == conv::Operator::kDgrad) { + return make_shape( + reverse(take<0, RankT - 1>(shape_C)), + shape_C[RankT - 1]); + } + else if constexpr (ConvOp == conv::Operator::kWgrad) { + return make_shape( + shape_C[0], + reverse(take<1, RankT>(shape_C))); + } + } + // Static method that returns the canonical strides of tensors (layouts are right major and compact) CUTLASS_HOST_DEVICE static constexpr TensorStride @@ -529,7 +578,9 @@ struct ConvProblemShape { // calculate n,z,p,q,k. // a helper lambda to compute a single spatial extent of the nzpqk tensor auto nzpqk_extent = [](int act_ext, int filter_ext, int pad_total, int dilation, int tstride) { - return 1 + (act_ext + pad_total - ((filter_ext -1) * dilation + 1)) / tstride; + auto tmp = act_ext + pad_total - ((filter_ext -1) * dilation + 1); + CUTLASS_ASSERT(tmp % tstride == 0); + return 1 + tmp / tstride; }; shape_xformed_act[0] = shape_act[0]; // Activation N extent diff --git a/include/cutlass/conv/device/conv_universal_adapter.hpp b/include/cutlass/conv/device/conv_universal_adapter.hpp index 603c47e84b..9812937e2e 100644 --- a/include/cutlass/conv/device/conv_universal_adapter.hpp +++ b/include/cutlass/conv/device/conv_universal_adapter.hpp @@ -228,29 +228,18 @@ class ConvUniversalAdapter /// Initializes conv state from arguments. Status initialize( - Arguments const& args, - void* workspace = nullptr, - cudaStream_t stream = nullptr, - CudaHostAdapter *cuda_adapter = nullptr) { + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + CUTLASS_TRACE_HOST("ConvUniversal::initialize() - workspace " << workspace << ", stream: " << (stream ? "non-null" : "null")); - size_t workspace_bytes = ConvKernel::get_workspace_size(args); - CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); - - if (workspace_bytes) { - if (!workspace) { - CUTLASS_TRACE_HOST(" error: device workspace must not be null"); - return Status::kErrorWorkspaceNull; - } - - CUTLASS_TRACE_HOST(" clearing device workspace"); - cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream); - if (cudaSuccess != result) { - result = cudaGetLastError(); // to clear the error bit - CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); - return Status::kErrorInternal; - } + // Initialize the workspace + Status status = ConvKernel::initialize_workspace(args, workspace, stream, cuda_adapter); + if (status != Status::kSuccess) { + return status; } // Initialize the Params structure @@ -297,7 +286,7 @@ class ConvUniversalAdapter /// Primary run() entry point API that is static allowing users to create and manage their own params. /// Supplied params struct must be construct by calling ConvKernel::to_underling_arguments() static Status - run(Params& params, cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { + run(Params& params, cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) { CUTLASS_TRACE_HOST("ConvUniversal::run()"); dim3 const block = ConvKernel::get_block_shape(); dim3 const grid = get_grid_shape(params); @@ -319,9 +308,13 @@ class ConvUniversalAdapter CUTLASS_ASSERT(cuda_adapter); if (cuda_adapter) { - launch_result = cuda_adapter->launch( - grid, cluster, block, smem_size, stream, kernel_params, 0 - ); + launch_result = cuda_adapter->launch(grid, + cluster, + block, + smem_size, + stream, + kernel_params, + kernel_index); } else { return Status::kErrorInternal; @@ -379,11 +372,12 @@ class ConvUniversalAdapter Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, - CudaHostAdapter *cuda_adapter = nullptr + CudaHostAdapter *cuda_adapter = nullptr, + int32_t kernel_index = 0 ) { Status status = initialize(args, workspace, stream, cuda_adapter); if (Status::kSuccess == status) { - status = run(params_, stream, cuda_adapter); + status = run(params_, stream, cuda_adapter, kernel_index); } return status; } diff --git a/include/cutlass/conv/device/direct_convolution.h b/include/cutlass/conv/device/direct_convolution.h index 8c13e2ed64..84953d8036 100644 --- a/include/cutlass/conv/device/direct_convolution.h +++ b/include/cutlass/conv/device/direct_convolution.h @@ -197,7 +197,7 @@ class DirectConvolution { params_.ptr_C = args.ref_C.data(); params_.ptr_D = args.ref_D.data(); params_.output_op = args.output_op; - params_.ptr_reordered_B = args.ref_reordered_B.data();; + params_.ptr_reordered_B = args.ref_reordered_B.data(); params_.semaphore = static_cast(workspace); return Status::kSuccess; diff --git a/include/cutlass/conv/dispatch_policy.hpp b/include/cutlass/conv/dispatch_policy.hpp index 6c7876dc96..039f4539c4 100644 --- a/include/cutlass/conv/dispatch_policy.hpp +++ b/include/cutlass/conv/dispatch_policy.hpp @@ -31,6 +31,7 @@ #pragma once #include "cutlass/conv/convolution.h" +#include "cutlass/epilogue/thread/activation.h" #include "cutlass/arch/arch.h" #include "cute/layout.hpp" @@ -38,6 +39,8 @@ ////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + namespace cutlass::conv { ////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/default_conv2d.h b/include/cutlass/conv/kernel/default_conv2d.h index f629bbb2d0..79bedb2c84 100644 --- a/include/cutlass/conv/kernel/default_conv2d.h +++ b/include/cutlass/conv/kernel/default_conv2d.h @@ -114,7 +114,10 @@ template < typename ElementTensor, typename ElementVector, typename OutputOp, - int ElementsPerAccess + int ElementsPerAccess, + typename PermuteDLayout = layout::NoPermute, + conv::StrideSupport StrideSupport = conv::StrideSupport::kUnity, + int Rank = 4 > struct DefaultConvEpilogueWithBroadcastSimt { using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithBroadcastSimt< @@ -124,7 +127,11 @@ struct DefaultConvEpilogueWithBroadcastSimt { ElementTensor, ElementVector, OutputOp, - ElementsPerAccess + ElementsPerAccess, + false, + PermuteDLayout, + StrideSupport, + Rank >::Epilogue; }; diff --git a/include/cutlass/conv/kernel/default_conv3d_fprop_with_broadcast.h b/include/cutlass/conv/kernel/default_conv3d_fprop_with_broadcast.h index 38e4de5c26..0fc291e605 100644 --- a/include/cutlass/conv/kernel/default_conv3d_fprop_with_broadcast.h +++ b/include/cutlass/conv/kernel/default_conv3d_fprop_with_broadcast.h @@ -197,7 +197,10 @@ struct DefaultConv3dFpropWithBroadcast < typename EpilogueOutputOp::ElementT, typename EpilogueOutputOp::ElementVector, EpilogueOutputOp, - ImplicitGemmBase::Epilogue::kElementsPerAccess + ImplicitGemmBase::Epilogue::kElementsPerAccess, + layout::NoPermute, + StrideSupport, + 5 >::Epilogue; // Define the kernel diff --git a/include/cutlass/conv/kernel/default_deconv2d.h b/include/cutlass/conv/kernel/default_deconv2d.h index ace21b92fa..4db152cd7a 100644 --- a/include/cutlass/conv/kernel/default_deconv2d.h +++ b/include/cutlass/conv/kernel/default_deconv2d.h @@ -181,7 +181,11 @@ struct DefaultDeconv2d < ThreadblockShape, WarpMmaSimtOp, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport::kStrided, + 4 >::Epilogue; // Define the kernel @@ -405,7 +409,11 @@ struct DefaultDeconv2d < ThreadblockShape, WarpMmaSimtOp, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport::kStrided, + 4 >::Epilogue; // Define the kernel @@ -627,7 +635,11 @@ struct DefaultDeconv2d < ThreadblockShape, WarpMmaSimtOp, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport::kStrided, + 4 >::Epilogue; // Define the kernel @@ -852,7 +864,11 @@ struct DefaultDeconv2d < ThreadblockShape, WarpMmaSimtOp, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport::kStrided, + 4 >::Epilogue; // Define the kernel diff --git a/include/cutlass/conv/kernel/default_deconv3d.h b/include/cutlass/conv/kernel/default_deconv3d.h index e9eb4cc5b0..70800c7af7 100644 --- a/include/cutlass/conv/kernel/default_deconv3d.h +++ b/include/cutlass/conv/kernel/default_deconv3d.h @@ -170,7 +170,11 @@ struct DefaultDeconv3d < ThreadblockShape, WarpMmaSimtOp, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport::kStrided, + 5 >::Epilogue; // Define the kernel @@ -282,7 +286,11 @@ struct DefaultDeconv3d < ThreadblockShape, WarpMmaSimtOp, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport::kStrided, + 5 >::Epilogue; // Define the kernel @@ -389,7 +397,11 @@ struct DefaultDeconv3d < ThreadblockShape, WarpMmaSimtOp, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport::kStrided, + 5 >::Epilogue; // Define the kernel @@ -501,7 +513,11 @@ struct DefaultDeconv3d < ThreadblockShape, WarpMmaSimtOp, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport::kStrided, + 5 >::Epilogue; // Define the kernel diff --git a/include/cutlass/conv/kernel/default_deconv3d_with_broadcast.h b/include/cutlass/conv/kernel/default_deconv3d_with_broadcast.h index 5c50c766d9..affe7a06f4 100644 --- a/include/cutlass/conv/kernel/default_deconv3d_with_broadcast.h +++ b/include/cutlass/conv/kernel/default_deconv3d_with_broadcast.h @@ -196,7 +196,10 @@ struct DefaultDeconv3dWithBroadcast < typename EpilogueOutputOp::ElementT, typename EpilogueOutputOp::ElementVector, EpilogueOutputOp, - ImplicitGemmBase::Epilogue::kElementsPerAccess + ImplicitGemmBase::Epilogue::kElementsPerAccess, + layout::NoPermute, + StrideSupport::kStrided, + 5 >::Epilogue; // Define the kernel @@ -273,7 +276,7 @@ struct DefaultDeconv3dWithBroadcast < >::Kernel; // Define epilogue - using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastSimtStridedDgrad< + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastSimt< ArchTag, typename ImplicitGemmBase::Epilogue::Shape, typename ImplicitGemmBase::Epilogue::WarpMmaOperator, @@ -281,7 +284,10 @@ struct DefaultDeconv3dWithBroadcast < typename EpilogueOutputOp::ElementT, typename EpilogueOutputOp::ElementVector, EpilogueOutputOp, - ImplicitGemmBase::Epilogue::kElementsPerAccess + ImplicitGemmBase::Epilogue::kElementsPerAccess, + layout::NoPermute, + StrideSupport::kStrided, + 5 >::Epilogue; // Define the kernel diff --git a/include/cutlass/conv/kernel/implicit_gemm_convolution.h b/include/cutlass/conv/kernel/implicit_gemm_convolution.h index c4de265e8d..b1e0b477a8 100644 --- a/include/cutlass/conv/kernel/implicit_gemm_convolution.h +++ b/include/cutlass/conv/kernel/implicit_gemm_convolution.h @@ -233,9 +233,9 @@ struct ImplicitGemmConvolution { ptr_A(args.ref_A.data()), iterator_B(args.problem_size, args.ref_B.layout()), ptr_B(args.ref_B.data()), - iterator_C(ConvOutputIteratorParameter::layout(args.ref_C), args.problem_size), + iterator_C(ConvOutputIteratorParameter::layout(args.ref_C), implicit_gemm_tensor_c_extent(kConvolutionalOperator, args.problem_size)), ptr_C(args.ref_C.data()), - iterator_D(ConvOutputIteratorParameter::layout(args.ref_D), args.problem_size), + iterator_D(ConvOutputIteratorParameter::layout(args.ref_D), implicit_gemm_tensor_c_extent(kConvolutionalOperator, args.problem_size)), ptr_D(args.ref_D.data()), output_op(args.output_op), semaphore(semaphore), diff --git a/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h b/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h index c768a2966e..1f27e0686d 100644 --- a/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h +++ b/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h @@ -257,9 +257,9 @@ struct ImplicitGemmConvolutionWithFusedEpilogue { ptr_A(args.ref_A.data()), iterator_B(args.problem_size, args.ref_B.layout()), ptr_B(args.ref_B.data()), - iterator_C(ConvOutputIteratorParameter::layout(args.ref_C)), + iterator_C(ConvOutputIteratorParameter::layout(args.ref_C), implicit_gemm_tensor_c_extent(kConvolutionalOperator, args.problem_size)), ptr_C(args.ref_C.data()), - iterator_D(ConvOutputIteratorParameter::layout(args.ref_D)), + iterator_D(ConvOutputIteratorParameter::layout(args.ref_D), implicit_gemm_tensor_c_extent(kConvolutionalOperator, args.problem_size)), ptr_D(args.ref_D.data()), output_op(args.output_op), semaphore(semaphore), diff --git a/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp b/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp index 43c6d5959b..95780bf84e 100644 --- a/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp @@ -51,12 +51,12 @@ namespace cutlass::conv::kernel { template < class CollectiveMainloop_, class CollectiveEpilogue_, - class TileSchedulerTag + class TileSchedulerTag_ > class ConvUniversal< CollectiveMainloop_, CollectiveEpilogue_, - TileSchedulerTag, + TileSchedulerTag_, cute::enable_if_t>> { @@ -90,6 +90,7 @@ class ConvUniversal< using EpilogueArguments = typename CollectiveEpilogue::Arguments; using EpilogueParams = typename CollectiveEpilogue::Params; + using TileSchedulerTag = TileSchedulerTag_; static_assert(cute::is_void_v, "TMA warp-specialized kernel does not support specializing the tile scheduler."); using TileScheduler = typename cutlass::gemm::kernel::detail::TileSchedulerSelector< @@ -144,7 +145,7 @@ class ConvUniversal< to_underlying_arguments(Arguments const& args, void* workspace) { (void) workspace; auto mainloop_params = CollectiveMainloop::to_underlying_arguments(args.mainloop, workspace); - auto problem_shape_MNKL = append<4>(mainloop_params.problem_shape, Int<1>{}); + auto problem_shape_MNKL = args.mainloop.problem_shape.get_transformed_problem_shape_MNKL(); return { mainloop_params, @@ -157,7 +158,7 @@ class ConvUniversal< can_implement(Arguments const& args) { bool implementable = true; implementable &= CollectiveMainloop::can_implement(args.mainloop.problem_shape, args.mainloop); - implementable &= CollectiveEpilogue::can_implement(args.mainloop.problem_shape.get_transformed_problem_shape_MNK(), args.epilogue); + implementable &= CollectiveEpilogue::can_implement(args.mainloop.problem_shape.get_transformed_problem_shape_MNKL(), args.epilogue); return implementable; } @@ -166,19 +167,17 @@ class ConvUniversal< return 0; } + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + // Computes the kernel launch grid shape based on runtime parameters static dim3 get_grid_shape(Params const& params) { - // The CONV mainloop params problem shape will be the cute::Shape<> rank-3 MNK tuple we want for grid planning - // Although conv problems do not have an L mode, we add it here to comply with the scheduler API - auto linear_problem_shape_MNKL = make_shape( - size<0>(params.mainloop.problem_shape), // M mode is linearized. - shape<1>(params.mainloop.problem_shape), - shape<2>(params.mainloop.problem_shape), - Int<1>{}); - return cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::get_tiled_cta_shape_mnl( - linear_problem_shape_MNKL, TileShape{}, ClusterShape{}); + params.mainloop.problem_shape, TileShape{}, ClusterShape{}); } static dim3 @@ -205,14 +204,25 @@ class ConvUniversal< Consumer = 1, }; + enum class ProducerWarpRole { + MainloopEpilogue = 0, + Warp1 = 1, + Warp2 = 2, + Warp3 = 3 + }; + // Kernel level shared memory storage SharedStorage& shared_storage = *reinterpret_cast(smem_buf); int thread_idx = int(threadIdx.x); + int lane_idx = canonical_lane_idx(); int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); + auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); int lane_predicate = cute::elect_one_sync(); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); // Issue Tma Descriptor Prefetch from a single thread if ((warp_idx == 0) && lane_predicate) { @@ -223,7 +233,7 @@ class ConvUniversal< // Mainloop Load pipeline using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; typename MainloopPipeline::Params mainloop_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer) { + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::MainloopEpilogue) { mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; } if (warp_group_role == WarpGroupRole::Consumer) { @@ -231,22 +241,24 @@ class ConvUniversal< } mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup; - mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes; MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); // Epilogue Load pipeline using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; typename EpiLoadPipeline::Params epi_load_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer) { + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::MainloopEpilogue) { epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; } if (warp_group_role == WarpGroupRole::Consumer) { epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; } epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); - epi_load_pipeline_params.producer_arv_count = 1; // 1 thread issues TMA load + epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; - epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + if constexpr (CollectiveEpilogue::RequiresTransactionBytes) { + epi_load_pipeline_params.transaction_bytes = params.epilogue.tma_transaction_bytes; + } EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); // Epilogue Store pipeline @@ -266,16 +278,26 @@ class ConvUniversal< PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + auto cluster_wait_fn = [&] () { + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer thread blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + return [] () { cute::cluster_wait(); }; + } + else { + __syncthreads(); + return [] () {}; // do nothing + } + } (); + // Separate out problem shape for convenience - auto M = get<0>(params.mainloop.problem_shape); - auto N = get<1>(params.mainloop.problem_shape); - auto K = get<2>(params.mainloop.problem_shape); - // output strides are coalesced so we linearize the output shape to match the shape/stride profiles - auto linear_problem_shape_MNKL = make_shape(size(M), N, K, Int<1>{}); + auto problem_shape_MNKL = append<4>(params.mainloop.problem_shape, _1{}); + auto [M, N, K, L] = problem_shape_MNKL; // TMA requires special handling of strides to deal with coord codomain mapping // Represent the full tensors -- get these from TMA - Tensor mA_mk = params.mainloop.tma_load_a.get_tma_tensor(make_shape(M, size(K))); + Tensor mA_mk = params.mainloop.tma_load_a.get_tma_tensor(make_shape(M, K)); Tensor mB_nk = params.mainloop.tma_load_b.get_tma_tensor(make_shape(N, K)); // Get the appropriate blocks for this thread block -- potential for thread block locality @@ -288,7 +310,8 @@ class ConvUniversal< // Compute m_coord, n_coord, and l_coord with their post-tiled shapes auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mk)); - auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nk)); + auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nk), compact_col_major(shape<2>(gB_nk))); + // The output shape M is linearized so the output coord M here should also be linearized. auto output_tile_coord = make_coord(int(blockIdx.x), n_coord, _, Int<0>{}); @@ -300,51 +323,43 @@ class ConvUniversal< auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); auto k_tile_count = size<2>(gA); - auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(cta_tile_shape); - auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(cta_tile_shape); - - // Make sure pipeline init is visible to all producers and consumer CTAs in cluster - if constexpr (size(ClusterShape{}) > 1) { - cute::cluster_arrive_relaxed(); - cute::cluster_wait(); - } - else { - __syncthreads(); - } - // In a warp specialized kernel, collectives expose data movement and compute operations separately CollectiveMainloop collective_mainloop; CollectiveEpilogue collective_epilogue{params.epilogue, shared_storage.tensors.epilogue}; + // Wait for all thread blocks in Cluster + cluster_wait_fn(); + if (warp_group_role == WarpGroupRole::Producer) { - collective_mainloop.load( - mainloop_pipeline, - mainloop_pipe_producer_state, - gA, params.mainloop.tma_load_a, - gB, params.mainloop.tma_load_b, - k_tile_iter, k_tile_count, - thread_idx, - shared_storage.tensors.mainloop - ); - // Update starting mainloop pipeline state for the pipeline drain - mainloop_pipe_producer_state.advance(k_tile_count); - // Make sure mainloop consumer has been waited upon before issuing epilogue load - collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); - - if (collective_epilogue.is_producer_load_needed()) { - collective_epilogue.load( - epi_load_pipeline, - epi_load_pipe_producer_state, - linear_problem_shape_MNKL, - cta_tile_shape, - output_tile_coord, - tiled_mma, - warp_group_thread_idx, - shared_storage.tensors.epilogue + if (producer_warp_role == ProducerWarpRole::MainloopEpilogue) { + collective_mainloop.load( + mainloop_pipeline, + mainloop_pipe_producer_state, + gA, params.mainloop.tma_load_a, + gB, params.mainloop.tma_load_b, + k_tile_iter, k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop ); - // Update starting load pipeline state for the pipeline drain - epi_load_pipe_producer_state.advance(c_tile_count); - collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + // Update starting mainloop pipeline state for the pipeline drain + mainloop_pipe_producer_state.advance(k_tile_count); + // Make sure mainloop consumer has been waited upon before issuing epilogue load + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + + if (collective_epilogue.is_producer_load_needed()) { + epi_load_pipe_producer_state = collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + cta_tile_shape, + output_tile_coord, + tiled_mma, + lane_idx, + shared_storage.tensors.epilogue + ); + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } } } else if (warp_group_role == WarpGroupRole::Consumer) { @@ -368,12 +383,13 @@ class ConvUniversal< ); // Epilogue and write to gD + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = collective_epilogue.store( epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline, epi_store_pipe_producer_state, - linear_problem_shape_MNKL, + problem_shape_MNKL, cta_tile_shape, output_tile_coord, accumulators, @@ -381,6 +397,13 @@ class ConvUniversal< warp_group_thread_idx, shared_storage.tensors.epilogue ); + + collective_epilogue.store_tail( + epi_load_pipeline, + epi_load_pipe_consumer_state_next, + epi_store_pipeline, + epi_store_pipe_producer_state_next + ); } } }; diff --git a/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h index 890b45b365..943ab88cfc 100644 --- a/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h @@ -251,7 +251,7 @@ class Conv3dDgradFilterTileAccessIteratorAnalytic { static Status can_implement(Conv3dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % (128/sizeof_bits::value)) { + if (problem_size.C % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h index aa8c0cc18e..2d5837dd3d 100644 --- a/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h @@ -272,7 +272,7 @@ class Conv3dDgradFilterTileAccessIteratorOptimized { static Status can_implement(Conv3dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % (128/sizeof_bits::value)) { + if (problem_size.C % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h index 21f68d97d5..30b7f2fcf6 100644 --- a/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h @@ -325,7 +325,7 @@ class Conv3dDgradOutputGradientTileAccessIteratorAnalytic < static Status can_implement(ConvProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % (128/sizeof_bits::value)) { + if (problem_size.K % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h index 79d6302e8d..5a53c8cbd5 100644 --- a/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h @@ -466,7 +466,7 @@ class Conv3dDgradOutputGradientTileAccessIteratorOptimized { } // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % (128/sizeof_bits::value)) { + if (problem_size.K % AccessType::kElements) { return Status::kErrorNotSupported; } diff --git a/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h index cd853503d5..f0f9a86a34 100644 --- a/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h @@ -272,7 +272,7 @@ class Conv3dFpropActivationTileAccessIteratorAnalytic { static Status can_implement(ConvProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % (128/sizeof_bits::value)) { + if (problem_size.C % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h index 40860d271f..78b270eb9a 100644 --- a/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h @@ -455,7 +455,7 @@ class Conv3dFpropActivationTileAccessIteratorOptimized { static Status can_implement(Conv3dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % (128/sizeof_bits::value)) { + if (problem_size.C % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h index 85dd37ffdb..9f04adc40b 100644 --- a/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h @@ -238,11 +238,10 @@ class Conv3dFpropFilterTileAccessIteratorAnalytic { /// Determines whether the Implicit GEMM can execute the given problem. CUTLASS_HOST_DEVICE static Status can_implement(ConvProblemSize const &problem_size) { - auto input_channels = (IsDeconv ? problem_size.K : problem_size.C); auto output_channels = (IsDeconv ? problem_size.C : problem_size.K); // check alignment constraint on iterator's contiguous dimension - if (input_channels % (128/sizeof_bits::value)) { + if (input_channels % AccessType::kElements) { return Status::kErrorInvalidProblem; } return Status::kSuccess; diff --git a/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h index ac49cf0781..efe34497f5 100644 --- a/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h @@ -260,14 +260,12 @@ class Conv3dFpropFilterTileAccessIteratorOptimized{ /// Determines whether the Implicit GEMM can execute the given problem. CUTLASS_HOST_DEVICE static Status can_implement(Conv3dProblemSize const &problem_size) { - auto input_channels = (IsDeconv ? problem_size.K : problem_size.C); // check alignment constraint on iterator's contiguous dimension - if (input_channels % (128/sizeof_bits::value)) { + if (input_channels % AccessType::kElements) { return Status::kErrorInvalidProblem; } - return Status::kSuccess; } }; diff --git a/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h index bd08293ff6..cc8faea701 100644 --- a/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h @@ -270,7 +270,7 @@ class Conv3dWgradActivationTileAccessIteratorAnalytic { static Status can_implement(Conv3dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % (128/sizeof_bits::value)) { + if (problem_size.C % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h index 95ac69404b..2b10d207fa 100644 --- a/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h @@ -250,7 +250,7 @@ class Conv3dWgradActivationTileAccessIteratorOptimized { fast_divmod(p, q, residual, problem_size_.Q, params_.q_mul, params_.q_shr); int d = z * problem_size_.stride_d + precomputed_filter_t_[iteration_contiguous_]; - int h = p * problem_size_.stride_h + precomputed_filter_r_[iteration_contiguous_];; + int h = p * problem_size_.stride_h + precomputed_filter_r_[iteration_contiguous_]; int w = q * problem_size_.stride_w + precomputed_filter_s_[iteration_contiguous_]; return TensorCoord(n, d, h, w, filter_c_[iteration_contiguous_]); @@ -300,7 +300,7 @@ class Conv3dWgradActivationTileAccessIteratorOptimized { static Status can_implement(Conv3dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % (128/sizeof_bits::value)) { + if (problem_size.C % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h index 67e49f7ab6..be9d4fb7ac 100644 --- a/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h @@ -248,7 +248,7 @@ class Conv3dWgradOutputGradientTileAccessIteratorAnalytic { static Status can_implement(Conv3dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % (128/sizeof_bits::value)) { + if (problem_size.K % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h index 487009f54e..0ef145f19d 100644 --- a/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h @@ -291,7 +291,7 @@ class Conv3dWgradOutputGradientTileAccessIteratorOptimized { static Status can_implement(Conv3dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % (128/sizeof_bits::value)) { + if (problem_size.K % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/cuda_host_adapter.hpp b/include/cutlass/cuda_host_adapter.hpp index a62ebee7a6..28f5ae0e8d 100644 --- a/include/cutlass/cuda_host_adapter.hpp +++ b/include/cutlass/cuda_host_adapter.hpp @@ -44,10 +44,29 @@ #include #endif -#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))) -# define CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED +///////////////////////////////////////////////////////////////////////////////////////////////// + +// NVRTC doesn't need definitions for these host classes + +#if ((__CUDACC_VER_MAJOR__ >= 12) || \ + ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))) \ + && !defined(__CUDACC_RTC__) +#define CUDA_HOST_ADAPTER_LAUNCH_ATTRIBUTES_ENABLED +#endif + +#if ((__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__)) +#define CUDA_HOST_ADAPTER_TENSORMAP_ENABLED #endif +// Include for CUDA Driver API calls if any of these capabilities are enabled. +#if defined(CUDA_HOST_ADAPTER_LAUNCH_ATTRIBUTES_ENABLED) || \ + defined(CUDA_HOST_ADAPTER_TENSORMAP_ENABLED) + +#include + +#endif // defined(CUDA_HOST_ADAPTER_LAUNCH_ATTRIBUTES_ENABLED) || + // defined(CUDA_HOST_ADAPTER_TENSORMAP_ENABLED) + ///////////////////////////////////////////////////////////////////////////////////////////////// // @@ -63,6 +82,45 @@ namespace cutlass { ///////////////////////////////////////////////////////////////////////////////////////////////// +/// This class manages runtime CUlaunchAttribute that can be supplied to CudaHostAdapter +/// CudaHostLaunchAttributes will be an empty struct in earlier CTK where CUlaunchAttribute +/// is not introduced. +struct CudaHostLaunchAttributes { + +#if defined(CUDA_HOST_ADAPTER_LAUNCH_ATTRIBUTES_ENABLED) + + /// Reasonable maximum launch attributes that are commonly applied + static constexpr int32_t kMaximumAttributeCount = 5; + + /// Launch attributes + CUlaunchAttribute launch_attributes[kMaximumAttributeCount]; + int32_t attribute_count = 0; + + CUTLASS_HOST_DEVICE + CudaHostLaunchAttributes(CUlaunchAttribute *launch_attributes_ = nullptr, + int32_t attribute_count_ = 0) { + CUTLASS_ASSERT(attribute_count_ >= 0 && attribute_count_ < kMaximumAttributeCount); + for (int32_t i = 0; i < attribute_count_ && i < kMaximumAttributeCount; ++i) { + launch_attributes[i] = launch_attributes_[i]; + } + attribute_count = attribute_count_; + } + + CUTLASS_HOST_DEVICE + CUlaunchAttribute const* data() const { + return launch_attributes; + } + + CUTLASS_HOST_DEVICE + size_t size() const { + return attribute_count; + } + +#endif // (CUDA_HOST_ADAPTER_LAUNCH_ATTRIBUTES_ENABLED) + +}; + + /// This class defines an object which abstracts interactions between the CUTLASS device-wide GEMM and /// CUDA. The intention is to enable CUTLASS to be used with both the CUDA Runtime API and CUDA Driver API. struct CudaHostAdapter { @@ -81,6 +139,8 @@ struct CudaHostAdapter { void *kernel_handles[kMaximumKernelCount]; int32_t kernel_count = 0; + CudaHostLaunchAttributes launch_attributes; + // // Methods // @@ -89,70 +149,80 @@ struct CudaHostAdapter { CudaHostAdapter() = default; /// Dtor - virtual ~CudaHostAdapter() {} + virtual ~CudaHostAdapter() = default; /// Copy Ctor - inline CudaHostAdapter(const CudaHostAdapter & rhs): - kernel_count(rhs.kernel_count) - { + CUTLASS_HOST_DEVICE + CudaHostAdapter(const CudaHostAdapter & rhs) + : kernel_count(rhs.kernel_count), + launch_attributes(rhs.launch_attributes) { CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount); + for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) { kernel_handles[i] = rhs.kernel_handles[i]; } } /// Copy Assignment - inline CudaHostAdapter& operator=(const CudaHostAdapter & rhs) { - + CUTLASS_HOST_DEVICE + CudaHostAdapter& operator=(const CudaHostAdapter & rhs) { CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount); for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) { kernel_handles[i] = rhs.kernel_handles[i]; } kernel_count = rhs.kernel_count; + + launch_attributes = rhs.launch_attributes; + return *this; } + /// Move ctor - inline CudaHostAdapter(CudaHostAdapter && rhs): - kernel_count(rhs.kernel_count) - { + CUTLASS_HOST_DEVICE + CudaHostAdapter(CudaHostAdapter && rhs) + : kernel_count(rhs.kernel_count), + launch_attributes(std::move(rhs.launch_attributes)) { CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount); + for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) { kernel_handles[i] = rhs.kernel_handles[i]; } } - /// Move assignment - inline CudaHostAdapter& operator=(CudaHostAdapter && rhs) { - + // / Move assignment + CUTLASS_HOST_DEVICE + CudaHostAdapter& operator=(CudaHostAdapter && rhs) { CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount); for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) { kernel_handles[i] = rhs.kernel_handles[i]; } - kernel_count = rhs.kernel_count; - + launch_attributes = std::move(rhs.launch_attributes); return *this; } /// Ctor - inline CudaHostAdapter( - void **kernel_handles_, - int32_t kernel_count_ - ): - kernel_count(kernel_count_) - { - CUTLASS_ASSERT(kernel_count >= 0); + CUTLASS_HOST_DEVICE + CudaHostAdapter(void **kernel_handles_, + int32_t kernel_count_, + CudaHostLaunchAttributes const &launch_attributes_ = { }) + : kernel_count(kernel_count_), + launch_attributes(launch_attributes_) { + CUTLASS_ASSERT(kernel_count >= 0 && kernel_count < kMaximumKernelCount); + for (int32_t i = 0; i < kernel_count && i < kMaximumKernelCount; ++i) { kernel_handles[i] = kernel_handles_[i]; } } /// Returns true if the CudaHostAdapter is empty (kernel_count == 0) - inline bool empty() const { return !kernel_count; } + CUTLASS_HOST_DEVICE + bool empty() const { return !kernel_count; } /// Returns kernel_count - inline size_t size() const { return static_cast(kernel_count); } + CUTLASS_HOST_DEVICE + size_t size() const { return static_cast(kernel_count); } /// Queries the occupancy of a kernel virtual Status query_occupancy( @@ -181,6 +251,48 @@ struct CudaHostAdapter { void** kernel_params, int32_t kernel_index) const = 0; +#if defined(CUDA_HOST_ADAPTER_TENSORMAP_ENABLED) + + /// Create a tensor map descriptor object representing im2col memory region. + virtual CUresult tensorMapEncodeIm2col ( + CUtensorMap* tensorMap, + CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, + void* globalAddress, + const cuuint64_t* globalDim, + const cuuint64_t* globalStrides, + const int* pixelBoxLowerCorner, + const int* pixelBoxUpperCorner, + cuuint32_t channelsPerPixel, + cuuint32_t pixelsPerColumn, + const cuuint32_t* elementStrides, + CUtensorMapInterleave interleave, + CUtensorMapSwizzle swizzle, + CUtensorMapL2promotion l2Promotion, + CUtensorMapFloatOOBfill oobFill) const = 0; + + /// Create a tensor map descriptor object representing tiled memory region. + virtual CUresult tensorMapEncodeTiled ( + CUtensorMap* tensorMap, + CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, + void* globalAddress, + const cuuint64_t* globalDim, + const cuuint64_t* globalStrides, + const cuuint32_t* boxDim, + const cuuint32_t* elementStrides, + CUtensorMapInterleave interleave, + CUtensorMapSwizzle swizzle, + CUtensorMapL2promotion l2Promotion, + CUtensorMapFloatOOBfill oobFill) const = 0; + + /// Modify an existing tensor map descriptor with an updated global address. + virtual CUresult tensorMapReplaceAddress( + CUtensorMap* tensorMap, + void* globalAddress) const = 0; + +#endif // defined(CUDA_HOST_ADAPTER_TENSORMAP_ENABLED) + protected: /** @@ -198,12 +310,12 @@ struct CudaHostAdapter { /// Fills a buffer in Global Memory with a byte sequence copied from host memory template + CUTLASS_HOST_DEVICE Status memsetDevice( - void* destination, - FillValueType fill_value, - size_t count, - cudaStream_t stream) const - { + void* destination, + FillValueType fill_value, + size_t count, + cudaStream_t stream) const { return this->memsetDeviceImpl( destination, &fill_value, diff --git a/include/cutlass/detail/layout.hpp b/include/cutlass/detail/layout.hpp index 1616544291..429e5c2f06 100644 --- a/include/cutlass/detail/layout.hpp +++ b/include/cutlass/detail/layout.hpp @@ -126,36 +126,18 @@ struct TagToStrideC { using type = cute::Stride, cute::Int<1>, cute::Int<0>>; }; -// Conv: Maps to modes (PN, C, _0) for compatiblity with GEMM epilogues expecting a batch mode stride -template <> -struct TagToStrideC { - using type = cute::Stride, cute::Int<0>>; -}; - // Conv: Maps to modes ((P,Q,N), C, _0) for compatiblity with GEMM epilogues expecting a batch mode stride template <> struct TagToStrideC { using type = cute::Stride, cute::Int<1>, cute::Int<0>>; }; -// Conv: Maps to modes (PQN, C, _0) for compatiblity with GEMM epilogues expecting a batch mode stride -template <> -struct TagToStrideC { - using type = cute::Stride, cute::Int<0>>; -}; - // Conv: Maps to modes ((P,Q,Z,N), C, _0) for compatiblity with GEMM epilogues expecting a batch mode stride template <> struct TagToStrideC { using type = cute::Stride, cute::Int<1>, cute::Int<0>>; }; -// Conv: Maps to modes (PQZN, C, _0) for compatiblity with GEMM epilogues expecting a batch mode stride -template <> -struct TagToStrideC { - using type = cute::Stride, cute::Int<0>>; -}; - // Conv: Maps to modes (K, (C,S), _0) for compatiblity with GEMM epilogues expecting a batch mode stride template <> struct TagToStrideC { @@ -174,6 +156,24 @@ struct TagToStrideC { using type = cute::Stride, int64_t, int64_t, int64_t>, cute::Int<0>>; }; +// Conv: Maps to modes ((C,S), K, _0) for compatiblity with GEMM epilogues expecting a batch mode stride +template <> +struct TagToStrideC { + using type = cute::Stride, int64_t>, int64_t, cute::Int<0>>; +}; + +// Conv: Maps to modes ((C,S,R), K, _0) for compatiblity with GEMM epilogues expecting a batch mode stride +template <> +struct TagToStrideC { + using type = cute::Stride, int64_t, int64_t>, int64_t, cute::Int<0>>; +}; + +// Conv: Maps to modes ((C,S,R,T), K, _0) for compatiblity with GEMM epilogues expecting a batch mode stride +template <> +struct TagToStrideC { + using type = cute::Stride, int64_t, int64_t, int64_t>, int64_t, cute::Int<0>>; +}; + // Convenience aliases template using TagToStrideA_t = typename TagToStrideA::type; @@ -318,6 +318,23 @@ get_alignment_count_from_gmem_tiled_copy() { } } +// Return alignment bit requirements for the GEMM inputs. +template < + class ElementType +> +constexpr int +get_input_alignment_bits() { + return 128; +} + +// Return alignment bit requirements for the GEMM outputs. +template +constexpr int +get_output_alignment_bits() { + return 128; +} + + // Return the shape that is associated with stride-1 mode, or 1 if not found template CUTLASS_HOST_DEVICE constexpr diff --git a/include/cutlass/epilogue/collective/builders/sm90_builder.inl b/include/cutlass/epilogue/collective/builders/sm90_builder.inl index b8952d15dd..2ca62c9794 100644 --- a/include/cutlass/epilogue/collective/builders/sm90_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm90_builder.inl @@ -39,6 +39,7 @@ #include "cutlass/gemm/collective/builders/sm90_common.inl" #include "cutlass/epilogue/dispatch_policy.hpp" #include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/collective/builders/sm90_common.inl" #include "cutlass/epilogue/thread/linear_combination.h" #include "cutlass/epilogue/thread/linear_combination_generic.h" #include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" @@ -69,12 +70,15 @@ sm90_get_tma_dispatch_policy() { constexpr int FragmentSize = size(EpilogueTileMN{}) / (detail::sm90_is_cooperative_v ? 256 : 128); // 8b residuals load fast and consume little smem, so the perf cost of waiting on stores to finish outweighs the cost of extra allocation constexpr bool ReuseSmem = (sizeof_bits_v == sizeof_bits_v) && (sizeof_bits_v > 8); - constexpr bool DelayTmaStore = is_void_v; // TMA store delay performs worse with residual loads + // TMA store delay performs worse with residual loads and compilicates tensormap updates for Ptr-Array GEMMs + constexpr bool DelayTmaStore = is_void_v && !detail::sm90_is_tma_ptr_array_v; constexpr int StagesD = cute::min(EpiTiles, 2); constexpr int StagesC = ReuseSmem ? cute::max(cute::min(EpiTiles, 4), StagesD+1) : cute::min(EpiTiles, 4); - return Sm90TmaWarpSpecialized{}; + return cute::conditional_t, + Sm90PtrArrayTmaWarpSpecialized, + Sm90TmaWarpSpecialized>{}; } // Returns the smem layout atom to be used for C or D matrix @@ -159,45 +163,6 @@ sm90_compute_tile_shape_or_override() { } } -// Selects the largest vectorized smem store atom available -template -constexpr auto -sm90_get_smem_store_op_for_accumulator() { - using namespace cute; - - if constexpr (sizeof(ElementD) == 2 && size<0>(GmemStrideTypeD{}) == 1) { - return SM90_U16x8_STSM_T{}; - } - else if constexpr (sizeof(ElementD) == 2 && size<1>(GmemStrideTypeD{}) == 1) { - return SM90_U32x4_STSM_N{}; - } - else { - // auto-vectorizing store - return AutoVectorizingCopyWithAssumedAlignment{}; - } -} - -// Selects the largest vectorized smem load atom available -template -constexpr auto -sm90_get_smem_load_op_for_source() { - using namespace cute; - - // Reuse the logic from smem store selector - using SmemStoreOp = decltype(sm90_get_smem_store_op_for_accumulator()); - - if constexpr (cute::is_same_v) { - return SM75_U16x8_LDSM_T{}; - } - else if constexpr (cute::is_same_v) { - return SM75_U32x4_LDSM_N{}; - } - else { - // auto-vectorizing load - return AutoVectorizingCopyWithAssumedAlignment<128>{}; - } -} - // callbacks builder with TMA aux out template < int StagesC, @@ -299,11 +264,20 @@ struct Sm90TmaBuilderImpl { SM90_TMA_LOAD >; + // Get the smallest tiled copy we can use to retile the accumulators + using CopyAtomC = Copy_Atom; + + using FusionDispatchPolicy = Sm90TmaWarpSpecialized; + // TMA builder allows for passing callbacks directly, which is either a fusion::FusionCallbacks // instance or a direct visitor implementation, e.g. fusion::Sm90LinearCombination using FusionCallbacks = typename CallbacksBuilder< - DispatchPolicy, + FusionDispatchPolicy, FusionOpOrCallbacks, TileShape_MNK, EpilogueTile_MN, @@ -324,7 +298,8 @@ struct Sm90TmaBuilderImpl { decltype(detail::sm90_get_smem_load_op_for_source()), CopyOpS2G, decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), - decltype(detail::sm90_get_smem_store_op_for_accumulator()) + decltype(detail::sm90_get_smem_store_op_for_accumulator()), + CopyAtomC >; }; @@ -404,19 +379,6 @@ struct AuxStoreDescriptor { decltype(detail::sm90_get_smem_store_op_for_accumulator()); }; -template< - typename EpilogueDescriptor, - typename ElementVector -> -struct RowBroadcastDescriptor { - constexpr static int Stages = ceil_div( - EpilogueDescriptor::StagesC, - size(shape_div(take<0, 2>(typename EpilogueDescriptor::TileShape{}), typename EpilogueDescriptor::EpilogueTile{})) - ) + 1; - - using Element = ElementVector; -}; - } // namespace detail /////////////////////////////////////////////////////////////////////////////// @@ -520,7 +482,8 @@ struct CollectiveBuilder< Schedule, FusionOperation, cute::enable_if_t || - cute::is_same_v >> { + cute::is_same_v || + cute::is_same_v >> { private: using ElementD = cute::conditional_t, fusion::get_element_aux_t, ElementD_>; @@ -748,6 +711,9 @@ private: using GmemStrideTypeC = gemm::TagToStrideC_t; using GmemStrideTypeD = gemm::TagToStrideC_t; + // Get the smallest tiled copy we can use to retile the accumulators + using CopyAtomC = Copy_Atom; + public: using CollectiveOp = cutlass::epilogue::collective::Sm90EpilogueTmaWarpSpecializedBiasElementwise< DispatchPolicy::StagesC, @@ -765,7 +731,8 @@ public: decltype(detail::sm90_get_smem_load_op_for_source()), SM90_TMA_STORE, decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), - decltype(detail::sm90_get_smem_store_op_for_accumulator()) + decltype(detail::sm90_get_smem_store_op_for_accumulator()), + CopyAtomC >; }; diff --git a/include/cutlass/epilogue/collective/builders/sm90_common.inl b/include/cutlass/epilogue/collective/builders/sm90_common.inl new file mode 100644 index 0000000000..cd2639c5dd --- /dev/null +++ b/include/cutlass/epilogue/collective/builders/sm90_common.inl @@ -0,0 +1,80 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::collective::detail { + +/////////////////////////////////////////////////////////////////////////////// + +// Selects the largest vectorized smem store atom available +template +constexpr auto +sm90_get_smem_store_op_for_accumulator() { + using namespace cute; + + if constexpr (sizeof(ElementD) == 2 && size<0>(GmemStrideTypeD{}) == 1) { + return SM90_U16x8_STSM_T{}; + } + else if constexpr (sizeof(ElementD) == 2 && size<1>(GmemStrideTypeD{}) == 1) { + return SM90_U32x4_STSM_N{}; + } + else { + // auto-vectorizing store + return AutoVectorizingCopyWithAssumedAlignment{}; + } +} + +// Selects the largest vectorized smem load atom available +template +constexpr auto +sm90_get_smem_load_op_for_source() { + using namespace cute; + + // Reuse the logic from smem store selector + using SmemStoreOp = decltype(sm90_get_smem_store_op_for_accumulator()); + + if constexpr (cute::is_same_v) { + return SM75_U16x8_LDSM_T{}; + } + else if constexpr (cute::is_same_v) { + return SM75_U32x4_LDSM_N{}; + } + else { + // auto-vectorizing load + return AutoVectorizingCopyWithAssumedAlignment<128>{}; + } +} + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::collective::detail diff --git a/include/cutlass/epilogue/collective/collective_epilogue.hpp b/include/cutlass/epilogue/collective/collective_epilogue.hpp index d61f59f729..f8179b0a0e 100644 --- a/include/cutlass/epilogue/collective/collective_epilogue.hpp +++ b/include/cutlass/epilogue/collective/collective_epilogue.hpp @@ -59,4 +59,5 @@ class CollectiveEpilogue { #include "sm70_epilogue_vectorized.hpp" #include "sm90_epilogue_tma_warpspecialized.hpp" #include "sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp" +#include "sm90_epilogue_array_tma_warpspecialized.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/default_epilogue.hpp b/include/cutlass/epilogue/collective/default_epilogue.hpp index bbeeacacd3..cd4a6ccddb 100644 --- a/include/cutlass/epilogue/collective/default_epilogue.hpp +++ b/include/cutlass/epilogue/collective/default_epilogue.hpp @@ -129,7 +129,7 @@ class DefaultEpilogue { } template - CUTLASS_HOST_DEVICE static bool + static bool can_implement( [[maybe_unused]] ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) { diff --git a/include/cutlass/epilogue/collective/default_epilogue_array.hpp b/include/cutlass/epilogue/collective/default_epilogue_array.hpp index c2f8423e44..0f6f329311 100644 --- a/include/cutlass/epilogue/collective/default_epilogue_array.hpp +++ b/include/cutlass/epilogue/collective/default_epilogue_array.hpp @@ -65,6 +65,7 @@ class DefaultEpilogueArray { // Type Aliases // using EpilogueSchedule = EpilogueSchedule_; + using DispatchPolicy = EpilogueSchedule_; // derived types of output thread level operator using ThreadEpilogueOp = ThreadEpilogueOp_; @@ -74,10 +75,10 @@ class DefaultEpilogueArray { using ElementScalar = ElementCompute; using ElementC = typename ThreadEpilogueOp::ElementC; using StrideC = StrideC_; - using UnderlyingStrideC = cute::remove_pointer_t; + using InternalStrideC = cute::remove_pointer_t; using ElementD = typename ThreadEpilogueOp::ElementD; using StrideD = StrideD_; - using UnderlyingStrideD = cute::remove_pointer_t; + using InternalStrideD = cute::remove_pointer_t; using GmemTiledCopyC = void; using GmemTiledCopyD = void; @@ -85,12 +86,14 @@ class DefaultEpilogueArray { static const int kOutputAlignment = ThreadEpilogueOp::kCount; using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; - static_assert(cute::is_same_v, "Incompatible epilogue schedule."); - static_assert(rank(UnderlyingStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - static_assert(rank(UnderlyingStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::is_same_v || cute::is_same_v, "Incompatible epilogue schedule."); + static_assert(rank(InternalStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(InternalStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); struct SharedStorage { }; + using TensorMapStorage = SharedStorage; + // Host side epilogue arguments struct Arguments { typename ThreadEpilogueOp::Params thread{}; @@ -118,7 +121,7 @@ class DefaultEpilogueArray { template static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { return 0; } @@ -130,7 +133,7 @@ class DefaultEpilogueArray { } template - CUTLASS_HOST_DEVICE static bool + static bool can_implement( [[maybe_unused]] ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) { @@ -195,9 +198,9 @@ class DefaultEpilogueArray { assert(0); } - UnderlyingStrideC stride_c; - UnderlyingStrideD stride_d; - if constexpr (!cute::is_same_v) { + InternalStrideC stride_c; + InternalStrideD stride_d; + if constexpr (!cute::is_same_v) { // If grouped gemm if (epilogue_op.is_source_needed()) { stride_c = detail::get_epilogue_stride(params.dC[l_coord]); diff --git a/include/cutlass/epilogue/collective/detail.hpp b/include/cutlass/epilogue/collective/detail.hpp index 6b01a22e16..a01781440c 100644 --- a/include/cutlass/epilogue/collective/detail.hpp +++ b/include/cutlass/epilogue/collective/detail.hpp @@ -78,7 +78,12 @@ static constexpr int elements_per_access_v = cutlass::sizeof_bits::val template static constexpr bool sm90_is_cooperative_v = - cute::is_base_of_v; + cute::is_base_of_v || + cute::is_base_of_v; + +template +static constexpr bool sm90_is_tma_ptr_array_v = + cute::is_base_of_v; template static constexpr bool sm90_is_warp_specialized_v = @@ -151,24 +156,26 @@ class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { using LoadPipeline = cutlass::PipelineTransactionAsync<0>; using LoadPipelineState = cutlass::PipelineState<0>; constexpr static uint32_t TmaTransactionBytes = 0; + constexpr static bool RequiresTransactionBytes = false; using StorePipeline = cutlass::PipelineTmaStore<0>; using StorePipelineState = cutlass::PipelineState<0>; using TensorStorage = typename EpilogueOp::SharedStorage; + using TensorMapStorage = typename EpilogueOp::SharedStorage; using PipelineStorage = typename LoadPipeline::SharedStorage; - template + template CUTLASS_HOST_DEVICE static constexpr int - get_load_pipe_increment([[maybe_unused]] TileShapeMNK) { + get_load_pipe_increment(CtaTileMNK) { return 1; } - template + template CUTLASS_HOST_DEVICE static constexpr int - get_store_pipe_increment([[maybe_unused]] TileShapeMNK) { + get_store_pipe_increment(CtaTileMNK) { return 1; } @@ -191,11 +198,38 @@ class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { return false; } + CUTLASS_DEVICE auto + load_init([[maybe_unused]] typename EpilogueOp::Params const& params, [[maybe_unused]] int32_t const sm_count, [[maybe_unused]] int32_t const sm_idx) const { + return cute::make_tuple(nullptr); + } + + template< + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class TiledMma + > + CUTLASS_DEVICE auto + load( + [[maybe_unused]] LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + [[maybe_unused]] ProblemShapeMNKL problem_shape_mnkl, + [[maybe_unused]] CtaTileMNK cta_tile_mnk, + [[maybe_unused]] CtaCoordMNKL cta_coord_mnkl, + [[maybe_unused]] TiledMma tiled_mma, + [[maybe_unused]] int thread_idx, + [[maybe_unused]] TensorStorage& shared_tensors, + [[maybe_unused]] int subtile_idx=-1) + { + return load_pipe_producer_state; + } + template< class ProblemShapeMNKL, class TileShapeMNK, class TileCoordMNKL, - class TiledMma + class TiledMma, + class TensorMapC > CUTLASS_DEVICE auto load( @@ -207,7 +241,9 @@ class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { [[maybe_unused]] TiledMma tiled_mma, [[maybe_unused]] int thread_idx, [[maybe_unused]] TensorStorage& shared_tensors, - [[maybe_unused]] int subtile_idx=-1) + [[maybe_unused]] TensorMapC const& load_tensormap, + [[maybe_unused]] int subtile_idx=-1, + [[maybe_unused]] bool return_prior_state = false) { return load_pipe_producer_state; } @@ -220,12 +256,66 @@ class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { return load_pipe_producer_state; } + CUTLASS_DEVICE auto + store_init([[maybe_unused]] typename EpilogueOp::Params const& params, [[maybe_unused]] int32_t const sm_count, + [[maybe_unused]] int32_t const sm_idx) const { + return cute::make_tuple(nullptr); + } + + template< + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class AccEngine, class AccLayout, + class TiledMma + > + CUTLASS_DEVICE auto + store( + [[maybe_unused]] LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + [[maybe_unused]] StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + cute::Tensor accumulators, + TiledMma tiled_mma, + int thread_idx, + TensorStorage& shared_tensors, + int subtile_index = -1) + { + constexpr int BLK_M_RANK = cute::rank<0>(cta_tile_mnk); + auto m_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { + return get<0,i>(problem_shape_mnkl) - get<0,i>(cta_tile_mnk) * get<0,i>(cta_coord_mnkl); + })); + + constexpr int BLK_N_RANK = cute::rank<1>(cta_tile_mnk); + auto n_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { + return get<1,i>(problem_shape_mnkl) - get<1,i>(cta_tile_mnk) * get<1,i>(cta_coord_mnkl); + })); + + auto residue_mnk = make_tuple(m_max_coord, n_max_coord, Int<0>{}); + + (*this)( + problem_shape_mnkl, + cta_tile_mnk, + cta_coord_mnkl, + accumulators, + tiled_mma, + residue_mnk, + thread_idx, + reinterpret_cast(&shared_tensors)); + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + template< class ProblemShapeMNKL, class TileShapeMNK, class TileCoordMNKL, class AccEngine, class AccLayout, - class TiledMma + class TiledMma, + class TensorMapD > CUTLASS_DEVICE auto store( @@ -240,6 +330,7 @@ class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { TiledMma tiled_mma, int thread_idx, TensorStorage& shared_tensors, + [[maybe_unused]] TensorMapD const& store_tensormap, int subtile_index = -1) { constexpr int BLK_M_RANK = cute::rank<0>(tile_shape_MNK); @@ -276,6 +367,29 @@ class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); } + // Dummy methods to perform different parts of TMA/Tensormap modifications + + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + [[maybe_unused]] TensorMapStorage& shared_tensormap, + [[maybe_unused]] typename EpilogueOp::Params const& params, + [[maybe_unused]] cute::TmaDescriptor const* tensormap, + [[maybe_unused]] int32_t next_batch) { } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release( + [[maybe_unused]] TensorMapStorage& shared_tensormap, + [[maybe_unused]] cute::TmaDescriptor const* tensormap, + [[maybe_unused]] uint32_t lane_predicate) { } + + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire([[maybe_unused]] cute::TmaDescriptor const* tensormap) { } }; } // namespace detail diff --git a/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp b/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp index c870b706db..48833ecf10 100644 --- a/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp +++ b/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp @@ -149,7 +149,7 @@ class EpilogueTensorBroadcast { } template - CUTLASS_HOST_DEVICE static bool + static bool can_implement( [[maybe_unused]] ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) { diff --git a/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp b/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp index be19944d1b..69170f75ea 100644 --- a/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp +++ b/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp @@ -135,7 +135,7 @@ class Epilogue { } template - CUTLASS_HOST_DEVICE static bool + static bool can_implement( [[maybe_unused]] ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) { diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp new file mode 100644 index 0000000000..981ea3e274 --- /dev/null +++ b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp @@ -0,0 +1,998 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/layout.hpp" +#include "cutlass/trace.h" + +#include "cute/tensor.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_, + class CtaTileMNK_, // (CTA_M,CTA_N,CTA_K) + class EpilogueTile_, // (EPI_TILE_M,EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class FusionCallbacks_, + class CopyOpG2S_, + class SmemLayoutAtomC_, + class CopyOpS2R_, + class CopyOpS2G_, + class SmemLayoutAtomD_, + class CopyOpR2S_, + class CopyAtomC_ +> +class CollectiveEpilogue< + Sm90PtrArrayTmaWarpSpecialized, + CtaTileMNK_, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpG2S_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpS2G_, + SmemLayoutAtomD_, + CopyOpR2S_, + CopyAtomC_ +> { +public: + // + // Type Aliases + // + using DispatchPolicy = Sm90PtrArrayTmaWarpSpecialized; + using CtaTileMNK = CtaTileMNK_; + using EpilogueTile = EpilogueTile_; + using FusionCallbacks = FusionCallbacks_; + using ElementC = ElementC_; + using StrideC = StrideC_; + using InternalStrideC = cute::remove_pointer_t; + using ElementD = ElementD_; + using StrideD = StrideD_; + using InternalStrideD = cute::remove_pointer_t; + using CopyOpG2S = CopyOpG2S_; + using SmemLayoutAtomC = SmemLayoutAtomC_; + using CopyOpS2R = CopyOpS2R_; + using CopyOpS2G = CopyOpS2G_; + using SmemLayoutAtomD = SmemLayoutAtomD_; + using CopyOpR2S = CopyOpR2S_; + using CopyAtomC = CopyAtomC_; + + + using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits::Operation; + using GmemTiledCopyC = CopyOpG2S; + using GmemTiledCopyD = CopyOpS2G; + + static_assert(!is_layout::value && is_tuple::value, "EpilogueTile must be a cute::Tile or cute::Shape"); + static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]"); + static_assert(size<0>(CtaTileMNK{}) % size<0>(shape(EpilogueTile{})) == 0, "EPI_TILE_M must divide CTA_M"); + static_assert(size<1>(CtaTileMNK{}) % size<1>(shape(EpilogueTile{})) == 0, "EPI_TILE_N must divide CTA_N"); + static_assert(cute::rank(InternalStrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); + static_assert(cute::rank(InternalStrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); + +private: + constexpr static bool is_source_supported = not cute::is_void_v; + constexpr static bool is_destination_supported = not cute::is_void_v; + using NonVoidElementD = cute::conditional_t, ElementD>; + static_assert(not cute::is_void_v, "SmemElementD is void"); + using NonVoidElementC = cute::conditional_t; // prevents void ref breakages + + using SmemElementC = typename cutlass::detail::get_unpacked_element_type::type; + using SmemElementD = typename cutlass::detail::get_unpacked_element_type::type; + + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static bool ReuseSmemC = ReuseSmemC_ and is_destination_supported; + constexpr static bool DelayTmaStore = DelayTmaStore_; + + constexpr static bool is_m_major_C = detail::is_m_major(); + constexpr static bool is_m_major_D = detail::is_m_major(); + + constexpr static bool is_im2col_C = cute::is_same_v; + constexpr static bool is_im2col_D = cute::is_same_v; + + using SmemLayoutC = decltype(tile_to_shape( + SmemLayoutAtomC{}, + make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{} )); + using SmemLayoutD = decltype(tile_to_shape( + SmemLayoutAtomD{}, + make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{} )); + + constexpr static bool support_smem_reuse = is_source_supported && is_destination_supported && StagesD <= StagesC + && cosize(take<0,2>(SmemLayoutC{})) == cosize(take<0,2>(SmemLayoutD{})); + static_assert(not (ReuseSmemC && not support_smem_reuse), "Smem reuse requirements not met"); + + constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); + constexpr static size_t SmemAlignmentC = cutlass::detail::alignment_for_swizzle(SmemLayoutC{}); + constexpr static size_t MaxSmemAlignment = cute::max(SmemAlignmentC, SmemAlignmentD); + + using SmemArrayTypeC = cute::ArrayEngine>; + using SmemArrayTypeD = cute::ArrayEngine>; + + using EmptyType = cute::tuple<>; + using SmemCStorage = cute::conditional_t; + using SmemDStorage = cute::conditional_t; + + struct CollectiveStorageWithC { + alignas(SmemAlignmentC) ArrayEngine> smem_C; + alignas(SmemAlignmentD) ArrayEngine> smem_D; + }; + + union CollectiveStorageWithoutC { + cute::array smem_C; + alignas(SmemAlignmentD) ArrayEngine> smem_D; + }; + + union CollectiveStorageReuseC { + alignas(MaxSmemAlignment) ArrayEngine> smem_C; + alignas(MaxSmemAlignment) ArrayEngine> smem_D; + }; + +public: + // TMA pipeline for loading C + using LoadPipeline = cutlass::PipelineTransactionAsync; + using LoadPipelineState = cutlass::PipelineState; + constexpr static uint32_t TmaTransactionBytes = + (size(take<0,2>(SmemLayoutC{})) * static_cast(sizeof_bits::value)) / 8; + constexpr static bool RequiresTransactionBytes = true; + + // TMA pipeline for storing D + using StorePipeline = cute::conditional_t, + cutlass::PipelineTmaStore>; + using StorePipelineState = cutlass::PipelineState; + + struct SharedStorage { + struct TensorStorage { + using CollectiveStorage = cute::conditional_t>; + CollectiveStorage collective; + + using FusionStorage = typename FusionCallbacks::SharedStorage; + FusionStorage thread; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128> { + cute::TmaDescriptor smem_tensormap_C; + cute::TmaDescriptor smem_tensormap_D; + } tensormaps; + + using PipelineStorage = typename LoadPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side epilogue arguments + struct Arguments { + typename FusionCallbacks::Arguments thread{}; + ElementC const** ptr_C = nullptr; + StrideC dC; + ElementD ** ptr_D = nullptr; + StrideD dD; + }; + + // Device side epilogue params + struct Params { + using TMA_C = decltype(make_tma_copy( + CopyOpG2S{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), + repeat_like(InternalStrideC{}, int32_t(0)), InternalStrideC{}), + take<0,2>(SmemLayoutC{}), + EpilogueTile{}, + _1{})); + + using TMA_D = decltype(make_tma_copy( + CopyOpS2G{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), + repeat_like(InternalStrideD{}, int32_t(0)), InternalStrideD{}), + take<0,2>(SmemLayoutD{}), + EpilogueTile{}, + _1{})); + + typename FusionCallbacks::Params thread{}; + TMA_C tma_load_c; + TMA_D tma_store_d; + cute::TmaDescriptor* tensormaps; + ElementC const** ptr_C; + ElementD** ptr_D; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(), 1); + auto [M, N, K, mock_L] = problem_shape_MNKL; + // Manage batches/groups through pointers to input matricies + mock_L = 1; + + static_assert(!is_im2col_C and !is_im2col_D, "Im2Col not supported on C or D"); + + uint32_t transaction_bytes = TmaTransactionBytes; + typename Params::TMA_C tma_load_c = {}; + if constexpr (is_source_supported) { + ElementC const* ptr_C_first_batch = reinterpret_cast(args.ptr_C); + Tensor tensor_c = make_tensor(ptr_C_first_batch, make_layout(make_shape(M,N,mock_L), append<3>(args.dC, _0{}))); + tma_load_c = make_tma_copy_C_sm90( + CopyOpG2S{}, + tensor_c, + take<0,2>(SmemLayoutC{}), + EpilogueTile{}); + + } + + typename Params::TMA_D tma_store_d; + if constexpr (is_destination_supported) { + ElementD const* ptr_D_first_batch = reinterpret_cast(args.ptr_D); + Tensor tensor_d = make_tensor(ptr_D_first_batch, make_layout(make_shape(M,N,mock_L), append<3>(args.dD, _0{}))); + tma_store_d = make_tma_copy_C_sm90( + CopyOpS2G{}, + tensor_d, + take<0,2>(SmemLayoutD{}), + EpilogueTile{}); + } + + auto fusion_workspace = static_cast(workspace); + auto fusion_workspace_size = FusionCallbacks::get_workspace_size(problem_shape, args.thread); + auto tma_descriptor_workspace = reinterpret_cast( + static_cast(workspace) + fusion_workspace_size); + + return { + FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, fusion_workspace), + tma_load_c, + tma_store_d, + tma_descriptor_workspace, + args.ptr_C, + args.ptr_D, + transaction_bytes, + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTensors = cute::is_void_v ? 1 : 2; + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies + return (NumInputTensors * SizeOfCuTensorMap * sm_count) + FusionCallbacks::get_workspace_size(problem_shape, args.thread); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return FusionCallbacks::initialize_workspace(problem_shape, args.thread, workspace, stream, cuda_adapter); + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(), 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + if constexpr (is_destination_supported) { + constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits(); + constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits::value; + implementable = cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideD{}); + } + + if constexpr (not cute::is_void_v) { + constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideC{}); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + bool fusion_implementable = FusionCallbacks::can_implement(problem_shape, args.thread); + + if (!fusion_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); + } + return implementable && fusion_implementable; + } + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_load_pipe_increment(TileShapeMNK tile_shape_MNK) { + // Compute number of epilogue subtiles + return size<1>(zipped_divide(make_layout(take<0,2>(tile_shape_MNK)), EpilogueTile{})); + } + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_store_pipe_increment(TileShapeMNK tile_shape_MNK) { + return get_load_pipe_increment(tile_shape_MNK); + } + + CUTLASS_HOST_DEVICE + CollectiveEpilogue(Params const& params_, TensorStorage& shared_tensors) + : params(params_), fusion_callbacks(params_.thread, shared_tensors.thread) {} + + CUTLASS_DEVICE + bool + is_producer_load_needed() const { + return fusion_callbacks.is_producer_load_needed(); + } + + CUTLASS_DEVICE auto + load_init(Params const& params, int32_t const sm_count, int32_t const sm_idx) const { + // Initialize tma for loading + constexpr bool IsLoad = true; + auto load_tensormaps = tensormaps_init(params, sm_count, sm_idx); + return load_tensormaps; + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class TiledMma, + class TensorMapC + > + CUTLASS_DEVICE auto + load( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + TiledMma tiled_mma, + int thread_idx, + TensorStorage& shared_tensors, + TensorMapC const& load_tensormap, + int subtile_idx=-1, + bool return_prior_state = false) { + using namespace cute; + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + + static_assert(!is_im2col_D, "Do not support im2col"); + + auto coord_shape = append<3>(make_shape(m_coord, n_coord), Int<0>{}); + + // Represent the full source tensor, slice to get the tile this CTA is currently responsible for + Tensor mC_mn = params.tma_load_c.get_tma_tensor(append<3>(make_shape(M,N), Int<1>{})); // (M,N,L) + Tensor mC = coalesce(mC_mn, take<0,2>(CtaTileMNK{})); + Tensor gC = local_tile(mC, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N) + + // Apply epilogue subtile, get matching smem tensor + auto ptr_sC = shared_tensors.collective.smem_C.begin(); + Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor sC_epi = make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + + // Prepare the thread(b)lock's (G)mem to (S)mem TMA tiled copy (bGS_) + ThrCopy thrblk_g2s = params.tma_load_c.get_slice(Int<0>{}); + Tensor bGS_gC = thrblk_g2s.partition_S(gC_epi); // (G2S,G2S_M,G2S_N,EPI_M,EPI_N) + Tensor bGS_sC = thrblk_g2s.partition_D(sC_epi); // (G2S,G2S_M,G2S_N,PIPE_C) + + // Get the fusion callbacks for the producer load warp + auto pld_args = cutlass::epilogue::fusion::detail::ProducerLoadArgs{ + problem_shape_mnkl, + CtaTileMNK{}, + tile_coord_mnkl, + tiled_mma, + EpilogueTile{}, + thread_idx + }; + auto pld_callbacks = fusion_callbacks.get_producer_load_callbacks(pld_args); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // Predication for TMA load (one thread issues TMA load) + bool issue_tma_load = cute::elect_one_sync(); + + // Acquire the lock for the first stage + uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); + load_pipeline.producer_acquire(load_pipe_producer_state); + + // Pre-loop fusion callback entry point + pld_callbacks.begin(tma_barrier, load_pipe_producer_state.count(), issue_tma_load); + + auto prior_state = load_pipe_producer_state; + + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < size<3>(gC_epi); ++epi_n) { + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < size<2>(gC_epi); ++epi_m) { + if (subtile_idx != -1 && (epi_n * static_cast(size<2>(gC_epi)) + epi_m) != subtile_idx) { + continue; + } + // Acquire the lock for this stage + constexpr uint16_t mcast_mask = 0; + uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); + load_pipeline.producer_acquire(load_pipe_producer_state); + + // Loop fusion callback entry point + pld_callbacks.step(tma_barrier, epi_m, epi_n, load_pipe_producer_state.count(), issue_tma_load); + + // Execute the TMA load for C if needed + if (issue_tma_load && is_C_load_needed) { + copy(params.tma_load_c.with(load_tensormap, *tma_barrier, mcast_mask), + bGS_gC(_,_,_,epi_m,epi_n), bGS_sC(_,_,_,load_pipe_producer_state.index())); + load_pipeline.producer_expect_transaction(load_pipe_producer_state); + } + + // Commit TMA loads for this stage and release the lock + load_pipeline.producer_commit(load_pipe_producer_state); + prior_state = load_pipe_producer_state; + ++load_pipe_producer_state; + } + } + + // Post-loop fusion callback entry point + pld_callbacks.end(); + + if (not return_prior_state) { + return load_pipe_producer_state; + } else { + return prior_state; + } + } + + CUTLASS_DEVICE auto + load_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state) { + bool issue_tma_load = cute::elect_one_sync(); + if (issue_tma_load) { + load_pipeline.producer_tail(load_pipe_producer_state); + } + + return load_pipe_producer_state; + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class AccEngine, class AccLayout, + class TiledMma, + class TensorMapD + > + CUTLASS_DEVICE auto + store( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + cute::Tensor accumulators, + TiledMma tiled_mma, + int thread_idx, + TensorStorage& shared_tensors, + TensorMapD const& store_tensormap, + int subtile_idx=-1) { + using namespace cute; + using ElementAccumulator = typename AccEngine::value_type; + using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; + using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; + + static_assert(is_rmem::value, "Accumulator must be RF resident."); + static_assert(rank(AccLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA,MMA_M,MMA_N)"); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "TileShapeMNK must be static"); + static_assert(rank(TileShapeMNK{}) == 3, "TileShapeMNK must be rank 3"); + static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + + + static_assert(!is_im2col_D, "Do not support im2col"); + + auto coord_shape = append<3>(make_shape(m_coord, n_coord), Int<0>{}); + + // Represent the full output tensor, slice to get the tile this CTA is responsible for + Tensor mD_mn = params.tma_store_d.get_tma_tensor(append<3>(make_shape(M,N), Int<1>{})); // (M,N,L) + Tensor mD = coalesce(mD_mn, take<0,2>(CtaTileMNK{})); + Tensor gD = local_tile(mD, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N) + + // Apply epilogue subtiling + Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + // Construct the corresponding pipelined smem tensors + auto ptr_sC = shared_tensors.collective.smem_C.begin(); + auto ptr_sD = shared_tensors.collective.smem_D.begin(); + Tensor sC_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + Tensor sD_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sD), SmemLayoutD{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_D) + + TiledCopy tiled_copy_C_atom = make_tiled_copy_C_atom(CopyAtomC{}, tiled_mma); + + // (t)hread-partition for (r)egister to (s)mem copy (tRS_) + TiledCopy tiled_r2s = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); + Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N) + Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D) + + auto mma_tile_m = size<0>(TileShapeMNK{}) / size<1>(tRS_rAcc); + auto mma_tile_n = size<1>(TileShapeMNK{}) / size<2>(tRS_rAcc); + auto epi_tile_m = size<0>(EpilogueTile{}); + auto epi_tile_n = size<1>(EpilogueTile{}); + + // Allocate D registers + Layout tRS_rD_layout = make_layout(take<0,3>(shape(thread_r2s.partition_S(sD_epi)))); + Tensor tRS_rD = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) + + // Vectorized fragment view + constexpr int FragmentSize = DispatchPolicy::FragmentSize; + Tensor tRS_rAcc_frg = recast>(tRS_rAcc); + Tensor tRS_rD_frg = recast>(tRS_rD); + CUTE_STATIC_ASSERT(size<0>(tRS_rAcc) % FragmentSize == 0, "Fragment size does not vectorize properly"); + + // (t)hread-partition for (s)mem to (r)egister copy (tSR_) + TiledCopy tiled_s2r = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx); + Tensor tSR_sC = thread_s2r.partition_S(sC_epi); // (S2R,S2R_M,S2R_N,PIPE_C) + Layout tSR_rC_layout = thread_s2r.retile_D(tRS_rD).layout(); // (S2R,S2R_M,S2R_N) + + // Allocate C registers + // If C smem load is a non-vectorized dst(i) = src(i) then we can allocate C registers directly in the compute type + // to eliminate some redundant pack+unpack instruction sequences for sub-word types + constexpr bool IsDirectS2R = cute::is_same_v> + && decltype(max_common_vector(tSR_rC_layout, tSR_sC.layout()))::value <= 1; + using RegisterElementC = cute::conditional_t; + Tensor tRS_rC = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) + Tensor tSR_rC = thread_s2r.retile_D(tRS_rC); // (S2R,S2R_M,S2R_N) + + // thread(b)lock-partition for (s)mem to (g)mem copy (bSG_) + ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{}); + Tensor bSG_sD = thrblk_s2g.partition_S(sD_epi); // (S2G,S2G_M,S2G_N,PIPE_D) + Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) + + // OOB predication for tile quantization "residue" + // Absolute coordinate tensors (dynamic) + Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) + Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) + Tensor tRS_cD_mn = thread_r2s.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) + // Relative coordinate tensors (static) + Tensor cD = make_counting_tensor(cD_mn.layout()); // (CTA_M,CTA_N) + Tensor tRS_cD = make_counting_tensor(tRS_cD_mn.layout()); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) + // Subtract the global "bottom right" corner from the local "top left" corner to get the max relative coordinate + auto residue_cD = make_coord(M,N) - cD_mn(_0{}); // (m,n) + auto residue_tRS_cD = make_coord(M,N) - tRS_cD_mn(_0{}); // (m,n) + + CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M"); + + CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); + // Get the fusion callbacks for the consumer store warps + constexpr bool RefSrc = true; // Register tensors reference R2S copy src layout + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ + problem_shape_mnkl, + CtaTileMNK{}, + tile_coord_mnkl, + tiled_mma, + EpilogueTile{}, + tiled_r2s, + cD, + residue_cD, + tRS_cD, + residue_tRS_cD, + tRS_rC, + thread_idx + }; + auto cst_callbacks = fusion_callbacks.get_consumer_store_callbacks(cst_args); + bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // Thread synchronizer for previously issued waits or fences + // to ensure visibility of smem reads/writes to threads or TMA unit + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + + // Predication for TMA store (one warp issues TMA store) + bool issue_tma_store = (thread_idx / NumThreadsPerWarp) == 0; + + // In the reuse smem configuration we have StagesC smem buffers and at most StagesD committed TMA stores in flight. + // The TMA store pipeline producer acquire returns when at most StagesD-1 committed stores are in-flight, so we can + // only guarantee store completion after StagesD iterations, then we can begin issuing releases on the smem buffer locks. + // store_pipe_producer_state tracks the acquire and load_pipe_consumer_state tracks the release, in circular buffer fashion. + LoadPipelineState load_wait_state = load_pipe_consumer_state; + if constexpr (ReuseSmemC) { + load_wait_state = store_pipe_producer_state; + load_wait_state.phase_ ^= 1; + } + + // We can delay issue of TMA store by one iteration to achieve better interleaving of non-TMA instructions + // Sync requirements of smem reuse may preclude this optimization + // Delayed stores cause delayed stage releases which causes deadlock when StagesC == StagesD + int epi_m_prev = 0, epi_n_prev = 0; + static_assert(not (DelayTmaStore and ReuseSmemC and StagesC == StagesD), "This TMA epilogue configuration will deadlock"); + + // The TMA store sequence for one subtile iteration + auto tma_store_fn = [&] (int epi_m, int epi_n) { + // Write the tile from smem to gmem with TMA + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + synchronize(); // ensure all threads have issued their async fence + if constexpr (is_destination_supported) { + if (issue_tma_store) { + copy(params.tma_store_d.with(store_tensormap), bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); + } + } + + // Post async fence, pre TMA commit callback entry point + cst_callbacks.tma_store(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store); + + // Commit the TMA stores for this stage + if (issue_tma_store) { + store_pipeline.producer_commit(store_pipe_producer_state); + } + ++store_pipe_producer_state; + ++issued_stores; + + // Wait for the next smem buffer to be available + if (issue_tma_store) { + store_pipeline.producer_acquire(store_pipe_producer_state); + } + synchronize(); + + if constexpr (ReuseSmemC) { + // producer_acquire returns when at most StagesD-1 committed stores are pending + bool store_finished = issued_stores > StorePipeline::UnacquiredStages; + // Let dma warp know earliest smem buffer is consumed and empty after StagesD producer commits + if (store_finished) { + if (is_producer_load_needed) { + load_pipeline.consumer_release(load_pipe_consumer_state); + } + ++load_pipe_consumer_state; + } + } + }; + + // + // BEGIN EPILOGUE + // + + // Pre-loop fusion callback entry point + cst_callbacks.begin(); + + // For each output tile + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n) { + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m) { + bool is_first_iteration = epi_m == 0 && epi_n == 0; + bool is_last_iteration = epi_m == size<2>(gD_epi)-1 && epi_n == size<3>(gD_epi)-1; + + if (subtile_idx != -1 && (epi_n * static_cast(size<2>(gD_epi)) + epi_m) != subtile_idx) { + continue; + } + + cst_callbacks.begin_loop(epi_m, epi_n); + + if (is_producer_load_needed) { + // Wait for the producer load to fill smem + load_pipeline.consumer_wait(load_wait_state); + + if (is_C_load_needed) { + // Copy source tile from smem to register + copy(tiled_s2r, tSR_sC(_,_,_,load_wait_state.index()), tSR_rC); + } + } + + // First loop fusion callback entry point + cst_callbacks.previsit(epi_m, epi_n, load_wait_state.count(), is_producer_load_needed); + + if (is_producer_load_needed) { + if constexpr (not ReuseSmemC) { + // Let producer load warp know smem buffers are consumed and empty + cutlass::arch::fence_view_async_shared(); + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + ++load_wait_state; + } + + int mma_m = epi_m; + int mma_n = (epi_n * size<1>(EpilogueTile{})) / mma_tile_n; + Tensor tRS_rAcc_frg_mn = tRS_rAcc_frg(_,mma_m,mma_n); + + // Vectorized fragment loop with visitor callback entry point + int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n); + int r2s_v = epi_n_in_mma * size(tRS_rD_frg); + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(tRS_rD_frg); ++epi_v) { + tRS_rD_frg(epi_v) = cst_callbacks.visit(tRS_rAcc_frg_mn(r2s_v + epi_v), epi_v, epi_m, epi_n); + } + // The latest we can delay the TMA store is right before the smem store of the next iteration + // since the current TMA store needs to be committed before we can acquire the next smem buffer + if constexpr (DelayTmaStore) { + // Issue TMA stores for the previous subtile + if (not is_first_iteration and subtile_idx == -1) { + tma_store_fn(epi_m_prev, epi_n_prev); + } + epi_m_prev = epi_m; + epi_n_prev = epi_n; + } + + // Smem reduction callback entry point using current store buffer for workspace + cst_callbacks.reduce(sD_epi(_,_,store_pipe_producer_state.index()), + synchronize, epi_m, epi_n, is_last_iteration, tRS_rD_frg); + + // Copy tile from register to smem + if constexpr (is_destination_supported) { + copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); + } + + // Post reduction, pre TMA store callback entry point + constexpr bool issue_smem_store = true; // No smem store predication + cst_callbacks.postreduce(epi_m, epi_n, store_pipe_producer_state.count(), issue_smem_store); + + if constexpr (not DelayTmaStore) { + // Issue TMA stores for this subtile + tma_store_fn(epi_m, epi_n); + } + + cst_callbacks.end_loop(epi_m, epi_n); + + } // for epi_m + } // for epi_n + + if constexpr (DelayTmaStore) { + // Issue TMA stores for the last subtile + tma_store_fn(epi_m_prev, epi_n_prev); + } + + // Post-loop fusion callback entry point + cst_callbacks.end(); + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + + CUTLASS_DEVICE auto + store_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state) { + // wait for all TMA stores to complete + store_pipeline.producer_tail(store_pipe_producer_state); + // reset store counter + issued_stores = 0; + + if constexpr (ReuseSmemC) { + if (fusion_callbacks.is_producer_load_needed()) { + // Issue releases on up to StagesD-1 previously issued TMA stores + constexpr int release_stages = cute::min(StorePipeline::UnacquiredStages, get_load_pipe_increment(CtaTileMNK{})); + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < release_stages; ++stage) { + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + } + } + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + + CUTLASS_DEVICE auto + store_init(Params const& params, int32_t const sm_count, int32_t const sm_idx) const { + // Initialize tma + constexpr bool IsLoad = false; + auto store_tensormaps = tensormaps_init(params, sm_count, sm_idx); + return store_tensormaps; + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + + template + CUTLASS_DEVICE auto + tensormaps_init(Params const& params, int32_t const sm_count, int32_t const sm_idx) const { + cute::TmaDescriptor* tma_desc = nullptr; + cute::TmaDescriptor* gmem_tensormap = params.tensormaps; + if constexpr (IsLoad) { + if (not cute::is_void_v) { + tma_desc = &gmem_tensormap[sm_idx]; + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to gmem for modification later + Tensor pC_tensormap = make_tensor(params.tma_load_c.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor gC_tensormap = make_tensor(tma_desc, Int<1>{}, Int<1>{}); + copy(recast(pC_tensormap), recast(gC_tensormap)); + } + } + } else { + int const offset_Ddesc = cute::is_void_v ? 0 : sm_count; + tma_desc = &gmem_tensormap[sm_idx + offset_Ddesc]; + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to gmem for modification later + Tensor pD_tensormap = make_tensor(params.tma_store_d.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor gD_tensormap = make_tensor(tma_desc, Int<1>{}, Int<1>{}); + copy(recast(pD_tensormap), recast(gD_tensormap)); + } + } + + return cute::make_tuple(tma_desc); + } + + // Bringing tensormaps to smem (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_fetch_to_smem( + TensorMapStorage& shared_tensormap, + cute::TmaDescriptor const* tensormap) const { + if constexpr (IsLoad) { + if (not cute::is_void_v) { + Tensor gC_tensormap = make_tensor(make_gmem_ptr(tensormap), Int<1>{}, Int<1>{}); + Tensor sC_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_C), Int<1>{}, Int<1>{}); + copy(recast(gC_tensormap), recast(sC_tensormap)); + } + } else { + Tensor gD_tensormap = make_tensor(make_gmem_ptr(tensormap), Int<1>{}, Int<1>{}); + Tensor sD_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_D), Int<1>{}, Int<1>{}); + copy(recast(gD_tensormap), recast(sD_tensormap)); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + // Replace address for the global tensor (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormap, + Params const& params, + int32_t next_batch) { + // Replacing global_address for the next batch + if constexpr (IsLoad) { + if (not cute::is_void_v) { + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_C, + params.ptr_C[next_batch]); + } + } else { + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_D, + params.ptr_D[next_batch]); + } + } + + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormap, + Params const& params, + cute::TmaDescriptor const* tensormap, + int32_t next_batch) { + if (cute::elect_one_sync()) { + // Bringing tensormaps to smem + tensormaps_fetch_to_smem(shared_tensormap, tensormap); + + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormap, params, next_batch); + } + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release( + TensorMapStorage& shared_tensormap, + cute::TmaDescriptor const* tensormap, + [[maybe_unused]] uint32_t lane_predicate) { + // Entire warp must do this (ie its aligned) + if constexpr (IsLoad) { + if (not cute::is_void_v) { + tma_descriptor_cp_fence_release(tensormap, shared_tensormap.smem_tensormap_C); + } + } else { + tma_descriptor_cp_fence_release(tensormap, shared_tensormap.smem_tensormap_D); + } + } + + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::TmaDescriptor const* tensormap) { + if constexpr (IsLoad) { + if (not cute::is_void_v) { + cute::tma_descriptor_fence_acquire(tensormap); + } + } else { + cute::tma_descriptor_fence_acquire(tensormap); + } + } + +private: + Params const& params; + FusionCallbacks fusion_callbacks; + int issued_stores = 0; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp index 9eb4c4b123..c03aed33fe 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp @@ -41,6 +41,7 @@ #include "cutlass/epilogue/thread/scale_type.h" #include "cutlass/epilogue/fusion/callbacks.hpp" #include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" +#include "cutlass/detail/collective.hpp" #include "cutlass/detail/layout.hpp" #include "cutlass/trace.h" @@ -73,7 +74,8 @@ template < class CopyOpS2R_, class CopyOpS2G_, class SmemLayoutAtomD_, - class CopyOpR2S_ + class CopyOpR2S_, + class CopyAtomC_ > class CollectiveEpilogue< Sm90TmaWarpSpecialized, @@ -89,7 +91,8 @@ class CollectiveEpilogue< CopyOpS2R_, CopyOpS2G_, SmemLayoutAtomD_, - CopyOpR2S_ + CopyOpR2S_, + CopyAtomC_ > { public: // @@ -109,6 +112,7 @@ class CollectiveEpilogue< using CopyOpS2G = CopyOpS2G_; using SmemLayoutAtomD = SmemLayoutAtomD_; using CopyOpR2S = CopyOpR2S_; + using CopyAtomC = CopyAtomC_; using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits::Operation; using GmemTiledCopyC = CopyOpG2S; @@ -125,9 +129,13 @@ class CollectiveEpilogue< private: constexpr static bool is_source_supported = not cute::is_void_v; constexpr static bool is_destination_supported = not cute::is_void_v; - using SmemElementD = cute::conditional_t, ElementD>; - static_assert(not cute::is_void_v, "SmemElementD is void"); - using SmemElementC = cute::conditional_t; // prevents void ref breakages + using NonVoidElementD = cute::conditional_t, ElementD>; + static_assert(not cute::is_void_v, "SmemElementD is void"); + using NonVoidElementC = cute::conditional_t; // prevents void ref breakages + + using SmemElementC = typename cutlass::detail::get_unpacked_element_type::type; + using SmemElementD = typename cutlass::detail::get_unpacked_element_type::type; + constexpr static int StagesC = StagesC_; constexpr static int StagesD = StagesD_; constexpr static bool ReuseSmemC = ReuseSmemC_ and is_destination_supported; @@ -154,30 +162,32 @@ class CollectiveEpilogue< constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); constexpr static size_t SmemAlignmentC = cutlass::detail::alignment_for_swizzle(SmemLayoutC{}); + constexpr static size_t MaxSmemAlignment = cute::max(SmemAlignmentC, SmemAlignmentD); + + using SmemArrayTypeC = cute::ArrayEngine>; + using SmemArrayTypeD = cute::ArrayEngine>; using EmptyType = cute::tuple<>; using SmemCStorage = cute::conditional_t, + SmemArrayTypeC, EmptyType>; using SmemDStorage = cute::conditional_t, + SmemArrayTypeD, EmptyType>; - struct TensorStorageImpl: cute::tuple { - using Base = cute::tuple; - - constexpr decltype(auto) - smem_C() { - return cute::get<0>(static_cast(*this)); - } + struct CollectiveStorageWithC { + alignas(SmemAlignmentC) ArrayEngine> smem_C; + alignas(SmemAlignmentD) ArrayEngine> smem_D; + }; - constexpr decltype(auto) - smem_D() { - return cute::get<1>(static_cast(*this)); - } + union CollectiveStorageWithoutC { + cute::array smem_C; + alignas(SmemAlignmentD) ArrayEngine> smem_D; + }; - using FusionStorage = typename FusionCallbacks::SharedStorage; - FusionStorage thread; + union CollectiveStorageReuseC { + alignas(MaxSmemAlignment) ArrayEngine> smem_C; + alignas(MaxSmemAlignment) ArrayEngine> smem_D; }; public: @@ -186,6 +196,7 @@ class CollectiveEpilogue< using LoadPipelineState = cutlass::PipelineState; constexpr static uint32_t TmaTransactionBytes = (size(take<0,2>(SmemLayoutC{})) * static_cast(sizeof_bits::value)) / 8; + constexpr static bool RequiresTransactionBytes = true; // TMA pipeline for storing D using StorePipeline = cute::conditional_t; struct SharedStorage { - using TensorStorage = TensorStorageImpl; + struct TensorStorage { + using CollectiveStorage = cute::conditional_t>; + CollectiveStorage collective; - TensorStorage tensors; + using FusionStorage = typename FusionCallbacks::SharedStorage; + FusionStorage thread; + } tensors; using PipelineStorage = typename LoadPipeline::SharedStorage; PipelineStorage pipeline; @@ -217,14 +233,14 @@ class CollectiveEpilogue< struct Params { using TMA_C = decltype(make_tma_copy( CopyOpG2S{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), + make_tensor(make_gmem_ptr(static_cast(nullptr)), repeat_like(StrideC{}, int32_t(0)), StrideC{}), take<0,2>(SmemLayoutC{}), EpilogueTile{}, _1{})); using TMA_D = decltype(make_tma_copy( CopyOpS2G{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), + make_tensor(make_gmem_ptr(static_cast(nullptr)), repeat_like(StrideD{}, int32_t(0)), StrideD{}), take<0,2>(SmemLayoutD{}), EpilogueTile{}, @@ -233,6 +249,7 @@ class CollectiveEpilogue< typename FusionCallbacks::Params thread{}; TMA_C tma_load_c; TMA_D tma_store_d; + uint32_t tma_transaction_bytes = TmaTransactionBytes; }; // @@ -248,26 +265,33 @@ class CollectiveEpilogue< // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(problem_shape, 1); auto [M, N, K, L] = problem_shape_MNKL; - // For fprop/dgrad kernel, problem shape M is multimodal which should be linearized under tiled mode - auto M_C = conditional_return(M, size(M)); - auto M_D = conditional_return(M, size(M)); + uint32_t transaction_bytes = TmaTransactionBytes; typename Params::TMA_C tma_load_c = {}; if constexpr (is_source_supported) { - Tensor tensor_c = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M_C,N,L), args.dC)); - tma_load_c = make_tma_copy(CopyOpG2S{}, tensor_c, take<0,2>(SmemLayoutC{}), EpilogueTile{}, _1{}); + Tensor tensor_c = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M,N,L), args.dC)); + tma_load_c = make_tma_copy_C_sm90( + CopyOpG2S{}, + tensor_c, + take<0,2>(SmemLayoutC{}), + EpilogueTile{}); } typename Params::TMA_D tma_store_d; if constexpr (is_destination_supported) { - Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M_D,N,L), args.dD)); - tma_store_d = make_tma_copy(CopyOpS2G{}, tensor_d, take<0,2>(SmemLayoutD{}), EpilogueTile{}, _1{}); + Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M,N,L), args.dD)); + tma_store_d = make_tma_copy_C_sm90( + CopyOpS2G{}, + tensor_d, + take<0,2>(SmemLayoutD{}), + EpilogueTile{}); } return { FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), tma_load_c, - tma_store_d + tma_store_d, + transaction_bytes }; } @@ -285,30 +309,37 @@ class CollectiveEpilogue< } template - CUTLASS_HOST_DEVICE static bool + static bool can_implement( ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) { - constexpr int tma_alignment_bits = 128; auto problem_shape_MNKL = append<4>(problem_shape, 1); auto [M,N,K,L] = problem_shape_MNKL; + auto shape = cute::make_shape(M,N,L); bool implementable = true; if constexpr (is_destination_supported) { - constexpr int min_tma_aligned_elements_D = tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), StrideD{}); + constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits(); + constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits::value; + implementable = cutlass::detail::check_alignment(shape, StrideD{}); } if constexpr (not cute::is_void_v) { - constexpr int min_tma_aligned_elements_C = tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), StrideC{}); + constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(shape, StrideC{}); } if (!implementable) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); } - return implementable; + bool fusion_implementable = FusionCallbacks::can_implement(problem_shape, args.thread); + + if (!fusion_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); + } + return implementable && fusion_implementable; } template @@ -377,24 +408,13 @@ class CollectiveEpilogue< make_coord(m_coord, n_coord), make_coord(m_coord, n_coord, l_coord)); - // Tile residue - auto residue_mn = make_coord(M,N); - // Represent the full source tensor, slice to get the tile this CTA is currently responsible for Tensor mC_mn = params.tma_load_c.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) Tensor mC = coalesce(mC_mn, take<0,2>(CtaTileMNK{})); Tensor gC = local_tile(mC, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N) // Apply epilogue subtile, get matching smem tensor - SmemElementC* ptr_sC = nullptr; - - if constexpr (is_source_supported) { - if constexpr (ReuseSmemC) { - ptr_sC = reinterpret_cast(shared_tensors.smem_D().data()); - } else { - ptr_sC = shared_tensors.smem_C().data(); - } - } + auto ptr_sC = shared_tensors.collective.smem_C.begin(); Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) Tensor sC_epi = make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) @@ -404,14 +424,14 @@ class CollectiveEpilogue< Tensor bGS_sC = thrblk_g2s.partition_D(sC_epi); // (G2S,G2S_M,G2S_N,PIPE_C) // Get the fusion callbacks for the producer load warp - auto pld_args = cutlass::epilogue::fusion::detail::ProducerLoadArgs{ + auto pld_args = cutlass::epilogue::fusion::detail::ProducerLoadArgs( problem_shape_mnkl, CtaTileMNK{}, tile_coord_mnkl, - residue_mn, + tiled_mma, EpilogueTile{}, thread_idx - }; + ); auto pld_callbacks = fusion_callbacks.get_producer_load_callbacks(pld_args); bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); @@ -507,10 +527,6 @@ class CollectiveEpilogue< // Indexing variables auto [M, N, K, L] = problem_shape_mnkl; auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; - auto mma_tile_m = tile_size<0>(tiled_mma); - auto mma_tile_n = tile_size<1>(tiled_mma); - auto epi_tile_m = size<0>(EpilogueTile{}); - auto epi_tile_n = size<1>(EpilogueTile{}); // The tma tensor D under im2col mode only has two modes (M, N) which // should be local tiled with only (m_coord, n_coord). @@ -527,27 +543,13 @@ class CollectiveEpilogue< Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) // Construct the corresponding pipelined smem tensors - SmemElementC* ptr_sC = nullptr; - if constexpr (is_source_supported) { - if constexpr (ReuseSmemC) { - ptr_sC = reinterpret_cast(shared_tensors.smem_D().data()); - } else { - ptr_sC = shared_tensors.smem_C().data(); - } - } - - SmemElementD* ptr_sD = nullptr; - if constexpr (is_destination_supported) { - ptr_sD = shared_tensors.smem_D().data(); - } - + auto ptr_sC = shared_tensors.collective.smem_C.begin(); + auto ptr_sD = shared_tensors.collective.smem_D.begin(); Tensor sC_epi = cute::as_position_independent_swizzle_tensor( make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) Tensor sD_epi = cute::as_position_independent_swizzle_tensor( make_tensor(make_smem_ptr(ptr_sD), SmemLayoutD{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_D) - // Get the smallest tiled copy we can use to retile the accumulators - using CopyAtomC = Copy_Atom; TiledCopy tiled_copy_C_atom = make_tiled_copy_C_atom(CopyAtomC{}, tiled_mma); // (t)hread-partition for (r)egister to (s)mem copy (tRS_) @@ -556,6 +558,11 @@ class CollectiveEpilogue< Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N) Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D) + auto mma_tile_m = size<0>(TileShapeMNK{}) / size<1>(tRS_rAcc); + auto mma_tile_n = size<1>(TileShapeMNK{}) / size<2>(tRS_rAcc); + auto epi_tile_m = size<0>(EpilogueTile{}); + auto epi_tile_n = size<1>(EpilogueTile{}); + // Allocate D registers Layout tRS_rD_layout = make_layout(take<0,3>(shape(thread_r2s.partition_S(sD_epi)))); Tensor tRS_rD = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) @@ -587,32 +594,46 @@ class CollectiveEpilogue< Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) // OOB predication for tile quantization "residue" - Tensor mD_crd = make_identity_tensor(make_shape(M,N)); - Tensor cD = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); - Tensor tRS_cD = thread_r2s.partition_S(flat_divide(cD, EpilogueTile{})); - auto residue_mn = make_coord(M,N); + // Absolute coordinate tensors (dynamic) + Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) + Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) + Tensor tRS_cD_mn = thread_r2s.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) + // Relative coordinate tensors (static) + Tensor cD = make_counting_tensor(cD_mn.layout()); // (CTA_M,CTA_N) + Tensor tRS_cD = make_counting_tensor(tRS_cD_mn.layout()); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) + // Subtract the global "bottom right" corner from the local "top left" corner to get the max relative coordinate + auto residue_cD = make_coord(M,N) - cD_mn(_0{}); // (m,n) + auto residue_tRS_cD = make_coord(M,N) - tRS_cD_mn(_0{}); // (m,n) + + CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M"); - CUTE_STATIC_ASSERT(mma_tile_m == epi_tile_m, "EPI_TILE_M must equal MMA_TILE_M"); CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); - // Get the fusion callbacks for the consumer store warps constexpr bool RefSrc = true; // Register tensors reference R2S copy src layout - auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs( problem_shape_mnkl, CtaTileMNK{}, tile_coord_mnkl, - residue_mn, + tiled_mma, EpilogueTile{}, - tiled_copy_C_atom, - thread_idx, + tiled_r2s, cD, + residue_cD, tRS_cD, - tRS_rC - }; + residue_tRS_cD, + tRS_rC, + thread_idx + ); auto cst_callbacks = fusion_callbacks.get_consumer_store_callbacks(cst_args); bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + using FragmentVisit = decltype(cst_callbacks.visit(tRS_rAcc_frg(0), 0, 0, 0)); + constexpr bool IsDirectR2S = cute::is_same_v>; + using RegisterElementD = cute::conditional_t; + Tensor tRS_rCompute = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) + Tensor tRS_rCompute_frg = recast>(tRS_rCompute); + // Thread synchronizer for previously issued waits or fences // to ensure visibility of smem reads/writes to threads or TMA unit auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; @@ -694,10 +715,8 @@ class CollectiveEpilogue< if (subtile_idx != -1 && (epi_n * static_cast(size<2>(gD_epi)) + epi_m) != subtile_idx) { continue; } - // The current tile in accumulator - int mma_m = epi_m; - int mma_n = (epi_n * size<1>(EpilogueTile{})) / mma_tile_n; - Tensor tRS_rAcc_frg_mn = tRS_rAcc_frg(_,mma_m,mma_n); + + cst_callbacks.begin_loop(epi_m, epi_n); if (is_producer_load_needed) { // Wait for the producer load to fill smem @@ -722,14 +741,17 @@ class CollectiveEpilogue< ++load_wait_state; } + int mma_m = epi_m; + int mma_n = (epi_n * size<1>(EpilogueTile{})) / mma_tile_n; + Tensor tRS_rAcc_frg_mn = tRS_rAcc_frg(_,mma_m,mma_n); + // Vectorized fragment loop with visitor callback entry point int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n); - int r2s_v = epi_n_in_mma * size(tRS_rD_frg); + int r2s_v = epi_n_in_mma * size(tRS_rCompute_frg); CUTLASS_PRAGMA_UNROLL - for (int epi_v = 0; epi_v < size(tRS_rD_frg); ++epi_v) { - tRS_rD_frg(epi_v) = cst_callbacks.visit(tRS_rAcc_frg_mn(r2s_v + epi_v), epi_v, epi_m, epi_n); + for (int epi_v = 0; epi_v < size(tRS_rCompute_frg); ++epi_v) { + tRS_rCompute_frg(epi_v) = cst_callbacks.visit(tRS_rAcc_frg_mn(r2s_v + epi_v), epi_v, epi_m, epi_n); } - // The latest we can delay the TMA store is right before the smem store of the next iteration // since the current TMA store needs to be committed before we can acquire the next smem buffer if constexpr (DelayTmaStore) { @@ -743,7 +765,12 @@ class CollectiveEpilogue< // Smem reduction callback entry point using current store buffer for workspace cst_callbacks.reduce(sD_epi(_,_,store_pipe_producer_state.index()), - synchronize, epi_m, epi_n, is_last_iteration); + synchronize, epi_m, epi_n, is_last_iteration, tRS_rCompute_frg); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tRS_rD_frg); ++i) { + tRS_rD_frg(i) = cutlass::NumericArrayConverter{}(tRS_rCompute_frg(i)); + } // Copy tile from register to smem if constexpr (is_destination_supported) { @@ -758,6 +785,9 @@ class CollectiveEpilogue< // Issue TMA stores for this subtile tma_store_fn(epi_m, epi_n); } + + cst_callbacks.end_loop(epi_m, epi_n); + } // for epi_m } // for epi_n diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp index 8eeb43c2dd..b67c229c27 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp @@ -61,7 +61,8 @@ template < class CopyOpS2R_, class CopyOpS2G_, class SmemLayoutAtomD_, - class CopyOpR2S_ + class CopyOpR2S_, + class CopyAtomC_ > class Sm90EpilogueTmaWarpSpecializedBiasElementwise : public CollectiveEpilogue< @@ -78,7 +79,8 @@ class Sm90EpilogueTmaWarpSpecializedBiasElementwise CopyOpS2R_, CopyOpS2G_, SmemLayoutAtomD_, - CopyOpR2S_ + CopyOpR2S_, + CopyAtomC_ > { private: using Impl = @@ -96,7 +98,8 @@ class Sm90EpilogueTmaWarpSpecializedBiasElementwise CopyOpS2R_, CopyOpS2G_, SmemLayoutAtomD_, - CopyOpR2S_ + CopyOpR2S_, + CopyAtomC_ >; public: using DispatchPolicy = Sm90TmaWarpSpecializedBiasElementwise; diff --git a/include/cutlass/epilogue/dispatch_policy.hpp b/include/cutlass/epilogue/dispatch_policy.hpp index 409ff74dd9..9f9576b417 100644 --- a/include/cutlass/epilogue/dispatch_policy.hpp +++ b/include/cutlass/epilogue/dispatch_policy.hpp @@ -45,10 +45,13 @@ namespace cutlass::epilogue { // ////////////////////////////////////////////////////////////////////////////// +struct PtrArrayDefault {}; struct NoSmemWarpSpecialized {}; struct PtrArrayNoSmemWarpSpecialized {}; +struct PtrArrayPlanarComplexNoSmemWarpSpecialized {}; struct TmaWarpSpecialized {}; struct TmaWarpSpecializedCooperative {}; +struct PtrArrayTmaWarpSpecializedCooperative {}; // DEPRECATED schedules, will be removed in next release struct TmaWarpSpecializedElementwiseBase : public TmaWarpSpecialized {}; struct TmaWarpSpecializedCooperativeElementwiseBase : public TmaWarpSpecializedCooperative {}; @@ -143,6 +146,20 @@ struct Sm90TmaWarpSpecialized { constexpr static bool DelayTmaStore = DelayTmaStore_; }; +template< + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_ +> +struct Sm90PtrArrayTmaWarpSpecialized { + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static int FragmentSize = FragmentSize_; + constexpr static bool ReuseSmemC = ReuseSmemC_; + constexpr static bool DelayTmaStore = DelayTmaStore_; +}; // DEPRECATED policies, will be removed in next release template< diff --git a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp index 6c729c10de..1de0a28e0f 100644 --- a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp @@ -544,7 +544,6 @@ struct FusionCallbacks< }; ///////////////////////////////////////////////////////////////////////////////////////////////// - // D = per-row alpha * acc + per-row beta * C + per-row bias template< class CtaTileShapeMNK, diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp index b8cac85634..0b12badc7d 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp @@ -115,6 +115,12 @@ struct Sm90Compute { return args; } + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + template static size_t get_workspace_size(ProblemShape const&, Arguments const&) { @@ -123,7 +129,7 @@ struct Sm90Compute { template static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { return cutlass::Status::kSuccess; } @@ -139,7 +145,8 @@ struct Sm90Compute { } CUTLASS_HOST_DEVICE - Sm90Compute() { } + Sm90Compute() + : params() {} CUTLASS_HOST_DEVICE Sm90Compute(Params const& params, SharedStorage const& shared_storage) @@ -336,6 +343,12 @@ struct Sm90ReLUAuxStore : Sm90VisitorImpl<> { return args; } + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { @@ -421,27 +434,27 @@ struct Sm90TreeVisitor< Params const& params; - template + template struct ConsumerStoreCallbacks : CallbacksImpl { CUTLASS_DEVICE ConsumerStoreCallbacks( RTensor&& tC_rAux, GTensor&& tC_gAux, CTensor tC_cAux, - ResidueMN residue_mn, + ThrResidue residue_tC_cAux, Params const& params, CallbacksImpl&& impl) : tC_rAux(cute::forward(tC_rAux)), tC_gAux(cute::forward(tC_gAux)), tC_cAux(tC_cAux), - residue_mn(residue_mn), + residue_tC_cAux(residue_tC_cAux), params(params), CallbacksImpl(cute::forward(impl)) {} RTensor tC_rAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) GTensor tC_gAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) CTensor tC_cAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - ResidueMN residue_mn; + ThrResidue residue_tC_cAux; Params const& params; template @@ -506,25 +519,27 @@ struct Sm90TreeVisitor< } } + // Compute vectorization + constexpr auto MCL = decltype(max_common_layout(tC_rAux, tC_gAux)){}; + constexpr int V = cute::min(Alignment, size(MCL)); // Copy vectorizes into byte-aligned stores - constexpr int V = cute::min(Alignment, decltype(max_common_vector(tC_rAux, tC_gAux))::value); - if constexpr (V > 0 && V % 8 == 0) { + if constexpr (V > 1 && V % 8 == 0) { using VecType = uint_bit_t; Tensor tC_rAux_vec = recast(tC_rAux); Tensor tC_gAux_vec = recast(tC_gAux); - Tensor tC_cAux_vec = tC_cAux.compose(make_layout(Int{}, Int{})); // only works if vector is logically sequential - auto predicate_fn = [&] (auto&&... coords) { return elem_less(tC_cAux_vec(coords...), residue_mn); }; - copy_if(FunctionPredTensor(predicate_fn), tC_rAux_vec, tC_gAux_vec); + Tensor tC_cAux_vec = tensor<1>(zipped_divide(tC_cAux, MCL.compose(Int{}))); + auto predicate_fn = [&] (auto&&... coords) { return elem_less(tC_cAux_vec(coords...), residue_tC_cAux); }; + copy_if(predicate_fn, tC_rAux_vec, tC_gAux_vec); } // sub-byte vectorization, must serialize threads else { // Assumes no inter-warp sharing of bytes (most copy layouts should satisfy this) int lane_idx = canonical_lane_idx(); - auto predicate_fn = [&] (auto&&... coords) { return elem_less(tC_cAux(coords...), residue_mn); }; + auto predicate_fn = [&] (auto&&... coords) { return elem_less(tC_cAux(coords...), residue_tC_cAux); }; CUTLASS_PRAGMA_NO_UNROLL for (int i = 0; i < NumThreadsPerWarp; ++i) { if (lane_idx == i) { - copy_if(FunctionPredTensor(predicate_fn), tC_rAux, tC_gAux); + copy_if(predicate_fn, tC_rAux, tC_gAux); } __syncwarp(); } @@ -553,8 +568,8 @@ struct Sm90TreeVisitor< Tensor tC_rAux = make_tensor(shape(tC_gAux)); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) auto callbacks_impl = Impl::template get_consumer_store_callbacks(args); - return ConsumerStoreCallbacks( - cute::move(tC_rAux), cute::move(tC_gAux), args.tCcD, args.residue_mn, params, cute::move(callbacks_impl)); + return ConsumerStoreCallbacks( + cute::move(tC_rAux), cute::move(tC_gAux), args.tCcD, args.residue_tCcD, params, cute::move(callbacks_impl)); } }; @@ -596,6 +611,12 @@ struct Sm90AuxLoad< return args; } + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { @@ -634,20 +655,20 @@ struct Sm90AuxLoad< return EmptyProducerLoadCallbacks{}; } - template + template struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { CUTLASS_DEVICE - ConsumerStoreCallbacks(RTensor&& tC_rAux_, GTensor&& tC_gAux_, CTensor tC_cAux_, ResidueMN residue_mn_, Params const& params_) + ConsumerStoreCallbacks(RTensor&& tC_rAux_, GTensor&& tC_gAux_, CTensor tC_cAux_, ThrResidue residue_tC_cAux_, Params const& params_) : tC_rAux(cute::forward(tC_rAux_)), tC_gAux(cute::forward(tC_gAux_)), tC_cAux(tC_cAux_), - residue_mn(residue_mn_), + residue_tC_cAux(residue_tC_cAux_), params(params_) {} RTensor tC_rAux; // (CPY,CPY_M,CPY_N,{EPI_M,EPI_N}) GTensor tC_gAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) CTensor tC_cAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - ResidueMN residue_mn; + ThrResidue residue_tC_cAux; Params const& params; CUTLASS_DEVICE void @@ -659,24 +680,25 @@ struct Sm90AuxLoad< } } - constexpr int V = cute::min(Alignment, decltype(max_common_vector(tC_rAux, tC_gAux))::value); - if constexpr (V > 0) { + constexpr auto MCL = decltype(max_common_layout(tC_rAux, tC_gAux)){}; + constexpr int V = cute::min(Alignment, size(MCL)); + if constexpr (V > 1) { using VecType = uint_bit_t; Tensor tC_gAux_vec = recast(tC_gAux); Tensor tC_rAux_vec = recast(tC_rAux); - Tensor tC_cAux_vec = tC_cAux.compose(make_layout(Int{}, Int{})); // only works if vector is logically sequential - auto predicate_fn = [&] (auto&&... coords) { return elem_less(tC_cAux_vec(coords...), residue_mn); }; - copy_if(FunctionPredTensor(predicate_fn), tC_gAux_vec, tC_rAux_vec); + Tensor tC_cAux_vec = tensor<1>(zipped_divide(tC_cAux, MCL.compose(Int{}))); + auto predicate_fn = [&] (auto&&... coords) { return elem_less(tC_cAux_vec(coords...), residue_tC_cAux); }; + copy_if(predicate_fn, tC_gAux_vec, tC_rAux_vec); } else { - auto predicate_fn = [&] (auto&&... coords) { return elem_less(tC_cAux(coords...), residue_mn); }; - copy_if(FunctionPredTensor(predicate_fn), tC_gAux, tC_rAux); + auto predicate_fn = [&] (auto&&... coords) { return elem_less(tC_cAux(coords...), residue_tC_cAux); }; + copy_if(predicate_fn, tC_gAux, tC_rAux); } } } CUTLASS_DEVICE void - previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { + begin_loop(int epi_m, int epi_n) { if constexpr (decltype(cute::rank(tC_rAux))::value == 3) { if constexpr (EnableNullptr) { if (params.ptr_aux == nullptr) { @@ -684,8 +706,8 @@ struct Sm90AuxLoad< } } - auto predicate_fn = [&] (auto&&... coords) { return elem_less(tC_cAux(_,_,_,epi_m,epi_n)(coords...), residue_mn); }; - copy_if(FunctionPredTensor(predicate_fn), tC_gAux(_,_,_,epi_m,epi_n), tC_rAux); + auto predicate_fn = [&] (auto&&... coords) { return elem_less(tC_cAux(_,_,_,epi_m,epi_n)(coords...), residue_tC_cAux); }; + copy_if(predicate_fn, tC_gAux(_,_,_,epi_m,epi_n), tC_rAux); } } @@ -734,8 +756,8 @@ struct Sm90AuxLoad< } } - return ConsumerStoreCallbacks( - cute::move(tC_rAux), cute::move(tC_gAux), args.tCcD, args.residue_mn, params); + return ConsumerStoreCallbacks( + cute::move(tC_rAux), cute::move(tC_gAux), args.tCcD, args.residue_tCcD, params); } }; diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp index 1ea663f6f0..4eb326b3dd 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp @@ -208,6 +208,12 @@ struct Sm90AuxLoad { return Params{tma_load_aux, args.null_default, use_default}; } + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { @@ -386,10 +392,8 @@ template< template class ReductionFn = multiplies > struct Sm90ScalarBroadcast { - static_assert( - (cute::is_same_v>) || // scalar broadcast, e.g. alpha - (cute::is_same_v>) || // batched scalar broadcast, e.g. per-batch alpha - (cute::is_same_v>)); + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_0>{}); struct SharedStorage { }; @@ -407,6 +411,12 @@ struct Sm90ScalarBroadcast { return args; } + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { @@ -536,7 +546,7 @@ struct Sm90ScalarBroadcast { namespace detail { template -constexpr int +[[deprecated("row broadcast only uses 0 stages")]] constexpr int compute_row_broadcast_stages() { return ceil_div(StagesC, size<1>(zipped_divide(make_layout(take<0,2>(CtaTileShapeMNK{})), EpilogueTile{}))) + 1; } @@ -545,8 +555,6 @@ compute_row_broadcast_stages() { // Row vector broadcast template< - // Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least - // ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races int Stages, class CtaTileShapeMNK, class Element, @@ -555,14 +563,12 @@ template< bool EnableNullptr = true // Fallback scalar broadcast for nullptr params > struct Sm90RowBroadcast { - static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); - static_assert( - (cute::is_same_v>) || // row vector broadcast, e.g. per-col alpha/bias - (cute::is_same_v>)); // batched row vector broadcast + static_assert(Stages == 0, "Row broadcast doesn't support smem usage"); + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{}); - // Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem - struct SharedStorage { - alignas(16) array_aligned(CtaTileShapeMNK{}) * Stages> smem_row; + struct SharedStorage { + array_aligned(CtaTileShapeMNK{})> smem; }; struct Arguments { @@ -579,6 +585,12 @@ struct Sm90RowBroadcast { return args; } + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { @@ -597,15 +609,15 @@ struct Sm90RowBroadcast { CUTLASS_HOST_DEVICE Sm90RowBroadcast(Params const& params, SharedStorage const& shared_storage) - : params(params), - smem_row(const_cast(shared_storage.smem_row.data())) { } + : params(params) + , smem(const_cast(shared_storage.smem.data())) { } Params params; - Element* smem_row; + Element *smem = nullptr; CUTLASS_DEVICE bool is_producer_load_needed() const { - return true; + return false; } CUTLASS_DEVICE bool @@ -618,82 +630,80 @@ struct Sm90RowBroadcast { return (params.ptr_row == nullptr && params.null_default == Element(0)); } - template - struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks { - CUTLASS_DEVICE - ProducerLoadCallbacks(GTensor&& gRow, STensor&& sRow, Params const& params) - : gRow(cute::forward(gRow)), - sRow(cute::forward(sRow)), - params(params) {} - - GTensor gRow; // (CTA_M,CTA_N) - STensor sRow; // (CTA_M,CTA_N,PIPE) - Params const& params; - - CUTLASS_DEVICE void - begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) { - if constexpr (EnableNullptr) { - if (params.ptr_row == nullptr) { - return; - } - } - - if (issue_tma_load) { - // Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size - constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * sizeof_bits_v / 8; - cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes); - // Issue the TMA bulk copy - auto bulk_copy = Copy_Atom{}.with(*full_mbarrier_ptr); - // Filter so we don't issue redundant copies over stride-0 modes - int bcast_pipe_index = (load_iteration / EpiTiles) % Stages; - copy(bulk_copy, filter(gRow), filter(sRow(_,_,bcast_pipe_index))); - } - } - }; - template CUTLASS_DEVICE auto get_producer_load_callbacks(ProducerLoadArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); - Tensor gRow = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) - Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE) - make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), - make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); - - constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; - return ProducerLoadCallbacks( - cute::move(gRow), cute::move(sRow), params); + return EmptyProducerLoadCallbacks{}; } - template + template struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { CUTLASS_DEVICE - ConsumerStoreCallbacks(RTensor&& tCrRow, STensor&& tCsRow, Params const& params) - : tCrRow(cute::forward(tCrRow)), - tCsRow(cute::forward(tCsRow)), - params(params) {} - - RTensor tCrRow; // (CPY,CPY_M,CPY_N) - STensor tCsRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) + ConsumerStoreCallbacks( + GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, + GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, + SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_, + CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_) + : tGS_gRow(tGS_gRow_) + , tGS_sRow(tGS_sRow_) + , tGS_cRow(tGS_cRow_) + , tiled_G2S(tiled_g2s_) + , tSR_sRow(tSR_sRow_) + , tSR_rRow(tSR_rRow_) + , tCcRow(tCcRow_) + , residue_tCcRow(residue_tCcRow_) + , params(params_) + , is_nullptr(EnableNullptr && params_.ptr_row == nullptr) {} + + GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N) + GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N) + GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N) + Tiled_G2S tiled_G2S; + + SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ThrResidue residue_tCcRow; // (m, n) + ThrNum thr_num; Params const& params; + bool is_nullptr; CUTLASS_DEVICE void - previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { + begin() { if constexpr (EnableNullptr) { if (params.ptr_row == nullptr) { - fill(tCrRow, params.null_default); + fill(tSR_rRow, params.null_default); return; } } + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); + Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); + Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); + + for (int i = 0; i < size(tGS_gRow_flt); ++i) { + if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { + continue; // OOB of SMEM, + } + if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) { + tGS_sRow_flt(i) = tGS_gRow_flt(i); + } + else { + tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds. + } + } + synchronize(); + } + + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { if (epi_m == 0) { // Assumes M-major subtile loop - // Filter so we don't issue redundant copies over stride-0 modes - // (only works if 0-strides are in same location, which is by construction) - int bcast_pipe_index = (load_iteration / EpiTiles) % Stages; - copy_aligned(filter(tCsRow(_,_,_,epi_m,epi_n,bcast_pipe_index)), filter(tCrRow)); + if (is_nullptr) return; // Do not issue LDS when bias is nullptr + Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); + Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); + copy(tSR_sRow_flt, tSR_rRow_flt); } } @@ -704,7 +714,7 @@ struct Sm90RowBroadcast { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < FragmentSize; ++i) { - frg_row[i] = tCrRow(epi_v * FragmentSize + i); + frg_row[i] = tSR_rRow(epi_v * FragmentSize + i); } return frg_row; @@ -717,17 +727,41 @@ struct Sm90RowBroadcast { > CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + using ThreadCount = decltype(size(args.tiled_copy)); - Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE) - make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), - make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); - Tensor tCsRow = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) - sRow, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrRow = make_tensor_like(take<0,3>(tCsRow)); // (CPY,CPY_M,CPY_N) - - constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; - return ConsumerStoreCallbacks( - cute::move(tCrRow), cute::move(tCsRow), params); + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); + Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) + Tensor sRow = make_tensor(make_smem_ptr(smem), + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) + //// G2S: Gmem to Smem + auto tiled_g2s = make_tiled_copy(Copy_Atom{}, + Layout< Shape<_1, ThreadCount>, + Stride<_0, _1>>{}, + Layout<_1>{}); + auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); + Tensor tGS_gRow = thr_g2s.partition_S(gRow); + Tensor tGS_sRow = thr_g2s.partition_D(sRow); + + //// G2S: Coord + auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); + Tensor tGS_cRow = thr_g2s.partition_S(cRow); + + //// S2R: Smem to Reg + Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) + + return ConsumerStoreCallbacks( + tGS_gRow, + tGS_sRow, + tGS_cRow, tiled_g2s, + tSR_sRow, + tSR_rRow, + args.tCcD, + args.residue_cD, + ThreadCount{}, + params); } }; @@ -743,11 +777,9 @@ template< bool EnableNullptr = true // Fallback scalar broadcast for nullptr params > struct Sm90ColBroadcast { - static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet"); - static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); - static_assert( - (cute::is_same_v>) || // col vector broadcast, e.g. per-row alpha/bias - (cute::is_same_v>)); // batched col vector broadcast, e.g. batched per-row bias + static_assert(Stages == 0, "Column broadcast doesn't support smem usage"); + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_1,_0>{}); // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem struct SharedStorage { }; @@ -766,6 +798,12 @@ struct Sm90ColBroadcast { return args; } + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { @@ -809,16 +847,20 @@ struct Sm90ColBroadcast { return EmptyProducerLoadCallbacks{}; } - template + template struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { CUTLASS_DEVICE - ConsumerStoreCallbacks(GTensor&& tCgCol, RTensor&& tCrCol, Params const& params) + ConsumerStoreCallbacks(GTensor&& tCgCol, RTensor&& tCrCol, CTensor tCcCol, ThrResidue residue_tCcCol, Params const& params) : tCgCol(cute::forward(tCgCol)), tCrCol(cute::forward(tCrCol)), + tCcCol(tCcCol), + residue_tCcCol(residue_tCcCol), params(params) {} GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) RTensor tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ThrResidue residue_tCcCol; Params const& params; CUTLASS_DEVICE void @@ -832,7 +874,24 @@ struct Sm90ColBroadcast { // Filter so we don't issue redundant copies over stride-0 modes // (only works if 0-strides are in same location, which is by construction) - copy_aligned(filter(tCgCol), filter(tCrCol)); + Tensor tCgCol_flt = filter_zeros(tCgCol); + Tensor tCrCol_flt = filter_zeros(tCrCol); + Tensor tCcCol_flt = make_tensor(tCcCol.data(), make_layout(tCrCol_flt.shape(), tCcCol.stride())); + + constexpr auto MCL = decltype(max_common_layout(tCgCol_flt, tCrCol_flt)){}; + constexpr int V = cute::min(Alignment, size(MCL)); + if constexpr (V > 1) { + using VecType = uint_bit_t>; + Tensor tCgCol_vec = recast(coalesce(tCgCol_flt)); + Tensor tCrCol_vec = recast(coalesce(tCrCol_flt)); + Tensor tCcCol_vec = tensor<1>(zipped_divide(tCcCol_flt, MCL.compose(Int{}))); + auto pred_fn = [&] (auto const&... coords) { return elem_less(tCcCol_vec(coords...), residue_tCcCol); }; + copy_if(pred_fn, tCgCol_vec, tCrCol_vec); + } + else { + auto pred_fn = [&] (auto const&... coords) { return elem_less(tCcCol_flt(coords...), residue_tCcCol); }; + copy_if(pred_fn, tCgCol_flt, tCrCol_flt); + } } template @@ -864,8 +923,8 @@ struct Sm90ColBroadcast { mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - return ConsumerStoreCallbacks( - cute::move(tCgCol), cute::move(tCrCol), params); + return ConsumerStoreCallbacks( + cute::move(tCgCol), cute::move(tCrCol), args.tCcD, args.residue_tCcD, params); } }; diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp index c8d941b62b..ae7b42b2bd 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp @@ -124,6 +124,12 @@ struct Sm90AuxStore { return {tma_store_aux, is_nullptr}; } + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { @@ -322,6 +328,12 @@ struct Sm90ScalarReduction { return args; } + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { @@ -368,24 +380,24 @@ struct Sm90ScalarReduction { return EmptyProducerLoadCallbacks{}; } - template + template struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { CUTLASS_DEVICE ConsumerStoreCallbacks( int l_coord, CTensor tCcScalar, - ResidueMN residue_mn, + ThrResidue residue_tCcScalar, Params const& params) : scalar(params.reduction_identity), l_coord(l_coord), tCcScalar(tCcScalar), - residue_mn(residue_mn), + residue_tCcScalar(residue_tCcScalar), params(params) {} ElementCompute scalar; int l_coord; CTensor tCcScalar; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - ResidueMN residue_mn; + ThrResidue residue_tCcScalar; Params params; template @@ -408,7 +420,7 @@ struct Sm90ScalarReduction { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < FragmentSize; ++i) { - if (elem_less(tCcScalar_mn(epi_v * FragmentSize + i), residue_mn)) { + if (elem_less(tCcScalar_mn(epi_v * FragmentSize + i), residue_tCcScalar)) { scalar = reduce_input(scalar, frg_I[i]); } } @@ -442,8 +454,8 @@ struct Sm90ScalarReduction { > CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - return ConsumerStoreCallbacks( - get<3>(args.tile_coord_mnkl), args.tCcD, args.residue_mn, params); + return ConsumerStoreCallbacks( + get<3>(args.tile_coord_mnkl), args.tCcD, args.residue_tCcD, params); } }; @@ -505,18 +517,18 @@ struct Sm90RowReduction { if constexpr (IsAtomic) { reduction_buffer = nullptr; } - else if constexpr (not FinalReduction) { - reduction_buffer = reinterpret_cast(args.ptr_row); - } - else { + else if constexpr (FinalReduction) { auto [M, N, K, L] = problem_shape; auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M), size<>(N), L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute); - tile_counters_offset = round_nearest(tile_counters_offset, sizeof(int)); + tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); reduction_buffer = reinterpret_cast(workspace); tile_counters = reinterpret_cast(reinterpret_cast(workspace) + tile_counters_offset); } + else { + reduction_buffer = reinterpret_cast(args.ptr_row); + } return { args.ptr_row, @@ -527,6 +539,12 @@ struct Sm90RowReduction { }; } + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { @@ -540,7 +558,7 @@ struct Sm90RowReduction { // Increment by size of reduction buffer workspace_size += product(ceil_div(make_shape(size<>(M),size<>(N),L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute); // Align and increment by size of tile counters - workspace_size = round_nearest(workspace_size, sizeof(int)); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); workspace_size += cute::ceil_div(size<>(N), tile_N) * sizeof(int); return workspace_size; } @@ -551,19 +569,25 @@ struct Sm90RowReduction { CudaHostAdapter* cuda_adapter = nullptr) { if constexpr (IsAtomic) { auto [M, N, K, L] = problem_shape; - Layout mRow_layout = make_layout(make_shape(M,N,L), args.dRow); + Layout mRow_layout = make_layout(make_shape(size<>(M),size<>(N),size<>(L)), args.dRow); if (args.ptr_row != nullptr) { return fill_workspace(args.ptr_row, ElementOutput(args.reduction_identity), cosize(mRow_layout), stream, cuda_adapter); } return Status::kSuccess; } - auto [M, N, K, L] = problem_shape; - auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; - size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M),size<>(N),L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute); + else if constexpr (FinalReduction) { + auto [M, N, K, L] = problem_shape; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M),size<>(N),L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute); + tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); - int* tile_counters = reinterpret_cast(reinterpret_cast(workspace) + tile_counters_offset); - size_t tile_counters_size = cute::ceil_div(size<>(N), tile_N) * sizeof(int); - return zero_workspace(tile_counters, tile_counters_size, stream); + int* tile_counters = reinterpret_cast(reinterpret_cast(workspace) + tile_counters_offset); + size_t tile_counters_size = cute::ceil_div(size<>(N), tile_N) * sizeof(int); + return zero_workspace(tile_counters, tile_counters_size, stream, cuda_adapter); + } + else { + return Status::kSuccess; + } } CUTLASS_DEVICE bool @@ -615,7 +639,7 @@ struct Sm90RowReduction { auto& [ref_src, tCrRow, tCcRow, gRow_l, cRow, gBuf_ml, sBuf_layout, lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, - tile_coord_mnkl, residue_mn, epi_tile, tiled_copy, thread_idx] = args_tuple; + tile_coord_mnkl, residue_cRow, residue_tCcRow, epi_tile, tiled_copy, thread_idx] = args_tuple; Tensor tCrRow_mn = tCrRow(_,_,_,epi_m,epi_n); Tensor tCcRow_mn = tCcRow(_,_,_,epi_m,epi_n); @@ -627,13 +651,7 @@ struct Sm90RowReduction { Array frg_I = convert_input(frg_input); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < FragmentSize; ++i) { - if constexpr (VisitCheckOOB) { - if (elem_less(tCcRow_mn(epi_v * FragmentSize + i), residue_mn)) { - ElementCompute& tCrRow_vmn = tCrRow_mn(epi_v * FragmentSize + i); - tCrRow_vmn = reduce_input(tCrRow_vmn, frg_I[i]); - } - } - else { + if (!VisitCheckOOB || elem_less(tCcRow_mn(epi_v * FragmentSize + i), residue_tCcRow)) { ElementCompute& tCrRow_vmn = tCrRow_mn(epi_v * FragmentSize + i); tCrRow_vmn = reduce_input(tCrRow_vmn, frg_I[i]); } @@ -642,16 +660,16 @@ struct Sm90RowReduction { return frg_input; } - template + template CUTLASS_DEVICE void - reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration) { + reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { if (not is_last_iteration) { return; } auto& [ref_src, tCrRow, tCcRow, gRow_l, cRow, gBuf_ml, sBuf_layout, lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, - tile_coord_mnkl, residue_mn, epi_tile, tiled_copy, thread_idx] = args_tuple; + tile_coord_mnkl, residue_cRow, residue_tCcRow, epi_tile, tiled_copy, thread_idx] = args_tuple; auto [m, n, k, l] = tile_coord_mnkl; constexpr bool ReferenceSrc = decltype(ref_src)::value; if constexpr (EnableNullptr) { @@ -661,7 +679,7 @@ struct Sm90RowReduction { } // fully OOB CTA in partially OOB cluster - if (not elem_less(cRow(_0{},_0{}), residue_mn)) { + if (not elem_less(cRow(_0{},_0{}), residue_cRow)) { return; } @@ -702,7 +720,7 @@ struct Sm90RowReduction { if (is_reduced_lane) { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tCrRow_flt); ++i) { - if (elem_less(tCcRow_flt(i), residue_mn)) { + if (elem_less(tCcRow_flt(i), residue_tCcRow)) { reduce_output(&tCgRow_flt(i), convert_output(tCrRow_flt(i))); } } @@ -733,6 +751,7 @@ struct Sm90RowReduction { static_assert(decltype(cosize(sBuf.layout()))::value * sizeof(ElementCompute) <= decltype(cosize(smem_buffer.layout()))::value * sizeof(typename remove_cvref_t::value_type), "smem reduction buffer not large enough, use a larger epilogue tile"); + sync_fn(); // Dump warp reduction to smem workspace Tensor tCsBuf = sm90_partition_for_epilogue(sBuf(_,_,get<0>(warp_mn)), epi_tile, tiled_copy, thread_idx); @@ -807,7 +826,7 @@ struct Sm90RowReduction { auto& [ref_src, tCrRow, tCcRow, gRow_l, cRow, gBuf_ml, sBuf_layout, lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, - tile_coord_mnkl, residue_mn, epi_tile, tiled_copy, thread_idx] = args_tuple; + tile_coord_mnkl, residue_cRow, residue_tCcRow, epi_tile, tiled_copy, thread_idx] = args_tuple; using ReduceOutput = GmemReduceFn; using ConvertOutput = NumericConverter; @@ -824,7 +843,7 @@ struct Sm90RowReduction { for (int ml = 1; ml < size(tRgBuf_ml); ++ml) { output = reduce_output(output, tRgBuf_ml(ml)); } - if (elem_less(cRow(_0{},n), residue_mn)) { + if (elem_less(cRow(_0{},n), residue_cRow)) { gRow_l(_0{},n,_0{}) = convert_output(output); } } @@ -833,7 +852,7 @@ struct Sm90RowReduction { else { CUTLASS_PRAGMA_NO_UNROLL for (int n = thread_idx; n < size<1>(gBuf_ml); n += size(tiled_copy)) { - bool do_store = elem_less(cRow(_0{},n), residue_mn); + bool do_store = elem_less(cRow(_0{},n), residue_cRow); CUTLASS_PRAGMA_NO_UNROLL for (int l = 0; l < size<3>(gBuf_ml); ++l) { Tensor tRgBuf_m = gBuf_ml(_0{},n,_,l); @@ -910,7 +929,7 @@ struct Sm90RowReduction { auto args_tuple = make_tuple( bool_constant{}, cute::move(tCrRow), args.tCcD, gRow_l, args.cD, gBuf_ml, sBuf_layout, lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, - args.tile_coord_mnkl, args.residue_mn, args.epi_tile, args.tiled_copy, args.thread_idx); + args.tile_coord_mnkl, args.residue_cD, args.residue_tCcD, args.epi_tile, args.tiled_copy, args.thread_idx); return ConsumerStoreCallbacks(cute::move(args_tuple), params); } }; @@ -971,18 +990,18 @@ struct Sm90ColReduction { if constexpr (IsAtomic) { reduction_buffer = nullptr; } - else if constexpr (not FinalReduction) { - reduction_buffer = reinterpret_cast(args.ptr_col); - } - else { + else if constexpr (FinalReduction) { auto [M, N, K, L] = problem_shape; auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute); - tile_counters_offset = round_nearest(tile_counters_offset, sizeof(int)); + tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); reduction_buffer = reinterpret_cast(workspace); tile_counters = reinterpret_cast(reinterpret_cast(workspace) + tile_counters_offset); } + else { + reduction_buffer = reinterpret_cast(args.ptr_col); + } return { args.ptr_col, @@ -993,6 +1012,12 @@ struct Sm90ColReduction { }; } + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { @@ -1007,7 +1032,7 @@ struct Sm90ColReduction { // Increment by size of reduction buffer workspace_size += product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute); // Align and increment by size of tile counters - workspace_size = round_nearest(workspace_size, sizeof(int)); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); workspace_size += cute::ceil_div(M, tile_M) * sizeof(int); return workspace_size; @@ -1019,21 +1044,25 @@ struct Sm90ColReduction { CudaHostAdapter* cuda_adapter = nullptr) { if constexpr (IsAtomic) { auto [M, N, K, L] = problem_shape; - Layout mCol_layout = make_layout(make_shape(M,N,L), args.dCol); + Layout mCol_layout = make_layout(make_shape(size<>(M),size<>(N),size<>(L)), args.dCol); if (args.ptr_col != nullptr) { return fill_workspace(args.ptr_col, ElementOutput(args.reduction_identity), cosize(mCol_layout), stream, cuda_adapter); } return Status::kSuccess; } + else if constexpr (FinalReduction) { + auto [M, N, K, L] = problem_shape; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute); + tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); - auto [M, N, K, L] = problem_shape; - auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; - size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute); - tile_counters_offset = round_nearest(tile_counters_offset, sizeof(int)); - - int* tile_counters = reinterpret_cast(reinterpret_cast(workspace) + tile_counters_offset); - size_t tile_counters_size = cute::ceil_div(M, tile_M) * sizeof(int); - return zero_workspace(tile_counters, tile_counters_size, stream); + int* tile_counters = reinterpret_cast(reinterpret_cast(workspace) + tile_counters_offset); + size_t tile_counters_size = cute::ceil_div(M, tile_M) * sizeof(int); + return zero_workspace(tile_counters, tile_counters_size, stream, cuda_adapter); + } + else { + return Status::kSuccess; + } } CUTLASS_DEVICE bool @@ -1084,7 +1113,7 @@ struct Sm90ColReduction { auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout, lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, - tile_coord_mnkl, residue_mn, epi_tile, tiled_copy, thread_idx] = args_tuple; + tile_coord_mnkl, residue_cCol, residue_tCcCol, epi_tile, tiled_copy, thread_idx] = args_tuple; Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); Tensor tCcCol_mn = tCcCol(_,_,_,epi_m,epi_n); @@ -1096,33 +1125,25 @@ struct Sm90ColReduction { Array frg_I = convert_input(frg_input); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < FragmentSize; ++i) { - if constexpr (VisitCheckOOB) { - if (elem_less(tCcCol_mn(epi_v * FragmentSize + i), residue_mn)) { - ElementCompute& tCrCol_vmn = tCrCol_mn(epi_v * FragmentSize + i); - tCrCol_vmn = reduce_input(tCrCol_vmn, frg_I[i]); - } - } - else { - if (elem_less(tCcCol_mn(epi_v * FragmentSize + i), residue_mn)) { - ElementCompute& tCrCol_vmn = tCrCol_mn(epi_v * FragmentSize + i); - tCrCol_vmn = reduce_input(tCrCol_vmn, frg_I[i]); - } + if (!VisitCheckOOB || elem_less(tCcCol_mn(epi_v * FragmentSize + i), residue_tCcCol)) { + ElementCompute& tCrCol_vmn = tCrCol_mn(epi_v * FragmentSize + i); + tCrCol_vmn = reduce_input(tCrCol_vmn, frg_I[i]); } } return frg_input; } - template + template CUTLASS_DEVICE void - reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration) { + reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { if (not is_last_iteration) { return; } auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout, lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, - tile_coord_mnkl, residue_mn, epi_tile, tiled_copy, thread_idx] = args_tuple; + tile_coord_mnkl, residue_cCol, residue_tCcCol, epi_tile, tiled_copy, thread_idx] = args_tuple; auto [m, n, k, l] = tile_coord_mnkl; constexpr bool ReferenceSrc = decltype(ref_src)::value; @@ -1134,7 +1155,7 @@ struct Sm90ColReduction { } // fully OOB CTA in partially OOB cluster - if (not elem_less(cCol(_0{},_0{}), residue_mn)) { + if (not elem_less(cCol(_0{},_0{}), residue_cCol)) { return; } @@ -1176,7 +1197,7 @@ struct Sm90ColReduction { if (is_reduced_lane) { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tCrCol_flt); ++i) { - if (elem_less(tCcCol_flt(i), residue_mn)) { + if (elem_less(tCcCol_flt(i), residue_tCcCol)) { reduce_output(&tCgCol_flt(i), convert_output(tCrCol_flt(i))); } } @@ -1207,6 +1228,7 @@ struct Sm90ColReduction { static_assert(decltype(cosize(sBuf.layout()))::value * sizeof(ElementCompute) <= decltype(cosize(smem_buffer.layout()))::value * sizeof(typename remove_cvref_t::value_type), "smem reduction buffer not large enough, use a larger epilogue tile"); + sync_fn(); // Dump warp reduction to smem workspace Tensor tCsBuf = sm90_partition_for_epilogue(sBuf(_,_,get<1>(warp_mn)), epi_tile, tiled_copy, thread_idx); @@ -1281,7 +1303,7 @@ struct Sm90ColReduction { auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout, lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, - tile_coord_mnkl, residue_mn, epi_tile, tiled_copy, thread_idx] = args_tuple; + tile_coord_mnkl, residue_cCol, residue_tCcCol, epi_tile, tiled_copy, thread_idx] = args_tuple; using ReduceOutput = GmemReduceFn; using ConvertOutput = NumericConverter; @@ -1298,7 +1320,7 @@ struct Sm90ColReduction { for (int nl = 1; nl < size(tRgBuf_nl); ++nl) { output = reduce_output(output, tRgBuf_nl(nl)); } - if (elem_less(cCol(m,_0{}), residue_mn)) { + if (elem_less(cCol(m,_0{}), residue_cCol)) { gCol_l(m,_0{},_0{}) = convert_output(output); } } @@ -1307,7 +1329,7 @@ struct Sm90ColReduction { else { CUTLASS_PRAGMA_NO_UNROLL for (int m = thread_idx; m < size<0>(gBuf_nl); m += size(tiled_copy)) { - bool do_store = elem_less(cCol(m,_0{}), residue_mn); + bool do_store = elem_less(cCol(m,_0{}), residue_cCol); CUTLASS_PRAGMA_NO_UNROLL for (int l = 0; l < size<3>(gBuf_nl); ++l) { Tensor tRgBuf_n = gBuf_nl(m,_0{},_,l); @@ -1378,7 +1400,7 @@ struct Sm90ColReduction { auto args_tuple = make_tuple( bool_constant{}, cute::move(tCrCol), args.tCcD, gCol_l, args.cD, gBuf_nl, sBuf_layout, lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, - args.tile_coord_mnkl, args.residue_mn, args.epi_tile, args.tiled_copy, args.thread_idx); + args.tile_coord_mnkl, args.residue_cD, args.residue_tCcD, args.epi_tile, args.tiled_copy, args.thread_idx); return ConsumerStoreCallbacks(std::move(args_tuple), params); } }; diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp index 1e07cc891f..843640127d 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp @@ -119,14 +119,14 @@ template< class ProblemShapeMNKL, class TileShapeMNK, class TileCoordMNKL, - class ResidueMN, + class TiledMma, class EpilogueTile > struct ProducerLoadArgs { ProblemShapeMNKL problem_shape_mnkl; TileShapeMNK tile_shape_mnk; TileCoordMNKL tile_coord_mnkl; - ResidueMN residue_mn; + TiledMma tiled_mma; EpilogueTile epi_tile; int thread_idx; @@ -135,13 +135,13 @@ struct ProducerLoadArgs { ProblemShapeMNKL problem_shape_mnkl, TileShapeMNK tile_shape_mnk, TileCoordMNKL tile_coord_mnkl, - ResidueMN residue_mn, + TiledMma tiled_mma, EpilogueTile epi_tile, int thread_idx) : problem_shape_mnkl(problem_shape_mnkl), tile_shape_mnk(tile_shape_mnk), tile_coord_mnkl(tile_coord_mnkl), - residue_mn(residue_mn), + tiled_mma(tiled_mma), epi_tile(epi_tile), thread_idx(thread_idx) {} }; @@ -150,47 +150,55 @@ template< class ProblemShapeMNKL, class TileShapeMNK, class TileCoordMNKL, - class ResidueMN, + class TiledMma, class EpilogueTile, class TiledCopy, class CoordTensor, + class Residue, class ThrCoordTensor, + class ThrResidue, class ThrSrcTensor > struct ConsumerStoreArgs { ProblemShapeMNKL problem_shape_mnkl; TileShapeMNK tile_shape_mnk; TileCoordMNKL tile_coord_mnkl; - ResidueMN residue_mn; + TiledMma tiled_mma; EpilogueTile epi_tile; TiledCopy tiled_copy; - int thread_idx; CoordTensor cD; + Residue residue_cD; ThrCoordTensor tCcD; + ThrResidue residue_tCcD; ThrSrcTensor const& tCrC; + int thread_idx; CUTLASS_DEVICE ConsumerStoreArgs( ProblemShapeMNKL problem_shape_mnkl, TileShapeMNK tile_shape_mnk, TileCoordMNKL tile_coord_mnkl, - ResidueMN residue_mn, + TiledMma tiled_mma, EpilogueTile epi_tile, TiledCopy tiled_copy, - int thread_idx, CoordTensor cD, + Residue residue_cD, ThrCoordTensor tCcD, - ThrSrcTensor const& tCrC) + ThrResidue residue_tCcD, + ThrSrcTensor const& tCrC, + int thread_idx) : problem_shape_mnkl(problem_shape_mnkl), tile_shape_mnk(tile_shape_mnk), tile_coord_mnkl(tile_coord_mnkl), - residue_mn(residue_mn), + tiled_mma(tiled_mma), epi_tile(epi_tile), tiled_copy(tiled_copy), - thread_idx(thread_idx), cD(cD), + residue_cD(residue_cD), tCcD(tCcD), - tCrC(tCrC) {} + residue_tCcD(residue_tCcD), + tCrC(tCrC), + thread_idx(thread_idx) {} }; template @@ -220,6 +228,20 @@ struct Sm90VisitorImplBase { ); } + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return transform_apply(tuple{}, args, + [&] (auto&& op, auto const& op_args) { + using Op = cute::remove_cvref_t; + return Op::can_implement(problem_shape, op_args); + }, + [&] (auto&&... implementable) { + return (true && ... && implementable); + } + ); + } + template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { @@ -409,7 +431,17 @@ struct Sm90VisitorImpl : Sm90VisitorImplBase { ); } - // Start of subtile store iteration. Smem broadcasts usually performed here. + // Start of subtile store iteration + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { + for_each(callbacks_tuple, + [&] (auto& callbacks) { + callbacks.begin_loop(epi_m, epi_n); + } + ); + } + + // Before visit callback. Smem broadcasts usually performed here. // Upon entry, all producer loads for this subtile are completed and visible. CUTLASS_DEVICE void previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { @@ -432,12 +464,15 @@ struct Sm90VisitorImpl : Sm90VisitorImplBase { // It is each nodes reponsibility to assert that this buffer is sufficiently sized // and to ensure that this buffer is no longer needed upon callback exit // i.e. results are synchronized and no longer in the reduction buffer - template + // + // visit_results is a rmem tensor that contains the results of visit() for an entire + // on the current epilogue subtile + template CUTLASS_DEVICE void - reduce(STensor&& reduction_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration) { + reduce(STensor&& reduction_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { for_each(callbacks_tuple, [&] (auto& callbacks) { - callbacks.reduce(reduction_buffer, sync_fn, epi_m, epi_n, is_last_iteration); + callbacks.reduce(reduction_buffer, sync_fn, epi_m, epi_n, is_last_iteration, visit_results); } ); } @@ -466,6 +501,16 @@ struct Sm90VisitorImpl : Sm90VisitorImplBase { ); } + // End of subtile store iteration + CUTLASS_DEVICE void + end_loop(int epi_m, int epi_n) { + for_each(callbacks_tuple, + [&] (auto& callbacks) { + callbacks.end_loop(epi_m, epi_n); + } + ); + } + // Exit of subtile store loop. Gmem reductions usually performed here. CUTLASS_DEVICE void end() { @@ -723,6 +768,12 @@ struct Sm90VisitorImplBase { }; } + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return Op0::can_implement(problem_shape, args.op_0); + } + template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { @@ -793,6 +844,13 @@ struct Sm90VisitorImplBase { }; } + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return Op0::can_implement(problem_shape, args.op_0) && + Op1::can_implement(problem_shape, args.op_1); + } + template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { @@ -880,6 +938,14 @@ struct Sm90VisitorImplBase { }; } + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return Op0::can_implement(problem_shape, args.op_0) && + Op1::can_implement(problem_shape, args.op_1) && + Op2::can_implement(problem_shape, args.op_2); + } + template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { @@ -983,6 +1049,15 @@ struct Sm90VisitorImplBase { Op3::to_underlying_arguments(problem_shape, args.op_3, op_3_workspace) }; } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return Op0::can_implement(problem_shape, args.op_0) && + Op1::can_implement(problem_shape, args.op_1) && + Op2::can_implement(problem_shape, args.op_2) && + Op3::can_implement(problem_shape, args.op_3); + } template static size_t diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index c37f2b9ab5..92407733f8 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -379,6 +379,9 @@ struct SiLu> { } }; +template +using ScaledSiLu = Scale>; + // Hardswish operator introduced by Howard et al. in the following paper // "Searching for MobileNetV3" (2019) // https://arxiv.org/pdf/1905.02244.pdf diff --git a/include/cutlass/epilogue/thread/linear_combination_clamp.h b/include/cutlass/epilogue/thread/linear_combination_clamp.h index 5e1c847d22..aad9b52389 100644 --- a/include/cutlass/epilogue/thread/linear_combination_clamp.h +++ b/include/cutlass/epilogue/thread/linear_combination_clamp.h @@ -219,10 +219,10 @@ class LinearCombinationClamp { /// Clamping constant value ElementCompute const kClampMax = - ElementCompute(platform::numeric_limits::max()); + ElementCompute(cutlass::platform::numeric_limits::max()); ElementCompute const kClampMin = - ElementCompute(platform::numeric_limits::lowest()); + ElementCompute(cutlass::platform::numeric_limits::lowest()); intermediate = max_accumulator(intermediate, kClampMin); intermediate = min_accumulator(intermediate, kClampMax); @@ -260,10 +260,10 @@ class LinearCombinationClamp { /// Clamping constant value ElementCompute const kClampMax = - ElementCompute(platform::numeric_limits::max()); + ElementCompute(cutlass::platform::numeric_limits::max()); ElementCompute const kClampMin = - ElementCompute(platform::numeric_limits::lowest()); + ElementCompute(cutlass::platform::numeric_limits::lowest()); intermediate = max_accumulator(intermediate, kClampMin); intermediate = min_accumulator(intermediate, kClampMax); @@ -299,7 +299,7 @@ class LinearCombinationClamp { using ElementCompute = float; static_assert( - platform::numeric_limits::is_integer, + cutlass::platform::numeric_limits::is_integer, "This elementwise op expects the output to be int."); static int const kCount = Count; @@ -499,7 +499,7 @@ class FastLinearCombinationClamp { using ElementCompute = float; static_assert( - platform::numeric_limits::is_integer, + cutlass::platform::numeric_limits::is_integer, "This elementwise op expects the output to be int."); static int const kCount = Count; diff --git a/include/cutlass/epilogue/thread/linear_combination_relu.h b/include/cutlass/epilogue/thread/linear_combination_relu.h index 07ebdec93d..bbdc498622 100644 --- a/include/cutlass/epilogue/thread/linear_combination_relu.h +++ b/include/cutlass/epilogue/thread/linear_combination_relu.h @@ -447,7 +447,7 @@ class LinearCombinationRelu { // Compute threshold optionally intermediate = relu(threshold_, intermediate); - if (platform::numeric_limits::is_integer) { + if (cutlass::platform::numeric_limits::is_integer) { // Convert floats back to INT FragmentAccumulator scaled_accumulator; @@ -492,7 +492,7 @@ class LinearCombinationRelu { // Compute threshold optionally intermediate = relu(threshold_, intermediate); - if (platform::numeric_limits::is_integer) { + if (cutlass::platform::numeric_limits::is_integer) { // Convert floats back to INT FragmentAccumulator scaled_accumulator; @@ -540,7 +540,7 @@ class LinearCombinationRelu { // Compute threshold optionally intermediate = relu(threshold_, intermediate); - if (platform::numeric_limits::is_integer) { + if (cutlass::platform::numeric_limits::is_integer) { // Convert floats back to INT FragmentAccumulator scaled_accumulator; diff --git a/include/cutlass/epilogue/thread/linear_combination_relu0.h b/include/cutlass/epilogue/thread/linear_combination_relu0.h index 798b8228fb..76ad59244d 100644 --- a/include/cutlass/epilogue/thread/linear_combination_relu0.h +++ b/include/cutlass/epilogue/thread/linear_combination_relu0.h @@ -418,7 +418,7 @@ class LinearCombinationRelu0 { // Compute threshold optionally intermediate = relu(intermediate); - if (platform::numeric_limits::is_integer) { + if (cutlass::platform::numeric_limits::is_integer) { // Convert floats back to INT FragmentAccumulator scaled_accumulator; @@ -463,7 +463,7 @@ class LinearCombinationRelu0 { // Compute threshold optionally intermediate = relu(intermediate); - if (platform::numeric_limits::is_integer) { + if (cutlass::platform::numeric_limits::is_integer) { // Convert floats back to INT FragmentAccumulator scaled_accumulator; @@ -511,7 +511,7 @@ class LinearCombinationRelu0 { // Compute threshold optionally intermediate = relu(intermediate); - if (platform::numeric_limits::is_integer) { + if (cutlass::platform::numeric_limits::is_integer) { // Convert floats back to INT FragmentAccumulator scaled_accumulator; diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h b/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h index d21382b41f..16e045e1e3 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h @@ -69,10 +69,17 @@ template < typename OutputOp, int ElementsPerAccess, bool ScatterD = false, - typename PermuteDLayout = layout::NoPermute + typename PermuteDLayout = layout::NoPermute, + conv::StrideSupport StrideSupport = conv::StrideSupport::kUnity, + int Rank = 4 > struct DefaultEpilogueWithBroadcastSimt { + static conv::StrideSupport const kStrideSupport = StrideSupport; + static int const kRank = Rank; + + static bool const UseCUDAStore = platform::is_same::value; + /// Use defaults related to the existing epilogue using Base = DefaultEpilogueSimt< Shape, @@ -81,16 +88,30 @@ struct DefaultEpilogueWithBroadcastSimt { ElementsPerAccess >; - // - // Stores the result z = (y = GEMM(A, B, C), broadcast) - // - using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< + using PackedOutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< typename Base::OutputTileThreadMap, ElementOutput, ScatterD, - PermuteDLayout + PermuteDLayout, + UseCUDAStore + >; + + using StridedOutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorConv< + typename Base::OutputTileThreadMap, + ElementOutput, + ScatterD, + PermuteDLayout, + UseCUDAStore, + kRank >; + // + // Stores the result z = (y = GEMM(A, B, C), broadcast) + // + using OutputTileIterator = typename platform::conditional::type; + // // Additional tensor tile iterator - stores t = Elementwise(z) // @@ -98,7 +119,6 @@ struct DefaultEpilogueWithBroadcastSimt { typename Base::OutputTileThreadMap, ElementTensor >; - /// Define the epilogue using Epilogue = EpilogueWithBroadcast< Shape, diff --git a/include/cutlass/epilogue/threadblock/output_iterator_parameter.h b/include/cutlass/epilogue/threadblock/output_iterator_parameter.h index 0f417485e2..7300882730 100644 --- a/include/cutlass/epilogue/threadblock/output_iterator_parameter.h +++ b/include/cutlass/epilogue/threadblock/output_iterator_parameter.h @@ -107,6 +107,32 @@ struct ConvOutputIteratorParameter +struct ConvOutputIteratorParameter { + + using TensorLayout = layout::TensorNHWC; + using OutputIteratorLayout = layout::TensorNHWC; + using MappedLayout = layout::RowMajor; + using OutputTensorCoord = typename OutputIteratorLayout::TensorCoord; + using MappedTensorCoord = typename MappedLayout::TensorCoord; + using TensorRef = TensorRef_; + static conv::Operator const kConvolutionalOperator = conv::Operator::kDeconv; + using ConvProblemSize = ConvProblemSize_; + + CUTLASS_HOST_DEVICE + static OutputIteratorLayout layout(const TensorRef & ref) { + return ref.stride(); + } + + CUTLASS_HOST_DEVICE + static MappedTensorCoord extent(ConvProblemSize problem_size) { + return conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(); + } +}; + template< typename TensorRef_, ///! Input tensor to epilogue output iterator typename ConvProblemSize_ ///! Convolutional operator on 2D or 3D problem @@ -133,6 +159,32 @@ struct ConvOutputIteratorParameter +struct ConvOutputIteratorParameter { + + using TensorLayout = layout::TensorNDHWC; + using OutputIteratorLayout = layout::TensorNDHWC; + using MappedLayout = layout::RowMajor; + using OutputTensorCoord = typename OutputIteratorLayout::TensorCoord; + using MappedTensorCoord = typename MappedLayout::TensorCoord; + using TensorRef = TensorRef_; + static conv::Operator const kConvolutionalOperator = conv::Operator::kDeconv; + using ConvProblemSize = ConvProblemSize_; + + CUTLASS_HOST_DEVICE + static OutputIteratorLayout layout(const TensorRef & ref) { + return ref.stride(); + } + + CUTLASS_HOST_DEVICE + static MappedTensorCoord extent(ConvProblemSize problem_size) { + return conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(); + } +}; + template < int InterleavedK, typename TensorRef_, diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h index 14a854476e..9943ea2563 100644 --- a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h @@ -135,14 +135,14 @@ class PredicatedTileIterator { CUTLASS_HOST_DEVICE Params(Layout const &layout, // Not needed. Added to be compatible with strided conv epilogue. - conv::Conv2dProblemSize const &problem_size): + cutlass::Tensor4DCoord const &tensor_extent): Params(layout) { } CUTLASS_HOST_DEVICE Params(Layout const &layout, // Not needed. Added to be compatible with strided conv epilogue. - conv::Conv3dProblemSize const &problem_size): + cutlass::Tensor5DCoord const &tensor_extent): Params(layout) { } @@ -1141,7 +1141,7 @@ class InterleavedConvPredicatedTileIterator { CUTLASS_HOST_DEVICE Params(Layout const &layout, // Not needed. Added to be compatible with strided conv epilogue. - conv::Conv2dProblemSize const &problem_size): + cutlass::Tensor4DCoord const &tensor_extent): Params(layout) { } diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator_conv.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_conv.h index c3c722bc4d..a59437c091 100644 --- a/include/cutlass/epilogue/threadblock/predicated_tile_iterator_conv.h +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_conv.h @@ -138,13 +138,13 @@ class PredicatedTileIteratorConv { Params() { } CUTLASS_HOST_DEVICE - Params(Layout const &layout, conv::Conv2dProblemSize const &problem_size): + Params(Layout const &layout, cutlass::Tensor4DCoord const &tensor_extent): PredicatedTileIteratorParams( layout.stride()[0] * int(sizeof(AccessType)) / kElementsPerAccess, make_OutputTileThreadMapDesc() ) { - divmod[0] = FastDivmod(problem_size.Q); - divmod[1] = FastDivmod(problem_size.P); + divmod[0] = FastDivmod(tensor_extent[2] /* Q for Fprop & W for Deconv*/); + divmod[1] = FastDivmod(tensor_extent[1] /* P for Fprop & H for Deconv*/); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kStrideRank; ++i) { @@ -153,14 +153,14 @@ class PredicatedTileIteratorConv { } CUTLASS_HOST_DEVICE - Params(Layout const &layout, conv::Conv3dProblemSize const &problem_size): + Params(Layout const &layout, cutlass::Tensor5DCoord const &tensor_extent): PredicatedTileIteratorParams( layout.stride()[0] * int(sizeof(AccessType)) / kElementsPerAccess, make_OutputTileThreadMapDesc() ) { - divmod[0] = FastDivmod(problem_size.Q); - divmod[1] = FastDivmod(problem_size.P); - divmod[2] = FastDivmod(problem_size.Z); + divmod[0] = FastDivmod(tensor_extent[3] /* Q for Fprop & W for Deconv*/); + divmod[1] = FastDivmod(tensor_extent[2] /* P for Fprop & H for Deconv*/); + divmod[2] = FastDivmod(tensor_extent[1] /* Z for Fprop & D for Deconv*/); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kStrideRank; ++i) { diff --git a/include/cutlass/epilogue/warp/fragment_iterator_tensor_op.h b/include/cutlass/epilogue/warp/fragment_iterator_tensor_op.h index 475fb73bf8..a69f0fd25a 100644 --- a/include/cutlass/epilogue/warp/fragment_iterator_tensor_op.h +++ b/include/cutlass/epilogue/warp/fragment_iterator_tensor_op.h @@ -164,6 +164,107 @@ class FragmentIteratorTensorOp +class FragmentIteratorTensorOp { +public: + + using WarpShape = WarpShape_; + using OperatorShape = OperatorShape_; + using OperatorElementC = OperatorElementC_; + using OperatorFragmentC = OperatorFragmentC_; + using Layout = layout::ColumnMajor; + + using Policy = TensorOpPolicy; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array< + OperatorElementC, + 4 * Policy::OperatorCount::kRow * Policy::kElementsPerAccess>; + + /// This is the complete warp-level accumulator tile. + using AccumulatorTile = Array< + OperatorElementC, + OperatorFragmentC::kElements * Policy::OperatorCount::kRow * Policy::OperatorCount::kColumn>; + + using OutputAccumulatorTile = AccumulatorTile; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + using TileIterations = typename Policy::TileIterations; + static int const kIterationsPerTile = kIterations / TileIterations::kCount; + +private: + + /// Internal access type + using AccessType = Array; + +private: + + // + // Data members + // + + /// Accumulator tile + AccessType const *accumulators_; + + /// Internal index + int index_; + +public: + + /// Constructs an iterator + CUTLASS_HOST_DEVICE + FragmentIteratorTensorOp(AccumulatorTile const &accum): + accumulators_(reinterpret_cast(&accum)), + index_(0) { + } + + /// Increments + CUTLASS_HOST_DEVICE + FragmentIteratorTensorOp &operator++() { + ++index_; + return *this; + } + + /// Decrements + CUTLASS_HOST_DEVICE + FragmentIteratorTensorOp &operator--() { + --index_; + return *this; + } + + /// Loads a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void load(Fragment &frag, int index_offset = 0) const { + + int index = index_ + index_offset; + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Policy::kAccumulatorRowStride; ++i) { + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < (Policy::OperatorCount::kRow * 2); ++m) { + + int accumulator_access_offset = + index * Policy::kAccumulatorColumnStride + m * Policy::kAccumulatorRowStride / Policy::kElementsPerAccess + i; + + frag_ptr[m + i * Policy::OperatorCount::kRow * 2] = accumulators_[accumulator_access_offset]; + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// + /// Dedicated to interleaved layout template < /// shape of the warp-level GEMM tile diff --git a/include/cutlass/epilogue/warp/tensor_op_policy.h b/include/cutlass/epilogue/warp/tensor_op_policy.h index e4175f4819..b3f3a4f59c 100644 --- a/include/cutlass/epilogue/warp/tensor_op_policy.h +++ b/include/cutlass/epilogue/warp/tensor_op_policy.h @@ -98,6 +98,47 @@ struct TensorOpPolicy { //////////////////////////////////////////////////////////////////////////////// +/// Partial specialization for row-major +template < + typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) + typename OperatorShape ///< matrix multiply operation shape (concept: gemm::GemmShape) +> +struct TensorOpPolicy { + + /// Number of operations + using OperatorCount = MatrixShape< + (WarpShape::kM + OperatorShape::kM - 1) / OperatorShape::kM, + (WarpShape::kN + OperatorShape::kN - 1) / OperatorShape::kN + >; + + // + // Hard-coded constants regarding Tensor Operations + // + + static int const kElementsPerAccess = 1; + static int const kColumnsPerIteration = 8; + static bool const kDivisible = + !(WarpShape::kM % OperatorShape::kM) && !(WarpShape::kN % OperatorShape::kN); + + // + // Derived quantities + // + + // Number of 'externally visible' iterations per actual instruction + static int const kIterationsPerInstruction = OperatorShape::kN / kColumnsPerIteration; + + // Number of externally visible iterations + static int const kIterations = OperatorCount::kColumn * kIterationsPerInstruction; + + using TileIterations = MatrixShape; + + // Hard code for 16x8 + static int const kAccumulatorRowStride = 2; + static int const kAccumulatorColumnStride = 4 * OperatorCount::kRow; +}; + +//////////////////////////////////////////////////////////////////////////////// + /// Partial specialization for column-major-interleaved template < typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) diff --git a/include/cutlass/fast_math.h b/include/cutlass/fast_math.h index 84fb06def2..fa3873c5e7 100644 --- a/include/cutlass/fast_math.h +++ b/include/cutlass/fast_math.h @@ -65,13 +65,13 @@ CUTLASS_HOST_DEVICE void swap(T &lhs, T &rhs) { * Static math utilities ******************************************************************************/ -/// Mixed precision dot product +/// Mixed precision dot product template CUTLASS_HOST_DEVICE LongIndex dot( - Coord const &coord, - Coord const &stride, + Coord const &coord, + Coord const &stride, LongIndex acc = LongIndex()) { - + CUTLASS_PRAGMA_UNROLL for (int n = 0; n < N; ++n) { acc += LongIndex(coord[n]) * stride[n]; @@ -312,18 +312,19 @@ void fast_divmod(int& quo, int64_t& rem, int64_t src, int div, unsigned int mul, /// /// FastDivmod divmod(divisor); /// -/// divmod(quotient, remainder, dividend); +/// divmod(quotient, remainder, dividend); /// /// // quotient = (dividend / divisor) /// // remainder = (dividend % divisor) /// struct FastDivmod { + using value_div_type = int; + using value_mod_type = int64_t; + int32_t divisor = 1; + uint32_t multiplier = 0u; + uint32_t shift_right = 0u; - int divisor; - unsigned int multiplier; - unsigned int shift_right; - - /// Find quotient and remainder using device-side intrinsics + // Find quotient and remainder using device-side intrinsics CUTLASS_HOST_DEVICE void fast_divmod(int& quotient, int& remainder, int dividend) const { @@ -357,21 +358,17 @@ struct FastDivmod { /// /// This precomputes some values based on the divisor and is computationally expensive. - CUTLASS_HOST_DEVICE - FastDivmod(): divisor(0), multiplier(0), shift_right(0) { } + constexpr FastDivmod() = default; CUTLASS_HOST_DEVICE - FastDivmod(int divisor): divisor(divisor) { - + FastDivmod(int divisor_): divisor(divisor_) { + assert(divisor_ >= 0); if (divisor != 1) { unsigned int p = 31 + find_log2(divisor); unsigned m = unsigned(((1ull << p) + unsigned(divisor) - 1) / unsigned(divisor)); multiplier = m; shift_right = p - 32; - } else { - multiplier = 0; - shift_right = 0; } } @@ -429,7 +426,6 @@ struct FastDivmod { operator int() const { return divisor; } }; - ///////////////////////////////////////////////////////////////////////////////////////////////// /// Object to encapsulate the fast division+modulus operation for 64b integer division. @@ -445,7 +441,7 @@ struct FastDivmod { /// /// FastDivmodU64 divmod(divisor); /// -/// divmod(quotient, remainder, dividend); +/// divmod(quotient, remainder, dividend); /// /// // quotient = (dividend / divisor) /// // remainder = (dividend % divisor) @@ -517,7 +513,7 @@ struct FastDivmodU64 { /// Computes the remainder given a computed quotient and dividend CUTLASS_HOST_DEVICE uint64_t modulus(uint64_t quotient, uint64_t dividend) const { - return uint32_t(dividend - quotient * divisor); + return dividend - quotient * divisor; } /// Returns the quotient of floor(dividend / divisor) and computes the remainder @@ -697,8 +693,8 @@ template CUTLASS_HOST_DEVICE int64_t OffsetBytes(int64_t index) { static_assert( - (sizeof_bits::value >= 8 && !(sizeof_bits::value % 8)) || - (sizeof_bits::value < 8 && !(8 % sizeof_bits::value)), + (sizeof_bits::value >= 8 && !(sizeof_bits::value % 8)) || + (sizeof_bits::value < 8 && !(8 % sizeof_bits::value)), "Size of numeric type in bits must either be divisible by 8 bits, or 8 bits must be divisible by the size."); if (sizeof_bits::value >= 8) { @@ -931,7 +927,7 @@ double fast_tanh(double x) { CUTLASS_HOST_DEVICE half_t fast_tanh(half_t x) { #if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 11) && (__CUDA_ARCH__ >= 750) - + asm volatile ( "tanh.approx.f16 %0, %1;" : "=h"(x.raw()) : "h"(x.raw())); return x; @@ -1010,13 +1006,13 @@ template struct fast_tanh_op> { CUTLASS_DEVICE Array operator()(Array const &rhs) const { - + Array result; // use x2 specialization uint32_t const *in = reinterpret_cast(&rhs); uint32_t *out = reinterpret_cast(&result); - + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 2; ++i) { asm volatile ("tanh.approx.f16x2 %0, %1;" : "=r"(out[i]) : "r"(in[i])); @@ -1026,7 +1022,7 @@ struct fast_tanh_op> { if (N % 2) { uint16_t const *in = reinterpret_cast(&rhs); uint16_t *out = reinterpret_cast(&result); - asm volatile ("tanh.approx.f16 %0, %1;" : "=h"(out[N - 1]) : "h"(in[N - 1])); + asm volatile ("tanh.approx.f16 %0, %1;" : "=h"(out[N - 1]) : "h"(in[N - 1])); } return result; @@ -1038,7 +1034,7 @@ template struct fast_tanh_op> { CUTLASS_HOST_DEVICE Array operator()(Array const &rhs) const { - + fast_tanh_op fast_op; Array y; diff --git a/include/cutlass/float8.h b/include/cutlass/float8.h index a2d062a04b..5709ec9fed 100644 --- a/include/cutlass/float8.h +++ b/include/cutlass/float8.h @@ -1153,6 +1153,7 @@ struct numeric_limits : } // namespace std #endif +namespace cutlass { namespace platform { /// Numeric limits common to all float8 types @@ -1208,7 +1209,7 @@ struct float8_base_numeric_limits { static F8Type denorm_min() { return F8Type::bitcast(0x01); } }; -/// std::numeric_limits +/// Forward Declaration template struct numeric_limits; @@ -1240,6 +1241,8 @@ struct numeric_limits : } // namespace platform +} // namespace cutlass + /////////////////////////////////////////////////////////////////////////////////////////////////// // diff --git a/include/cutlass/floating_point_nvrtc.h b/include/cutlass/floating_point_nvrtc.h index 67b35f8e8d..fdbd80fcdd 100644 --- a/include/cutlass/floating_point_nvrtc.h +++ b/include/cutlass/floating_point_nvrtc.h @@ -58,6 +58,39 @@ enum { FP_NORMAL }; +CUTLASS_HOST_DEVICE +int fpclassify(float const& f) { + + uint32_t s; + + #if defined(__CUDA_ARCH__) + s = reinterpret_cast(f); + #else + std::memcpy(&s, &f, sizeof(s)); + #endif + + uint32_t exp = s & 0x7f800000; + uint32_t mantissa = s & 0x007fffff; + + if (exp == 0x7f800000) { + if (mantissa) { + return FP_NAN; + } + else { + return FP_INFINITE; + } + } + else if (!exp) { + if (mantissa) { + return FP_SUBNORMAL; + } + else { + return FP_ZERO; + } + } + return FP_NORMAL; +} + /////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h index 964d2ff35f..f1444b31e3 100644 --- a/include/cutlass/functional.h +++ b/include/cutlass/functional.h @@ -37,6 +37,11 @@ #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" +#include "cutlass/platform/platform.h" + +#if defined(__CUDACC_RTC__) +#include "cutlass/floating_point_nvrtc.h" +#endif #include @@ -262,7 +267,36 @@ struct reciprocal_approximate { CUTLASS_HOST_DEVICE float operator()(float lhs) const { float ret; + #if defined(__CUDA_ARCH__) + asm volatile ("rcp.approx.f32 %0, %1;\n" : "=f"(ret) : "f"(lhs)); + #else + ret = 1.0f / lhs; + #endif + return ret; + } +}; + +/// reciprocal_approximate with ftz +template +struct reciprocal_approximate_ftz : reciprocal_approximate +{}; + +template <> +struct reciprocal_approximate_ftz { + CUTLASS_HOST_DEVICE + float operator()(float lhs) const { + float ret; + #if defined(__CUDA_ARCH__) + asm volatile ("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(ret) : "f"(lhs)); + #else + if (std::fpclassify(lhs) == FP_SUBNORMAL) { + lhs = 0.0f; + } ret = 1.0f / lhs; + if (std::fpclassify(ret) == FP_SUBNORMAL) { + ret = 0.0f; + } + #endif return ret; } }; @@ -336,7 +370,7 @@ struct maximum { CUTLASS_HOST_DEVICE T operator()(T const &lhs, T const &rhs) const { #if defined(__CUDA_ARCH__) - return lhs > rhs or isnan(lhs) ? lhs : rhs; + return lhs > rhs or ::isnan(lhs) ? lhs : rhs; #else return lhs > rhs or std::isnan(lhs) ? lhs : rhs; #endif @@ -359,7 +393,7 @@ struct maximum { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) asm volatile("max.NaN.f32 %0, %1, %2;\n" : "=f"(res) : "f"(lhs), "f"(rhs)); #elif defined(__CUDA_ARCH__) - res = lhs > rhs or isnan(lhs) ? lhs : rhs; + res = lhs > rhs or ::isnan(lhs) ? lhs : rhs; #else res = lhs > rhs or std::isnan(lhs) ? lhs : rhs; #endif @@ -394,7 +428,7 @@ struct minimum { CUTLASS_HOST_DEVICE T operator()(T const &lhs, T const &rhs) const { #if defined(__CUDA_ARCH__) - return lhs < rhs or isnan(lhs) ? lhs : rhs; + return lhs < rhs or ::isnan(lhs) ? lhs : rhs; #else return lhs < rhs or std::isnan(lhs) ? lhs : rhs; #endif @@ -409,6 +443,10 @@ struct minimum { } }; +template +struct minimum_with_nan_propagation : minimum +{}; + template struct maximum_absolute_value { CUTLASS_HOST_DEVICE @@ -469,6 +507,83 @@ struct multiply_add_relu0 { } }; +/// Guarded-multiply-add +template +struct guarded_multiply_add { + CUTLASS_HOST_DEVICE + C operator()(A const &a, B const &b, C const &c) const { + if (isnan(a) || isnan(b)) { + return C(0); + } + return C(a) * C(b) + c; + } +}; + +/// Guarded-multiply-add +template <> +struct guarded_multiply_add { + CUTLASS_HOST_DEVICE + half_t operator()(half_t const &a, half_t const &b, half_t const &c) const { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + half_t result; + asm ("fma.rn.oob.f16 %0, %1, %2, %3;\n" + : "=h"(*reinterpret_cast(&result)) + : "h"(*reinterpret_cast(&a)), "h"(*reinterpret_cast(&b)), "h"(*reinterpret_cast(&c))); + return result; +#else + if (isnan(a) || isnan(b)) { + return half_t(0); + } + return a * b + c; +#endif + } +}; + +/// Guarded-multiply-add-relu0 +template +struct guarded_multiply_add_relu0 { + CUTLASS_HOST_DEVICE + C operator()(A const &a, B const &b, C const &c) const { + if ( +#if defined(__CUDA_ARCH__) + ::isnan(a) || ::isnan(b) +#else + std::isnan(a) || std::isnan(b) +#endif + ) { + return C(0); + } + maximum mx; + return mx(C(a) * C(b) + c, C(0)); + } +}; + +template <> +struct guarded_multiply_add_relu0 { + CUTLASS_HOST_DEVICE + half_t operator()(half_t const &a, half_t const &b, half_t const &c) const { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + half_t result; + asm ("fma.rn.oob.relu.f16 %0, %1, %2, %3;\n" + : "=h"(*reinterpret_cast(&result)) + : "h"(*reinterpret_cast(&a)), "h"(*reinterpret_cast(&b)), "h"(*reinterpret_cast(&c))); + return result; +#else + if ( +#if defined(__CUDA_ARCH__) + ::isnan(a) || ::isnan(b) +#else + std::isnan(a) || std::isnan(b) +#endif + ) { + return half_t(0); + } + maximum mx; + return mx(a * b + c, half_t(0)); +#endif + } +}; + /// Fused multiply-add template struct and_add { @@ -488,11 +603,99 @@ struct xor_add { } }; +namespace detail { + +// Whether namespace-unqualified conj(t) for t of type T is +// well-formed. This says whether the compiler can find +// namespace-unqualified conj(T) via argument-dependent lookup. +// If so, then CUTLASS assumes that conj(t) returns +// the complex conjugate of t. +template +struct has_unqualified_conj : cutlass::platform::false_type +{}; + +template +struct has_unqualified_conj< + T, + decltype(conj(cutlass::platform::declval()), void()) + > : cutlass::platform::true_type +{}; + +template +constexpr bool has_unqualified_conj_v = has_unqualified_conj::value; + +} // namespace detail + +// forward declaration (needed for conjugate below) +template +CUTLASS_HOST_DEVICE T conj(T const& z); + +namespace detail { + +// Whether cutlass::conj(t) for t of type T is well-formed. +// If so, then CUTLASS assumes that cutlass::conj(t) +// returns the complex conjugate of t. +template +struct has_cutlass_conj : cutlass::platform::false_type +{}; + +template +struct has_cutlass_conj< + T, + decltype(cutlass::conj(cutlass::platform::declval()), void()) + > : cutlass::platform::true_type +{}; + +template +constexpr bool has_cutlass_conj_v = has_cutlass_conj::value; + +} // namespace detail + +// Return the complex conjugate of the input. +// +// If the struct hasn't already been specialized for type T, then +// +// 1. for arithmetic types, return z; +// +// 2. for types where either (namespace-unqualified) conj(z) or +// cutlass::conj(z) is well formed, declare "using cutlass::conj;" +// and return conj(z); and +// +// 3. for everything else, return z. +// +// Regarding (1), the C++ Standard Library makes std::conj always +// return std::complex, even for (noncomplex) arithmetic types. +// cutlass::conj(T t) needs to return type T. This follows the +// convention of linear algebra software like the BLAS, where +// "conjugate transpose" means the same thing as "transpose" for a +// matrix of noncomplex numbers. +// +// Case (2) covers std::complex, cuda::std::complex, and non-Standard +// (including user-defined) complex number types (for which "conj(z)" +// is findable via argument-dependent lookup). cutlass::conj has a +// totally generic overload, but a more type-specific overload in any +// namespace will take precedence. +// +// Case (3) covers non-Standard non-complex number types. +// +// Users should not generally need to specialize this struct for their +// own custom complex or noncomplex types. The idiomatic way to +// identify a type T as "complex" is to make namespace-unqualified +// calls to conj(T) findable via argument-dependent lookup. template struct conjugate { CUTLASS_HOST_DEVICE - T operator()(T const &a) const { - return a; + T operator()(T const& z) const { + if constexpr (cutlass::platform::is_arithmetic_v) { + return z; + } + else if constexpr (detail::has_unqualified_conj_v || detail::has_cutlass_conj_v) { + using cutlass::conj; + return conj(z); + } + else { + return z; + } } }; @@ -649,7 +852,13 @@ struct atomic_maximum { CUTLASS_DEVICE float operator()(float *ptr, float value) const { #if defined(__CUDA_ARCH__) - return !signbit(value) ? + // In device code, make sure that we do NOT try to use + // std::signbit, as that won't work if building with NVRTC. + // Instead, prefix "::" to call signbit from the global namespace, + // which CUDA guarantees to work in device code without including + // any headers. + // + return ! ::signbit(value) ? __int_as_float(atomicMax((int*)ptr, __float_as_int(value))) : __uint_as_float(atomicMin((unsigned int*)ptr, __float_as_uint(value))); #else diff --git a/include/cutlass/gemm/collective/builders/sm90_common.inl b/include/cutlass/gemm/collective/builders/sm90_common.inl index 14ae739891..298793e886 100644 --- a/include/cutlass/gemm/collective/builders/sm90_common.inl +++ b/include/cutlass/gemm/collective/builders/sm90_common.inl @@ -138,7 +138,7 @@ make_cp_async_gmem_tiled_copy() { if constexpr (cutlass::gemm::detail::is_k_major()) { // K major thread layout for K major gmem - constexpr int threads_major = TileSizeK / Alignment; + constexpr int threads_major = (ThreadCount >= TileSizeK / Alignment) ? (TileSizeK / Alignment) : ThreadCount; constexpr int threads_minor = ThreadCount / threads_major; static_assert(threads_major > 0); static_assert(ThreadCount % threads_major == 0); @@ -151,7 +151,7 @@ make_cp_async_gmem_tiled_copy() { } else if constexpr (cutlass::gemm::detail::is_mn_major()) { // MN major thread layout for MN major gmem - constexpr int threads_major = TileSizeMN / Alignment; + constexpr int threads_major = (ThreadCount >= TileSizeMN / Alignment) ? (TileSizeMN / Alignment) : ThreadCount; constexpr int threads_minor = ThreadCount / threads_major; static_assert(threads_major > 0); static_assert(ThreadCount % threads_major == 0); diff --git a/include/cutlass/gemm/collective/collective_builder.hpp b/include/cutlass/gemm/collective/collective_builder.hpp index d86fda53d5..532bfecfb1 100644 --- a/include/cutlass/gemm/collective/collective_builder.hpp +++ b/include/cutlass/gemm/collective/collective_builder.hpp @@ -31,62 +31,11 @@ #pragma once ///////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/gemm/collective/collective_mma_decl.hpp" #include "cutlass/gemm/collective/collective_mma.hpp" -namespace cutlass::gemm::collective { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// Used to specify stage counts or dispatch to automatic computation of stage count -template -struct StageCount { - static constexpr int value = num_stages; - - StageCount() = default; - explicit StageCount(cute::Int) {} -}; - -template -struct StageCountAutoCarveout { - static constexpr int bytes = carveout_bytes; - - StageCountAutoCarveout() = default; - explicit StageCountAutoCarveout(cute::Int) {} -}; - -using StageCountAuto = StageCountAutoCarveout<0>; - -// Used to automatically let the builder pick the kernel schedule. -// Can be overridden with kernel schedule tags in cutlass/gemm/dispatch_policy.hpp -struct KernelScheduleAuto {}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - class ArchTag, - class OpClass, - class ElementA, - class GmemLayoutA, - int AlignmentA, - class ElementB, - class GmemLayoutB, - int AlignmentB, - class ElementAccumulator, - class TileShape_MNK, - class ClusterShape_MNK, - class StageCountType, - class KernelScheduleType, - class Enable = void -> -struct CollectiveBuilder { - static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters."); -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::collective - ///////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/gemm/collective/collective_builder_decl.hpp" #include "cutlass/gemm/collective/builders/sm90_gmma_builder.inl" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/collective_builder_decl.hpp b/include/cutlass/gemm/collective/collective_builder_decl.hpp new file mode 100644 index 0000000000..c0570d37a9 --- /dev/null +++ b/include/cutlass/gemm/collective/collective_builder_decl.hpp @@ -0,0 +1,88 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Used to specify stage counts or dispatch to automatic computation of stage count +template +struct StageCount { + static constexpr int value = num_stages; + + StageCount() = default; + explicit StageCount(cute::Int) {} +}; + +template +struct StageCountAutoCarveout { + static constexpr int bytes = carveout_bytes; + + StageCountAutoCarveout() = default; + explicit StageCountAutoCarveout(cute::Int) {} +}; + +using StageCountAuto = StageCountAutoCarveout<0>; + +// Used to automatically let the builder pick the kernel schedule. +// Can be overridden with kernel schedule tags in cutlass/gemm/dispatch_policy.hpp +struct KernelScheduleAuto final {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ArchTag, + class OpClass, + class ElementA, + class GmemLayoutA, + int AlignmentA, + class ElementB, + class GmemLayoutB, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType, + class Enable = void +> +struct CollectiveBuilder { + static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + diff --git a/include/cutlass/gemm/collective/collective_mma.hpp b/include/cutlass/gemm/collective/collective_mma.hpp index 91c801762a..7bcc075782 100644 --- a/include/cutlass/gemm/collective/collective_mma.hpp +++ b/include/cutlass/gemm/collective/collective_mma.hpp @@ -30,38 +30,8 @@ **************************************************************************************************/ #pragma once -#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/gemm/collective/collective_mma_decl.hpp" -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::collective { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - class DispatchPolicy, - class TileShape, - class ElementA, - class StrideA, - class ElementB, - class StrideB, - class TiledMma, - class GmemTiledCopyA, - class SmemLayoutAtomA, - class SmemCopyAtomA, - class TransformA, - class GmemTiledCopyB, - class SmemLayoutAtomB, - class SmemCopyAtomB, - class TransformB -> -struct CollectiveMma { - static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::collective ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/collective_mma_decl.hpp b/include/cutlass/gemm/collective/collective_mma_decl.hpp new file mode 100644 index 0000000000..feef54962c --- /dev/null +++ b/include/cutlass/gemm/collective/collective_mma_decl.hpp @@ -0,0 +1,64 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class DispatchPolicy, + class TileShape, + class ElementA, + class StrideA, + class ElementB, + class StrideB, + class TiledMma, + class GmemTiledCopyA, + class SmemLayoutAtomA, + class SmemCopyAtomA, + class TransformA, + class GmemTiledCopyB, + class SmemLayoutAtomB, + class SmemCopyAtomB, + class TransformB +> +struct CollectiveMma { + static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + diff --git a/include/cutlass/gemm/collective/sm70_mma_twostage.hpp b/include/cutlass/gemm/collective/sm70_mma_twostage.hpp index 57773d79df..3d9e03edff 100644 --- a/include/cutlass/gemm/collective/sm70_mma_twostage.hpp +++ b/include/cutlass/gemm/collective/sm70_mma_twostage.hpp @@ -38,10 +38,11 @@ #include "cute/algorithm/gemm.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/tensor_predicate.hpp" +#include "cutlass/gemm/collective/collective_mma_decl.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// - + namespace cutlass::gemm::collective { using namespace cute; @@ -163,7 +164,7 @@ struct CollectiveMma< KTileIterator k_tile_iter, int k_tile_count, ResidueMNK residue_mnk, int thread_idx, - char *smem_buf) + char *smem_buf) { using namespace cute; @@ -252,9 +253,9 @@ struct CollectiveMma< while (k_tile_count > -1) { // Pipeline the outer products with a static for loop - for_each(make_int_sequence{}, [&] (auto k_block) + for_each(make_int_sequence{}, [&] (auto k_block) { - if (k_block == K_BLOCK_MAX - 1) + if (k_block == K_BLOCK_MAX - 1) { __syncthreads(); @@ -268,7 +269,7 @@ struct CollectiveMma< int k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static copy(tCsA(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); copy(tCsB(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); - if (k_block == 0) + if (k_block == 0) { // Copy gmem to rmem copy(gmem_tiled_copy_a, tAgA(_,_,_,*k_tile_iter), tArA); @@ -406,7 +407,7 @@ struct CollectiveMma< KTileIterator k_tile_iter, int k_tile_count, ResidueMNK residue_mnk, int thread_idx, - char *smem_buf) + char *smem_buf) { using namespace cute; @@ -549,9 +550,9 @@ struct CollectiveMma< while (k_tile_count > -1) { // Pipeline the outer products with a static for loop - for_each(make_int_sequence{}, [&] (auto k_block) + for_each(make_int_sequence{}, [&] (auto k_block) { - if (k_block == K_BLOCK_MAX - 1) + if (k_block == K_BLOCK_MAX - 1) { __syncthreads(); @@ -565,7 +566,7 @@ struct CollectiveMma< int k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static copy(tCsA(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); copy(tCsB(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); - if (k_block == 0) + if (k_block == 0) { if (k_tile_count <= 0) { clear(tApA); diff --git a/include/cutlass/gemm/collective/sm80_mma_multistage.hpp b/include/cutlass/gemm/collective/sm80_mma_multistage.hpp index a0038cddde..a129b56e3c 100644 --- a/include/cutlass/gemm/collective/sm80_mma_multistage.hpp +++ b/include/cutlass/gemm/collective/sm80_mma_multistage.hpp @@ -290,7 +290,7 @@ struct CollectiveMma< } CUTLASS_PRAGMA_NO_UNROLL - for ( ; k_tile_count > -(DispatchPolicy::Stages-1); --k_tile_count) + while (k_tile_count > -(DispatchPolicy::Stages-1)) { // Pipeline the outer products with a static for loop. // @@ -318,6 +318,9 @@ struct CollectiveMma< copy(gmem_tiled_copy_A, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,smem_pipe_write)); copy(gmem_tiled_copy_B, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,smem_pipe_write)); cp_async_fence(); + + // Advance the tile + --k_tile_count; if (k_tile_count > 0) { ++k_tile_iter; } // Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe) @@ -344,6 +347,7 @@ struct CollectiveMma< template < int Stages, + class ClusterShape_, class TileShape_, class ElementA_, class StrideA_, @@ -360,7 +364,9 @@ template < class TransformB_ > struct CollectiveMma< - MainloopSm80CpAsync, + MainloopSm80CpAsync< + Stages, + ClusterShape_>, TileShape_, ElementA_, StrideA_, @@ -380,7 +386,9 @@ struct CollectiveMma< // // Type Aliases // - using DispatchPolicy = MainloopSm80CpAsync; + using DispatchPolicy = MainloopSm80CpAsync< + Stages, + ClusterShape_>; using TileShape = TileShape_; // Follow the change in TestSmall: TileShape switch to CtaShape // In legacy arch, it should be same @@ -490,8 +498,8 @@ struct CollectiveMma< // Shift tensor so residue_k is at origin (Can't read any k_coord < residue_k) // This aligns the tensor with BLK_K for all but the 0th k_tile - gA.data() = &gA(0, get<2>(residue_mnk), 0); - gB.data() = &gB(0, get<2>(residue_mnk), 0); + gA = cute::domain_offset(make_coord(0, get<2>(residue_mnk), 0), gA); + gB = cute::domain_offset(make_coord(0, get<2>(residue_mnk), 0), gB); // Partition the copying of A and B tiles across the threads GmemTiledCopyA gmem_tiled_copy_A; diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp index 16255d70fc..4f2837d17e 100644 --- a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp @@ -35,6 +35,7 @@ #include "cutlass/numeric_types.h" #include "cutlass/pipeline/pipeline.hpp" #include "cutlass/trace.h" +#include "cutlass/cuda_host_adapter.hpp" #include "cute/arch/cluster_sm90.hpp" #include "cute/arch/copy_sm90.hpp" @@ -94,10 +95,10 @@ struct CollectiveMma< using TileShape = TileShape_; using ElementA = ElementA_; using StrideA = StrideA_; - using UnderlyingStrideA = cute::remove_pointer_t; + using InternalStrideA = cute::remove_pointer_t; using ElementB = ElementB_; using StrideB = StrideB_; - using UnderlyingStrideB = cute::remove_pointer_t; + using InternalStrideB = cute::remove_pointer_t; using TiledMma = TiledMma_; using ElementAccumulator = typename TiledMma::ValTypeC; using GmemTiledCopyA = GmemTiledCopyA_; @@ -152,14 +153,14 @@ struct CollectiveMma< // Assumption: StrideA is congruent with Problem_MK using TMA_A = decltype(make_tma_copy( GmemTiledCopyA{}, - make_tensor(static_cast(nullptr), repeat_like(UnderlyingStrideA{}, int32_t(0)), UnderlyingStrideA{}), + make_tensor(static_cast(nullptr), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}), SmemLayoutA{}(_,_,cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any // Assumption: StrideB is congruent with Problem_NK using TMA_B = decltype(make_tma_copy( GmemTiledCopyB{}, - make_tensor(static_cast(nullptr), repeat_like(UnderlyingStrideB{}, int32_t(0)), UnderlyingStrideB{}), + make_tensor(static_cast(nullptr), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}), SmemLayoutB{}(_,_,cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any @@ -182,7 +183,7 @@ struct CollectiveMma< using TensorMapStorage = typename SharedStorage::TensorMapStorage; using PipelineStorage = typename SharedStorage::PipelineStorage; - static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; // Host side kernel arguments struct Arguments { @@ -196,6 +197,7 @@ struct CollectiveMma< struct Params { TMA_A tma_load_a; TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytes; void* tensormaps; InternalElementA const** ptr_A; StrideA dA; @@ -222,14 +224,14 @@ struct CollectiveMma< // Batches/Groups are managed by using appropriate pointers to input matrices const uint32_t mock_L = 1; InternalElementA const* ptr_A_first_batch = reinterpret_cast(args.ptr_A); - InternalElementB const* ptr_B_first_batch = reinterpret_cast(args.ptr_B); + InternalElementB const* ptr_B_first_batch = reinterpret_cast(args.ptr_B); - UnderlyingStrideA stride_a; - UnderlyingStrideB stride_b; + InternalStrideA stride_a; + InternalStrideB stride_b; if constexpr (IsGroupedGemmKernel) { // Strides for Grouped Gemm will be replaced prior to the first access regardless. - stride_a = UnderlyingStrideA{}; - stride_b = UnderlyingStrideB{}; + stride_a = InternalStrideA{}; + stride_b = InternalStrideB{}; } else { // Tensor shapes for Ptr-Array are initialized correctly only here. @@ -261,6 +263,7 @@ struct CollectiveMma< return { tma_load_a, tma_load_b, + TmaTransactionBytes, tensormaps, reinterpret_cast(args.ptr_A), args.dA, @@ -280,12 +283,12 @@ struct CollectiveMma< template static cutlass::Status - initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream) { + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { return cutlass::Status::kSuccess; } template - CUTLASS_HOST_DEVICE static bool + static bool can_implement( ProblemShape problem_shapes, Arguments const& args) { @@ -299,8 +302,8 @@ struct CollectiveMma< for (int i = 0; i < problem_shapes.groups(); i++) { auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); auto [M,N,K,L] = problem_shape_MNKL; - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), UnderlyingStrideA{}); - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), UnderlyingStrideB{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); } } @@ -480,8 +483,22 @@ struct CollectiveMma< // Define C accumulators and A/B partitioning // + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + TiledMma tiled_mma; - auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) @@ -508,12 +525,9 @@ struct CollectiveMma< // Prologue GMMAs int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); - + assert(k_tile_count >= 1); tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - warpgroup_fence_operand(accum); - CUTLASS_PRAGMA_UNROLL - for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) { // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); @@ -534,6 +548,22 @@ struct CollectiveMma< ++smem_pipe_read; } + warpgroup_fence_operand(accum); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count - 1; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + warpgroup_arrive(); + cute::gemm(tiled_mma, tCrA(_,_,_,read_stage), tCrB(_,_,_,read_stage), accum); // (V,M,K) x (V,N,K) => (V,M,N) + warpgroup_commit_batch(); + + ++smem_pipe_read; + } + warpgroup_fence_operand(accum); // Mainloop GMMAs k_tile_count -= prologue_mma_count; @@ -552,13 +582,7 @@ struct CollectiveMma< int read_stage = smem_pipe_read.index(); warpgroup_fence_operand(accum); warpgroup_arrive(); - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } + cute::gemm(tiled_mma, tCrA(_,_,_,read_stage), tCrB(_,_,_,read_stage), accum); // (V,M,K) x (V,N,K) => (V,M,N) warpgroup_commit_batch(); /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed @@ -730,11 +754,12 @@ struct CollectiveMma< tensormaps_cp_fence_release ( TensorMapStorage& shared_tensormap, cute::tuple const& input_tensormaps) { - // Entire warp must do this (ie its aligned) + // Entire warp must do this (i.e. it's aligned) tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormap.smem_tensormap_A); tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormap.smem_tensormap_B); } + // The entire warp must call this function collectively (that is, the instructions are aligned) template CUTLASS_DEVICE void diff --git a/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp index bca824aea7..4b291db358 100644 --- a/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp @@ -245,7 +245,7 @@ struct CollectiveMma< } template - CUTLASS_HOST_DEVICE static bool + static bool can_implement( ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) { @@ -445,14 +445,27 @@ struct CollectiveMma< // Define C accumulators and A/B partitioning // + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + TiledMma tiled_mma; - auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + auto mma_thread_slice = tiled_mma.get_thread_slice(thread_idx); + auto mma_warpgroup_slice = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); // Allocate fragments and descriptors - Tensor tCsA = thread_mma.partition_A(sA); - Tensor tCrA = thread_mma.partition_fragment_A(sA(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCsB = thread_mma.partition_B(gmma_sB); // (MMA,MMA_N,MMA_K,PIPE) - Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCsA = mma_thread_slice.partition_A(sA); + Tensor tCrA = mma_thread_slice.partition_fragment_A(sA(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = mma_warpgroup_slice.partition_B(gmma_sB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrB = mma_warpgroup_slice.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) // // Copy Atom A retiling diff --git a/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp index 25a3ca4b12..90e7acd38c 100644 --- a/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp @@ -172,7 +172,7 @@ struct CollectiveMma< } template - CUTLASS_HOST_DEVICE static bool + static bool can_implement( ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) { @@ -361,8 +361,22 @@ struct CollectiveMma< // Define C accumulators and A/B partitioning // + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + TiledMma tiled_mma; - auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) @@ -389,13 +403,10 @@ struct CollectiveMma< // Prologue GMMAs int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); - + assert(k_tile_count >= 1); tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - warpgroup_fence_operand(accum); - CUTLASS_PRAGMA_UNROLL - for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) { - + { // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); pipeline.consumer_wait(smem_pipe_read, barrier_token); @@ -417,6 +428,26 @@ struct CollectiveMma< ++smem_pipe_read; } + warpgroup_fence_operand(accum); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count - 1; k_tile_prologue > 0; --k_tile_prologue) { + + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + + warpgroup_arrive(); + + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,_,read_stage), tCrB(_,_,_,read_stage), accum); + + warpgroup_commit_batch(); + + ++smem_pipe_read; + } + warpgroup_fence_operand(accum); // Mainloop GMMAs @@ -433,14 +464,8 @@ struct CollectiveMma< warpgroup_fence_operand(accum); warpgroup_arrive(); - - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,_,read_stage), tCrB(_,_,_,read_stage), accum); warpgroup_commit_batch(); /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp index dbac39a1d5..43e05afa07 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp @@ -110,7 +110,8 @@ struct CollectiveMma< using SmemCopyAtomB = SmemCopyAtomB_; using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); - // Swap and transpose A/B for A k-major layout and B mn-major layout since WGMMA is k-major only (e.g. tf32, Fp32, Int8, Fp8 WGMMA) + // Swap and transpose A/B for A k-major layout and B mn-major layout since WGMMA is k-major only + // (e.g. tf32, Fp32, Int8, Fp8 WGMMA) static constexpr bool IsLayoutAkBmn = cute::is_same_v, layout::RowMajor> && cute::is_same_v, layout::RowMajor>; @@ -235,21 +236,24 @@ struct CollectiveMma< // Device side kernel params struct Params { // Assumption: StrideA is congruent with Problem_MK - using TMA_A = decltype(make_tma_copy( + using TMA_A = decltype(make_tma_copy_A_sm90( GmemTiledCopyA{}, make_tensor(static_cast(nullptr), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}), SmemLayoutA{}(_,_,cute::Int<0>{}), - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + TileShape{}, + ClusterShape{})); // Assumption: StrideB is congruent with Problem_NK - using TMA_B = decltype(make_tma_copy( + using TMA_B = decltype(make_tma_copy_B_sm90( GmemTiledCopyB{}, make_tensor(static_cast(nullptr), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}), SmemLayoutB{}(_,_,cute::Int<0>{}), - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + TileShape{}, + ClusterShape{})); TMA_A tma_load_a; TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; }; // @@ -290,26 +294,33 @@ struct CollectiveMma< Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), dA)); Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), dB)); - typename Params::TMA_A tma_load_a = make_tma_copy( + typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90( GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_,_,cute::Int<0>{}), - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{})); // mcast along N mode for this M load, if any - typename Params::TMA_B tma_load_b = make_tma_copy( + TileShape{}, + ClusterShape{}); + typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90( GmemTiledCopyB{}, tensor_b, SmemLayoutB{}(_,_,cute::Int<0>{}), - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + TileShape{}, + ClusterShape{}); + uint32_t transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t transaction_bytes_nk = TmaTransactionBytesNK; + uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk; + return { tma_load_a, - tma_load_b + tma_load_b, + transaction_bytes, + transaction_bytes_mk, + transaction_bytes_nk }; } template - CUTLASS_HOST_DEVICE static bool + static bool can_implement( ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) { @@ -330,9 +341,11 @@ struct CollectiveMma< } static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; - static constexpr uint32_t TmaTransactionBytes = - cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)) + + static constexpr uint32_t TmaTransactionBytesMK = + cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytesNK = cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)) ; + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance CUTLASS_DEVICE @@ -375,7 +388,7 @@ struct CollectiveMma< CUTLASS_DEVICE void load( Params const& mainloop_params, - MainloopPipeline pipeline, + MainloopPipeline pipeline, PipelineState smem_pipe_write, cute::tuple const& load_inputs, BlockCoord const& blk_coord, @@ -422,14 +435,14 @@ struct CollectiveMma< // Issue TmaLoads // Maps the tile -> block, value if constexpr (cute::is_same_v) { - auto block_layout = Layout{}; // (m,n) -> block_id + auto block_layout = Layout{}; // (m,n) -> block_id for (int n = 0; n < size<1>(block_layout); ++n) { mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); } } if constexpr (cute::is_same_v) { - auto block_layout = Layout{}; // (m,n) -> block_id + auto block_layout = Layout{}; // (m,n) -> block_id for (int m = 0; m < size<0>(block_layout); ++m) { mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); } @@ -518,14 +531,27 @@ struct CollectiveMma< // Define C accumulators and A/B partitioning // + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + TiledMma tiled_mma; - auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + auto mma_thread_slice = tiled_mma.get_thread_slice(thread_idx); + auto mma_warpgroup_slice = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); // Allocate fragments and descriptors - Tensor tCsA = thread_mma.partition_A(sA); - Tensor tCrA = thread_mma.partition_fragment_A(sA(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCsB = thread_mma.partition_B(gmma_sB_position_dependent); // (MMA,MMA_N,MMA_K,PIPE) - Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCsA = mma_thread_slice.partition_A(sA); + Tensor tCrA = mma_thread_slice.partition_fragment_A(sA(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = mma_warpgroup_slice.partition_B(gmma_sB_position_dependent); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrB = mma_warpgroup_slice.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) // // Copy Atom A retiling diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp index 4613f7bf65..1f679c88ca 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp @@ -39,6 +39,7 @@ #include "cutlass/detail/layout.hpp" #include "cutlass/pipeline/pipeline.hpp" #include "cutlass/transform/collective/sm90_wgmma_transpose.hpp" +#include "cutlass/pipeline/pipeline.hpp" #include "cutlass/trace.h" #include "cutlass/detail/collective.hpp" @@ -50,9 +51,6 @@ #include "cute/algorithm/gemm.hpp" #include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" -#include "cutlass/pipeline/pipeline.hpp" -#include "cutlass/trace.h" -#include "cutlass/detail/collective.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -128,7 +126,8 @@ struct CollectiveMma< using TileShape = TileShape_; static_assert(cute::is_tuple::value ^ cute::is_tuple::value, - "Either A OR B must be a tuple. It must take the from {ElementOperand, [ElementScale], [ElementZero]}. Inputs in [] are optional."); + "Either A OR B must be a tuple. It must take the from {ElementOperand, [ElementScale]," + "[ElementZero]}. Inputs in [] are optional."); using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementAOptionalTuple>; using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementBOptionalTuple>; @@ -144,7 +143,8 @@ struct CollectiveMma< // These are always MN major using StrideScale = cute::Stride, int64_t, int64_t>; // For cases where we can't have a void scale, we can use this to allow the code to compile when the scale is void. - using NonVoidStrideScale = cute::conditional_t, cute::Stride<_1, int64_t, int64_t>, StrideScale>; + using NonVoidStrideScale = cute::conditional_t< + cute::is_void_v, cute::Stride<_1, int64_t, int64_t>, StrideScale>; static_assert((IsATransformed && cutlass::gemm::detail::is_k_major()) || (!IsATransformed && cutlass::gemm::detail::is_k_major()), @@ -303,11 +303,8 @@ struct CollectiveMma< // These methods use some the public members of the class. For that reason, we define them after the public section. static constexpr uint32_t - compute_tma_transaction_bytes() { - constexpr uint32_t a_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(cute::sizeof_bits_v)); - constexpr uint32_t b_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(cute::sizeof_bits_v)); - - constexpr uint32_t baseline_bytes = a_bytes + b_bytes; + compute_tma_transaction_bytes_mk() { + constexpr uint32_t baseline_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(cute::sizeof_bits_v)); if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { return baseline_bytes; @@ -333,6 +330,11 @@ struct CollectiveMma< } } + static constexpr uint32_t + compute_tma_transaction_bytes_nk() { + return cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(cute::sizeof_bits_v)); + } + public: static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); @@ -421,6 +423,9 @@ struct CollectiveMma< TMA_Zero tma_load_zero; int64_t scale_k; int group_size; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; }; // @@ -478,7 +483,7 @@ struct CollectiveMma< typename Params::TMA_Scale tma_load_scale; typename Params::TMA_Zero tma_load_zero; if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0 }; + return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0, TmaTransactionBytes, TmaTransactionBytesMK, TmaTransactionBytesNK }; } else if constexpr (ModeHasScales) { auto scale_k = (K + args.group_size - 1) / args.group_size; @@ -493,7 +498,7 @@ struct CollectiveMma< _1{}); // mcast along N mode for this M load, if any if constexpr(KernelConversionMode == ConversionMode::ConvertAndScale) { - return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size }; + return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, TmaTransactionBytes, TmaTransactionBytesMK, TmaTransactionBytesNK }; } else if constexpr(KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { Tensor tensor_zero = make_tensor(get_logical_ptr(args.ptr_Z), make_layout(make_shape(M,scale_k,L), dS)); @@ -503,7 +508,7 @@ struct CollectiveMma< SmemLayoutScale{}(_,_,cute::Int<0>{}), ScaleTileShape{}, _1{}); // mcast along N mode for this M load, if any - return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size }; + return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, TmaTransactionBytes, TmaTransactionBytesMK, TmaTransactionBytesNK }; } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in to_underlying_arguments."); } @@ -514,7 +519,7 @@ struct CollectiveMma< } template - CUTLASS_HOST_DEVICE static bool + static bool can_implement( ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) { @@ -564,7 +569,9 @@ struct CollectiveMma< } static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; - static constexpr uint32_t TmaTransactionBytes = compute_tma_transaction_bytes(); + static constexpr uint32_t TmaTransactionBytesMK = compute_tma_transaction_bytes_mk(); + static constexpr uint32_t TmaTransactionBytesNK = compute_tma_transaction_bytes_nk(); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance CUTLASS_DEVICE @@ -607,22 +614,22 @@ struct CollectiveMma< Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) // Make tiled views, defer the slice - Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) - Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { return cute::make_tuple(gA_mkl, gB_nkl); } else if constexpr (ModeHasScales) { auto scale_k = mainloop_params.scale_k; - Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(make_shape(M,scale_k,L)); // (m,scale_k,l) - Tensor gS_mkl = local_tile(mS_mkl, ScaleTileShape{}, make_coord(_,_)); // (BLK_M,BLK_Scale_K,m,scale_k,l) + Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(make_shape(M,scale_k,L)); // (m,scale_k,l) + Tensor gS_mkl = local_tile(mS_mkl, ScaleTileShape{}, make_coord(_,_)); // (BLK_M,BLK_Scale_K,m,scale_k,l) if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { return cute::make_tuple(gA_mkl, gB_nkl, gS_mkl); } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - Tensor mZ_mkl = mainloop_params.tma_load_zero.get_tma_tensor(make_shape(M,scale_k,L)); // (m,scale_k,l) - Tensor gZ_mkl = local_tile(mZ_mkl, ScaleTileShape{}, make_coord(_,_)); // (BLK_M,BLK_Scale_K,m,scale_k,l) + Tensor mZ_mkl = mainloop_params.tma_load_zero.get_tma_tensor(make_shape(M,scale_k,L)); // (m,scale_k,l) + Tensor gZ_mkl = local_tile(mZ_mkl, ScaleTileShape{}, make_coord(_,_)); // (BLK_M,BLK_Scale_K,m,scale_k,l) return cute::make_tuple(gA_mkl, gB_nkl, gS_mkl, gZ_mkl); } else { @@ -668,10 +675,10 @@ struct CollectiveMma< int lane_predicate = cute::elect_one_sync(); if (lane_predicate) { - Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) - Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) - Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_N,BLK_K,PIPE) + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_N,BLK_K,PIPE) // // Prepare the TMA loads for A, B and Scales @@ -692,10 +699,10 @@ struct CollectiveMma< Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) // Applies the mapping from block_tma_a - Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) - Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) uint16_t mcast_mask_a = 0; @@ -705,14 +712,14 @@ struct CollectiveMma< // Issue TmaLoads // Maps the tile -> block, value if constexpr (cute::is_same_v) { - auto block_layout = Layout{}; // (m,n) -> block_id + auto block_layout = Layout{}; // (m,n) -> block_id for (int n = 0; n < size<1>(block_layout); ++n) { mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); } } if constexpr (cute::is_same_v) { - auto block_layout = Layout{}; // (m,n) -> block_id + auto block_layout = Layout{}; // (m,n) -> block_id for (int m = 0; m < size<0>(block_layout); ++m) { mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); } @@ -829,16 +836,29 @@ struct CollectiveMma< // Define C accumulators and A/B partitioning // + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + TiledMma tiled_mma; - auto thread_mma = tiled_mma.get_thread_slice(thread_idx); - Tensor tCsA = thread_mma.partition_A(sA); + auto mma_thread_slice = tiled_mma.get_thread_slice(thread_idx); + Tensor tCsA = mma_thread_slice.partition_A(sA); + auto mma_warpgroup_slice = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); // Allocate fragments and descriptors - Tensor tCrA_mma = thread_mma.partition_fragment_A(sA(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrA_mma = mma_thread_slice.partition_fragment_A(sA(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K,PIPE) Tensor tCrA_load = make_fragment_like(tCrA_mma); - Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) - Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCsB = mma_warpgroup_slice.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrB = mma_warpgroup_slice.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) // // Copy Atom A retiling @@ -846,7 +866,7 @@ struct CollectiveMma< auto smem_tiled_copy_A = make_tiled_copy_A(InternalSmemCopyAtomA{}, tiled_mma); auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(warp_group_thread_idx); - Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA_load); // (CPY,CPY_M,CPY_K) + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA_load); // (CPY,CPY_M,CPY_K) // Compute the max vector length that can be used to copy A. This will match the vector width of the // conversions used. It helps by allowing the compiler to convert using the same register that was used @@ -856,7 +876,7 @@ struct CollectiveMma< using A_CPY_VEC = decltype(max_common_vector(tCsA, tCrA_copy_view)); // Partition of thread -> shared and thread -> RF - auto partitioned_extra_info = partition_extra_mma_info(thread_mma, shared_tensors); + auto partitioned_extra_info = partition_extra_mma_info(mma_thread_slice, shared_tensors); auto copy_partitions_extra_info = retile_extra_mma_info(tiled_mma, partitioned_extra_info, warp_group_thread_idx); CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M @@ -1047,16 +1067,16 @@ struct CollectiveMma< int const l_coord) { if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - return cute::tuple{}; + return cute::make_tuple(); } else if constexpr (ModeHasScales) { Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) Tensor gS_mkl = get<2>(load_inputs); auto block_tma_s = mainloop_params.tma_load_scale.get_slice(cluster_local_block_id.y); - Tensor gS = gS_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gS = gS_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor tSgS = block_tma_s.partition_S(gS); // (TMA,TMA_M,TMA_K,k) - Tensor tSsS = block_tma_s.partition_D(sS); // (TMA,TMA_M,TMA_K,PIPE) + Tensor tSgS = block_tma_s.partition_S(gS); // (TMA,TMA_M,TMA_K,k) + Tensor tSsS = block_tma_s.partition_D(sS); // (TMA,TMA_M,TMA_K,PIPE) if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { return cute::make_tuple(tSgS, tSsS); } @@ -1064,10 +1084,10 @@ struct CollectiveMma< Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) Tensor gZ_mkl = get<3>(load_inputs); auto block_tma_z = mainloop_params.tma_load_zero.get_slice(cluster_local_block_id.y); - Tensor gZ = gZ_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gZ = gZ_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor tZgZ = block_tma_z.partition_S(gZ); // (TMA,TMA_M,TMA_K,k) - Tensor tZsZ = block_tma_z.partition_D(sZ); // (TMA,TMA_M,TMA_K,PIPE) + Tensor tZgZ = block_tma_z.partition_S(gZ); // (TMA,TMA_M,TMA_K,k) + Tensor tZsZ = block_tma_z.partition_D(sZ); // (TMA,TMA_M,TMA_K,PIPE) return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ); } else { @@ -1083,25 +1103,25 @@ struct CollectiveMma< template CUTLASS_DEVICE auto partition_extra_mma_info( - ThreadMma const& thread_mma, + ThreadMma const& mma_thread_slice, TensorStorage& shared_tensors) { if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - // noting to do - return cute::tuple{}; + // nothing to do + return cute::make_tuple(); } else if constexpr (ModeHasScales) { - Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE) - Tensor tCsS = thread_mma.partition_A(sS); - Tensor tCrS = make_tensor(thread_mma.partition_fragment_A(sS(_,_,Int<0>{})).shape()); + Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsS = mma_thread_slice.partition_A(sS); + Tensor tCrS = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).shape()); if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { return cute::make_tuple(tCsS, tCrS); } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE) - Tensor tCsZ = thread_mma.partition_A(sZ); - Tensor tCrZ = make_tensor(thread_mma.partition_fragment_A(sZ(_,_,Int<0>{})).shape()); + Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsZ = mma_thread_slice.partition_A(sZ); + Tensor tCrZ = make_tensor(mma_thread_slice.partition_fragment_A(sZ(_,_,Int<0>{})).shape()); return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ); } else { @@ -1122,8 +1142,8 @@ struct CollectiveMma< int const warp_group_thread_idx) { if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - // noting to do - return cute::tuple{}; + // nothing to do + return cute::make_tuple(); } else if constexpr (ModeHasScales) { auto smem_tiled_copy_S = make_tiled_copy_A(SmemCopyAtomScale{}, tiled_mma); diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp index 32589cb6a8..daaed6210b 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp @@ -223,7 +223,7 @@ struct CollectiveMma< } template - CUTLASS_HOST_DEVICE static bool + static bool can_implement( ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) { @@ -399,8 +399,22 @@ struct CollectiveMma< // Define C accumulators and A/B partitioning // + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + TiledMma tiled_mma; - auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) @@ -424,9 +438,7 @@ struct CollectiveMma< warpgroup_fence_operand(accum); // Prologue MMAs - CUTLASS_PRAGMA_UNROLL - for (int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); - prologue_mma_count > 0; --prologue_mma_count) + assert(k_tile_count >= 1); { // WAIT on smem_pipe_read until it's data is available pipeline.consumer_wait(smem_pipe_read); @@ -443,6 +455,20 @@ struct CollectiveMma< ++smem_pipe_read; --k_tile_count; } + + CUTLASS_PRAGMA_UNROLL + for (int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count) - 1; + prologue_mma_count > 0; --prologue_mma_count) + { + // WAIT on smem_pipe_read until it's data is available + pipeline.consumer_wait(smem_pipe_read); + warpgroup_arrive(); + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,_,smem_pipe_read.index()), tCrB(_,_,_,smem_pipe_read.index()), accum); + warpgroup_commit_batch(); + ++smem_pipe_read; + --k_tile_count; + } warpgroup_fence_operand(accum); // @@ -461,13 +487,8 @@ struct CollectiveMma< warpgroup_fence_operand(accum); warpgroup_arrive(); - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_,_,k_block,smem_pipe_read.index()), tCrB(_,_,k_block,smem_pipe_read.index()), accum); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,_,smem_pipe_read.index()), tCrB(_,_,_,smem_pipe_read.index()), accum); warpgroup_commit_batch(); /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp index 6ee9bf2ba0..24af314d5f 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp @@ -173,21 +173,24 @@ struct CollectiveMma< // Device side kernel params struct Params { // Assumption: StrideA is congruent with Problem_MK - using TMA_A = decltype(make_tma_copy( + using TMA_A = decltype(make_tma_copy_A_sm90( GmemTiledCopyA{}, make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), SmemLayoutA{}(_,_,cute::Int<0>{}), - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + TileShape{}, + ClusterShape{})); // Assumption: StrideB is congruent with Problem_NK - using TMA_B = decltype(make_tma_copy( + using TMA_B = decltype(make_tma_copy_B_sm90( GmemTiledCopyB{}, make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), SmemLayoutB{}(_,_,cute::Int<0>{}), - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + TileShape{}, + ClusterShape{})); TMA_A tma_load_a; TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; }; // @@ -208,26 +211,34 @@ struct CollectiveMma< Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); - typename Params::TMA_A tma_load_a = make_tma_copy( + + typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90( GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_,_,cute::Int<0>{}), - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{})); // mcast along N mode for this M load, if any - typename Params::TMA_B tma_load_b = make_tma_copy( + TileShape{}, + ClusterShape{}); + typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90( GmemTiledCopyB{}, tensor_b, SmemLayoutB{}(_,_,cute::Int<0>{}), - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + TileShape{}, + ClusterShape{}); + uint32_t transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t transaction_bytes_nk = TmaTransactionBytesNK; + uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk; + return { tma_load_a, - tma_load_b + tma_load_b, + transaction_bytes, + transaction_bytes_mk, + transaction_bytes_nk }; } template - CUTLASS_HOST_DEVICE static bool + static bool can_implement( ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) { @@ -249,9 +260,11 @@ struct CollectiveMma< static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; static constexpr int K_PIPE_MMAS = 1; - static constexpr uint32_t TmaTransactionBytes = - cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value))+ + static constexpr uint32_t TmaTransactionBytesMK = + cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytesNK = cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance CUTLASS_DEVICE @@ -294,7 +307,7 @@ struct CollectiveMma< CUTLASS_DEVICE void load( Params const& mainloop_params, - MainloopPipeline pipeline, + MainloopPipeline pipeline, PipelineState smem_pipe_write, cute::tuple const& load_inputs, BlockCoord const& blk_coord, @@ -354,8 +367,7 @@ struct CollectiveMma< // Mainloop CUTLASS_PRAGMA_NO_UNROLL - for ( ; k_tile_count > 0; --k_tile_count) - { + for ( ; k_tile_count > 0; --k_tile_count) { // LOCK smem_pipe_write for _writing_ pipeline.producer_acquire(smem_pipe_write); @@ -422,8 +434,22 @@ struct CollectiveMma< // Define C accumulators and A/B partitioning // + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + TiledMma tiled_mma; - auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) @@ -450,12 +476,9 @@ struct CollectiveMma< // Prologue GMMAs int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); - + assert(k_tile_count >= 1); tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - warpgroup_fence_operand(accum); - CUTLASS_PRAGMA_UNROLL - for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) { // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); @@ -463,6 +486,7 @@ struct CollectiveMma< int read_stage = smem_pipe_read.index(); warpgroup_arrive(); + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { @@ -476,6 +500,25 @@ struct CollectiveMma< ++smem_pipe_read; } + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + + warpgroup_fence_operand(accum); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count - 1; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + warpgroup_arrive(); + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,_,read_stage), tCrB(_,_,_,read_stage), accum); + warpgroup_commit_batch(); + + ++smem_pipe_read; + } + warpgroup_fence_operand(accum); // Mainloop GMMAs k_tile_count -= prologue_mma_count; @@ -494,13 +537,8 @@ struct CollectiveMma< int read_stage = smem_pipe_read.index(); warpgroup_fence_operand(accum); warpgroup_arrive(); - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,_,read_stage), tCrB(_,_,_,read_stage), accum); warpgroup_commit_batch(); /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp index c0e2c90792..c281d4f5f7 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp @@ -167,21 +167,24 @@ struct CollectiveMma< // Device side kernel params struct Params { // Assumption: StrideA is congruent with Problem_MK - using TMA_A = decltype(make_tma_copy( + using TMA_A = decltype(make_tma_copy_A_sm90( GmemTiledCopyA{}, make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), SmemLayoutA{}(_,_,0), - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + TileShape{}, + ClusterShape{})); // Assumption: StrideB is congruent with Problem_NK - using TMA_B = decltype(make_tma_copy( + using TMA_B = decltype(make_tma_copy_B_sm90( GmemTiledCopyB{}, make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), SmemLayoutB{}(_,_,0), - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + TileShape{}, + ClusterShape{})); TMA_A tma_load_a; TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; uint32_t mma_promotion_interval = 4; }; @@ -203,27 +206,34 @@ struct CollectiveMma< Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); - typename Params::TMA_A tma_load_a = make_tma_copy( + typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90( GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_,_,cute::Int<0>{}), - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{})); // mcast along N mode for this M load, if any - typename Params::TMA_B tma_load_b = make_tma_copy( + TileShape{}, + ClusterShape{}); + typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90( GmemTiledCopyB{}, tensor_b, SmemLayoutB{}(_,_,cute::Int<0>{}), - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + TileShape{}, + ClusterShape{}); + uint32_t transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t transaction_bytes_nk = TmaTransactionBytesNK; + uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk; + return { tma_load_a, tma_load_b, + transaction_bytes, + transaction_bytes_mk, + transaction_bytes_nk, args.mma_promotion_interval }; } template - CUTLASS_HOST_DEVICE static bool + static bool can_implement( ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) { @@ -247,9 +257,11 @@ struct CollectiveMma< static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; static constexpr int K_PIPE_MMAS = 1; - static constexpr uint32_t TmaTransactionBytes = - cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value))+ + static constexpr uint32_t TmaTransactionBytesMK = + cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytesNK = cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance CUTLASS_DEVICE @@ -321,8 +333,8 @@ struct CollectiveMma< // Partition the inputs based on the current block coordinates. auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; - Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) // Applies the mapping from block_tma_a Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) @@ -352,8 +364,7 @@ struct CollectiveMma< // Mainloop CUTLASS_PRAGMA_NO_UNROLL - for ( ; k_tile_count > 0; --k_tile_count) - { + for ( ; k_tile_count > 0; --k_tile_count) { // LOCK smem_pipe_write for _writing_ pipeline.producer_acquire(smem_pipe_write); @@ -422,9 +433,23 @@ struct CollectiveMma< // // Define C accumulators and A/B partitioning // + + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); TiledMma tiled_mma; - auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) diff --git a/include/cutlass/gemm/device/gemm_sparse_universal.h b/include/cutlass/gemm/device/gemm_sparse_universal.h new file mode 100644 index 0000000000..b7d8cecfa7 --- /dev/null +++ b/include/cutlass/gemm/device/gemm_sparse_universal.h @@ -0,0 +1,211 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/arch/mma.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm_sparse_universal.h" + +#include "cutlass/gemm/kernel/default_gemm_sparse_universal.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/device/gemm_universal_base.h" + +#include "cutlass/layout/permute.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/*! + GemmSparseUniversal is a stateful, reusable Sparse GEMM handle. Once initialized for a given GEMM computation + (problem geometry and data references), it can be reused across different GEMM problems having the + geometry. (Once initialized, details regarding problem geometry and references to workspace memory + cannot be updated.) + + The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and + batched array variants. +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassTensorOp, + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag_ = arch::Sm80, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator> +class GemmSparseUniversal : + public GemmUniversalBase< + typename kernel::DefaultGemmSparseUniversal< + ElementA_, + LayoutA_, + AlignmentA, + ElementB_, + LayoutB_, + AlignmentB, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_ + >::GemmKernel + > { + + public: + + static_assert((platform::is_same::value), + "Epilogue of Ampere sparse GEMM must be row major for now."); + + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + + using Base = GemmUniversalBase< + typename kernel::DefaultGemmSparseUniversal< + ElementA_, + LayoutA_, + AlignmentA, + ElementB_, + LayoutB_, + AlignmentB, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_ + >::GemmKernel + >; + + using Arguments = typename Base::Arguments; + using GemmKernel = typename Base::GemmKernel; + + using ElementE = typename GemmKernel::ElementE; + + using LayoutE = typename GemmKernel::LayoutE; + + static int const kAlignmentE = 128 / sizeof_bits::value; + + static int const kSparse = GemmKernel::kSparse; + static int const kMetaSizeInBits = GemmKernel::kMetaSizeInBits; + static int const kElementsPerElementE = GemmKernel::kElementsPerElementE; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/device/gemm_sparse_universal_with_absmax.h b/include/cutlass/gemm/device/gemm_sparse_universal_with_absmax.h new file mode 100644 index 0000000000..a313ddc907 --- /dev/null +++ b/include/cutlass/gemm/device/gemm_sparse_universal_with_absmax.h @@ -0,0 +1,202 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/arch/mma.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm_sparse_universal.h" + +#include "cutlass/gemm/kernel/default_gemm_sparse_universal_with_absmax.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/device/gemm_universal_base.h" + +#include "cutlass/layout/permute.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassTensorOp, + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag_ = arch::Sm80, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator> +class GemmSparseUniversalWithAbsmax : + public GemmUniversalBase< + typename kernel::DefaultGemmSparseUniversalWithAbsmax< + ElementA_, + LayoutA_, + AlignmentA, + ElementB_, + LayoutB_, + AlignmentB, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_ + >::GemmKernel + > { + + public: + + static_assert((platform::is_same::value), + "Epilogue of Ada sparse GEMM must be row major for now."); + + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + + using Base = GemmUniversalBase< + typename kernel::DefaultGemmSparseUniversalWithAbsmax< + ElementA_, + LayoutA_, + AlignmentA, + ElementB_, + LayoutB_, + AlignmentB, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_ + >::GemmKernel + >; + + using Arguments = typename Base::Arguments; + using GemmKernel = typename Base::GemmKernel; + + using ElementE = typename GemmKernel::ElementE; + + using LayoutE = typename GemmKernel::LayoutE; + + static int const kAlignmentE = 128 / sizeof_bits::value; + + static int const kSparse = GemmKernel::kSparse; + static int const kMetaSizeInBits = GemmKernel::kMetaSizeInBits; + static int const kElementsPerElementE = GemmKernel::kElementsPerElementE; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/device/gemm_universal_adapter.h b/include/cutlass/gemm/device/gemm_universal_adapter.h index 8d045d6efa..ce7fd3203d 100644 --- a/include/cutlass/gemm/device/gemm_universal_adapter.h +++ b/include/cutlass/gemm/device/gemm_universal_adapter.h @@ -338,7 +338,8 @@ class GemmUniversalAdapter< static Status run(Params& params, cudaStream_t stream = nullptr, - CudaHostAdapter *cuda_adapter = nullptr) { + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false) { CUTLASS_TRACE_HOST("GemmUniversal::run()"); dim3 const block = GemmKernel::get_block_shape(); dim3 const grid = get_grid_shape(params); @@ -361,6 +362,11 @@ class GemmUniversalAdapter< CUTLASS_ASSERT(cuda_adapter); if (cuda_adapter) { + if (launch_with_pdl) { + CUTLASS_TRACE_HOST( + "GemmUniversal::run() does not support launching with PDL and a custom cuda adapter."); + return Status::kErrorInternal; + } launch_result = cuda_adapter->launch(grid, cluster, block, @@ -378,7 +384,7 @@ class GemmUniversalAdapter< void const* kernel = (void const*) device_kernel; if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 90) { launch_result = ClusterLauncher::launch( - grid, cluster, block, smem_size, stream, kernel, kernel_params); + grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl); } } } @@ -424,12 +430,13 @@ class GemmUniversalAdapter< Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, - CudaHostAdapter *cuda_adapter = nullptr + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false ) { Status status = initialize(args, workspace, stream, cuda_adapter); if (Status::kSuccess == status) { - status = run(params_, stream, cuda_adapter); + status = run(params_, stream, cuda_adapter, launch_with_pdl); } return status; } @@ -440,20 +447,24 @@ class GemmUniversalAdapter< Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, - CudaHostAdapter *cuda_adapter = nullptr) { - return run(args, workspace, stream, cuda_adapter); + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false) { + return run(args, workspace, stream, cuda_adapter, launch_with_pdl); } /// Overload that allows a user to re-launch the same kernel without updating internal params struct. Status - run(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { - return run(params_, stream, cuda_adapter); + run( + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false) { + return run(params_, stream, cuda_adapter, launch_with_pdl); } /// Overload that allows a user to re-launch the same kernel without updating internal params struct. Status - operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { - return run(params_, stream, cuda_adapter); + operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, bool launch_with_pdl = false) { + return run(params_, stream, cuda_adapter, launch_with_pdl); } }; diff --git a/include/cutlass/gemm/device/gemm_universal_base.h b/include/cutlass/gemm/device/gemm_universal_base.h index 434f46d654..63da07b418 100644 --- a/include/cutlass/gemm/device/gemm_universal_base.h +++ b/include/cutlass/gemm/device/gemm_universal_base.h @@ -33,7 +33,6 @@ \brief The universal GEMM accommodates streamk, batched strided, and batched array variants. */ - #pragma once #if defined(__CUDACC_RTC__) @@ -271,12 +270,33 @@ class GemmUniversalBase { { CUTLASS_TRACE_HOST("GemmUniversalBase::can_implement()"); - dim3 grid = get_grid_shape(args, cuda_adapter); + if (!kEnableCudaHostAdapter || cuda_adapter) { + + dim3 grid = get_grid_shape(args, cuda_adapter); + + if (!(grid.y <= std::numeric_limits::max() && + grid.z <= std::numeric_limits::max())) + { + return Status::kErrorInvalidProblem; + } + } + else { + // + // With a null host adapter, a conservative grid shape is computed and required to conform to CUDA grid + // dimension limits. + // + + int64_t logicalGridM = (int64_t(args.problem_size.m()) + ThreadblockShape::kM - 1) / ThreadblockShape::kM; + int64_t logicalGridN = (int64_t(args.problem_size.n()) + ThreadblockShape::kN - 1) / ThreadblockShape::kN; + int32_t logicalGridL = args.batch_count; + + if ((int64_t(std::numeric_limits::max()) < logicalGridM) || + (int64_t(std::numeric_limits::max()) < logicalGridN) || + (int32_t(std::numeric_limits::max()) < logicalGridL)) { + + return Status::kErrorInvalidProblem; + } - if (!(grid.y <= std::numeric_limits::max() && - grid.z <= std::numeric_limits::max())) - { - return Status::kErrorInvalidProblem; } return GemmKernel::can_implement(args); diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index 3694d0a87b..2e820b6136 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -68,6 +68,24 @@ enum class KernelInputTransformType { ////////////////////////////////////////////////////////////////////////////// +namespace kernel::detail { + +// Has_SwapAB::value will be true only if: +// class T has member SwapAB and T::SwapAB is true +template +struct Has_SwapAB { static constexpr bool value = false; }; + +template +struct Has_SwapAB > +{ static constexpr bool value = T::SwapAB; }; + +template +static constexpr bool Has_SwapAB_v = Has_SwapAB::value; + +} // namespace kernel::detail + +////////////////////////////////////////////////////////////////////////////// + // // Kernel schedule policies (the base class tags, one for each kernel layer file) // @@ -137,12 +155,15 @@ struct MainloopSm80CpAsyncUnpredicated { }; // n-buffer in smem (cp.async), pipelined with registers, with predicated gmem loads -template +template< + int Stages_, + class ClusterShape_ = Shape<_1,_1,_1> +> struct MainloopSm80CpAsync { constexpr static int Stages = Stages_; - using ArchTag = arch::Sm80; + using ArchTag = cute::conditional_t<(size(ClusterShape_{}) > 1), arch::Sm90, arch::Sm80>; using Schedule = KernelMultistage; - using ClusterShape = Shape<_1,_1,_1>; + using ClusterShape = ClusterShape_; }; // n-buffer in smem (cp.async), pipelined with Hopper GMMA, with predicated gmem loads, warp specialized dynamic schedule diff --git a/include/cutlass/gemm/kernel/default_gemm_sparse_universal.h b/include/cutlass/gemm/kernel/default_gemm_sparse_universal.h new file mode 100644 index 0000000000..250a0e7b29 --- /dev/null +++ b/include/cutlass/gemm/kernel/default_gemm_sparse_universal.h @@ -0,0 +1,141 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level Sparse GEMM definitions combine threadblock-scoped matrix multiply-add with + the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are + accommodated by exchanging A and B operands and assuming transposed layouts. Partial + specializations here choose 'device::GemmTransposed' to implement this functionality. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/complex.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/kernel/gemm_sparse_universal.h" +#include "cutlass/gemm/kernel/default_gemm_sparse.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Real-valued GEMM kernels +// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator +> +struct DefaultGemmSparseUniversal { + + using DefaultGemmKernel = typename kernel::DefaultSparseGemm< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + true, + Operator + >::GemmKernel; + + /// Select kernel by ThreadblockSwizzle's support for StreamkFeature + using GemmKernel = kernel::GemmSparseUniversal< + typename DefaultGemmKernel::Mma, + typename DefaultGemmKernel::Epilogue, + ThreadblockSwizzle>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/default_gemm_sparse_universal_with_absmax.h b/include/cutlass/gemm/kernel/default_gemm_sparse_universal_with_absmax.h new file mode 100644 index 0000000000..0193909217 --- /dev/null +++ b/include/cutlass/gemm/kernel/default_gemm_sparse_universal_with_absmax.h @@ -0,0 +1,144 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level Sparse GEMM definitions combine threadblock-scoped matrix multiply-add with + the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are + accommodated by exchanging A and B operands and assuming transposed layouts. Partial + specializations here choose 'device::GemmTransposed' to implement this functionality. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/complex.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_with_absmax.h" +#include "cutlass/gemm/kernel/gemm_sparse_universal_with_absmax.h" +#include "cutlass/gemm/kernel/default_gemm_sparse.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Real-valued GEMM kernels +// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator +> +struct DefaultGemmSparseUniversalWithAbsmax { + + using GemmBase = typename DefaultSparseGemm< + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementC, LayoutC, ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + false, // SplitKSerial + Operator + >::GemmKernel; + + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithAbsMax< + typename GemmBase::Epilogue::Shape, + typename GemmBase::Epilogue::WarpMmaOperator, + GemmBase::Epilogue::kPartitionsK, + ElementC, + typename EpilogueOutputOp::ElementAuxOutput, + ElementC, + EpilogueOutputOp, + GemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + using GemmKernel = kernel::GemmSparseUniversalWithAbsmax< + typename GemmBase::Mma, Epilogue, ThreadblockSwizzle>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/default_gemm_sparse_with_visitor.h b/include/cutlass/gemm/kernel/default_gemm_sparse_with_visitor.h index ec93f8fc92..9d7f2c6f7e 100644 --- a/include/cutlass/gemm/kernel/default_gemm_sparse_with_visitor.h +++ b/include/cutlass/gemm/kernel/default_gemm_sparse_with_visitor.h @@ -167,7 +167,7 @@ struct DefaultSparseGemmWithVisitor::ThreadblockMma; - static constexpr int kAlignmentC = 128 / sizeof_bits::value;; + static constexpr int kAlignmentC = 128 / sizeof_bits::value; using ElementEpilogue = ElementAccumulator; static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; diff --git a/include/cutlass/gemm/kernel/default_gemm_universal.h b/include/cutlass/gemm/kernel/default_gemm_universal.h index 92466af7ad..ed7951be58 100644 --- a/include/cutlass/gemm/kernel/default_gemm_universal.h +++ b/include/cutlass/gemm/kernel/default_gemm_universal.h @@ -30,10 +30,10 @@ **************************************************************************************************/ /*! \file - \brief + \brief Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with the appropriate threadblock-scoped epilogue. - + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are accommodated by exchanging A and B operands and assuming transposed layouts. Partial specializations here choose 'device::GemmTransposed' to implement this functionality. diff --git a/include/cutlass/gemm/kernel/gemm_sparse_universal.h b/include/cutlass/gemm/kernel/gemm_sparse_universal.h new file mode 100644 index 0000000000..c5420c72d9 --- /dev/null +++ b/include/cutlass/gemm/kernel/gemm_sparse_universal.h @@ -0,0 +1,804 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/fast_math.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/params_universal_base.h" + +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { +namespace detail { + +template < + typename LayoutA, + typename LayoutB, + typename LayoutC, + typename LayoutE +> +struct SparseUniversalArgumentsBase : UniversalArgumentsBase { + // + // Data members + // + + void const * ptr_A; + void const * ptr_B; + void const * ptr_C; + void * ptr_D; + void const * ptr_E; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_C; + int64_t batch_stride_E; + + typename LayoutA::Stride::LongIndex lda; + typename LayoutB::Stride::LongIndex ldb; + typename LayoutC::Stride::LongIndex ldc; + typename LayoutC::Stride::LongIndex ldd; + typename LayoutE::Stride::LongIndex lde; + + // + // Methods + // + + SparseUniversalArgumentsBase(): + ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr), ptr_E(nullptr) + {} + + /// constructs an arguments structure + SparseUniversalArgumentsBase( + GemmUniversalMode mode, + GemmCoord problem_size, + int batch_count, + void const * ptr_A, + void const * ptr_B, + void const * ptr_C, + void * ptr_D, + void const * ptr_E, + int64_t batch_stride_A, + int64_t batch_stride_B, + int64_t batch_stride_C, + int64_t batch_stride_D, + int64_t batch_stride_E, + typename LayoutA::Stride::LongIndex lda, + typename LayoutB::Stride::LongIndex ldb, + typename LayoutC::Stride::LongIndex ldc, + typename LayoutC::Stride::LongIndex ldd, + typename LayoutC::Stride::LongIndex lde) + : + UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), + ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), ptr_E(ptr_E), + batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), + batch_stride_E(batch_stride_E), + lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), lde(lde) + { + CUTLASS_TRACE_HOST("SparseUniversalArgumentsBase::Arguments() - problem_size: " << problem_size); + } +}; + +template < + typename Mma, + typename Epilogue, + typename Arguments, + typename ThreadblockSwizzle, + typename ThreadblockShape, + typename ElementA, + typename ElementB, + typename ElementC, + typename LayoutA, + typename LayoutB +> +struct SparseUniversalParamsBase : UniversalParamsBase< + ThreadblockSwizzle, + ThreadblockShape, + ElementA, + ElementB, + ElementC, + LayoutA, + LayoutB> { + using ParamsBase = UniversalParamsBase< + ThreadblockSwizzle, + ThreadblockShape, + ElementA, + ElementB, + ElementC, + LayoutA, + LayoutB>; + + // + // Data members + // + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Mma::IteratorE::Params params_E; + + void * ptr_A; + void * ptr_B; + void * ptr_C; + void * ptr_D; + void * ptr_E; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_C; + int64_t batch_stride_E; + + // + // Host dispatch API + // + + /// Default constructor + SparseUniversalParamsBase() = default; + + /// Constructor + SparseUniversalParamsBase( + Arguments const &args, /// GEMM application arguments + int device_sms, /// Number of SMs on the device + int sm_occupancy) /// Kernel SM occupancy (in thread blocks) + : + ParamsBase(args, device_sms, sm_occupancy), + params_A(args.lda), + params_B(args.ldb), + params_C(args.ldc), + params_D(args.ldd), + params_E(args.lde), + ptr_A(const_cast(args.ptr_A)), + ptr_B(const_cast(args.ptr_B)), + ptr_C(const_cast(args.ptr_C)), + ptr_D(args.ptr_D), + ptr_E(const_cast(args.ptr_E)), + batch_stride_A(args.batch_stride_A), + batch_stride_B(args.batch_stride_B), + batch_stride_C(args.batch_stride_C), + batch_stride_E(args.batch_stride_E) + {} + + /// Lightweight update given a subset of arguments. + void update(Arguments const &args) + { + CUTLASS_TRACE_HOST("SparseUniversalParamsBase::update()"); + + // Update input/output pointers + this->ptr_A = const_cast(args.ptr_A); + this->ptr_B = const_cast(args.ptr_B); + this->ptr_C = const_cast(args.ptr_C); + this->ptr_D = args.ptr_D; + this->ptr_E = const_cast(args.ptr_E); + + this->batch_stride_A = args.batch_stride_A; + this->batch_stride_B = args.batch_stride_B; + this->batch_stride_C = args.batch_stride_C; + this->batch_stride_D = args.batch_stride_D; + this->batch_stride_E = args.batch_stride_E; + } +}; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function +> +class GemmSparseUniversal { +public: + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + static int const kSparse = Mma::kSparse; + static int const kMetaSizeInBits = Mma::kMetaSizeInBits; + static int const kMaxID2 = Mma::kMaxID2; + static int const kElementsPerElementE = Mma::kElementsPerElementE; + + using ElementE = typename Mma::ElementE; + using LayoutE = typename Mma::LayoutE; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); + + // + // Structures + // + + /// Argument structure + struct Arguments : detail::SparseUniversalArgumentsBase< + LayoutA, + LayoutB, + LayoutC, + LayoutE + > { + using Base = detail::SparseUniversalArgumentsBase< + LayoutA, + LayoutB, + LayoutC, + LayoutE + >; + + typename EpilogueOutputOp::Params epilogue; + + Arguments() {} + + /// constructs an arguments structure + Arguments( + GemmUniversalMode mode, + GemmCoord problem_size, + int batch_count, + typename EpilogueOutputOp::Params epilogue, + void const * ptr_A, + void const * ptr_B, + void const * ptr_C, + void * ptr_D, + void const * ptr_E, + int64_t batch_stride_A, + int64_t batch_stride_B, + int64_t batch_stride_C, + int64_t batch_stride_D, + int64_t batch_stride_E, + typename LayoutA::Stride::LongIndex lda, + typename LayoutB::Stride::LongIndex ldb, + typename LayoutC::Stride::LongIndex ldc, + typename LayoutC::Stride::LongIndex ldd, + typename LayoutC::Stride::LongIndex lde) + : + Base( + mode, problem_size, batch_count, + ptr_A, ptr_B, ptr_C, ptr_D, ptr_E, + batch_stride_A, batch_stride_B, batch_stride_C, batch_stride_D, batch_stride_E, + lda, ldb, ldc, ldd, lde + ), + epilogue(epilogue) + { + CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); + } + }; + + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params : detail::SparseUniversalParamsBase< + Mma, + Epilogue, + Arguments, + ThreadblockSwizzle, + ThreadblockShape, + ElementA, + ElementB, + ElementC, + LayoutA, + LayoutB> + { + using ParamsBase = detail::SparseUniversalParamsBase< + Mma, + Epilogue, + Arguments, + ThreadblockSwizzle, + ThreadblockShape, + ElementA, + ElementB, + ElementC, + LayoutA, + LayoutB>; + + typename EpilogueOutputOp::Params output_op; + + // + // Host dispatch API + // + + /// Default constructor + Params() = default; + + /// Constructor + Params( + Arguments const &args, /// GEMM application arguments + int device_sms, /// Number of SMs on the device + int sm_occupancy) /// Kernel SM occupancy (in thread blocks) + : + ParamsBase(args, device_sms, sm_occupancy), + output_op(args.epilogue) + {} + + /// Lightweight update given a subset of arguments. + void update(Arguments const &args) + { + CUTLASS_TRACE_HOST("GemmUniversal::Params::update()"); + + // Update input/output pointers + this->ptr_A = const_cast(args.ptr_A); + this->ptr_B = const_cast(args.ptr_B); + this->ptr_C = const_cast(args.ptr_C); + this->ptr_D = args.ptr_D; + this->ptr_E = const_cast(args.ptr_E); + + this->batch_stride_A = args.batch_stride_A; + this->batch_stride_B = args.batch_stride_B; + this->batch_stride_C = args.batch_stride_C; + this->batch_stride_D = args.batch_stride_D; + this->batch_stride_E = args.batch_stride_E; + + output_op = args.epilogue; + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + +public: + + // + // Host dispatch API + // + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size, + GemmUniversalMode mode, + int split_k_count) + { + CUTLASS_TRACE_HOST("GemmUniversal::can_implement()"); + + static int const kAlignmentA = (cute::is_same>::value) + ? 32 + : (cute::is_same>::value) + ? 64 + : Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = (cute::is_same>::value) + ? 32 + : (cute::is_same>::value) + ? 64 + : Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = (cute::is_same>::value) + ? 32 + : (cute::is_same>::value) + ? 64 + : Epilogue::OutputTileIterator::kElementsPerAccess; + + static int const kAlignmentE = Mma::IteratorE::AccessType::kElements; + + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; + bool isEMisaligned = false; + + if (cute::is_same::value) { + isAMisaligned = (problem_size.k() / kSparse) % kAlignmentA; + } else if (cute::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if (cute::is_same>::value + || cute::is_same>::value) { + isAMisaligned = (problem_size.k() / kSparse) % kAlignmentA; + } + + if (cute::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (cute::is_same::value) { + isBMisaligned = (problem_size.k() / kSparse) % kAlignmentB; + } else if (cute::is_same>::value + || cute::is_same>::value) { + isBMisaligned = (problem_size.k() / kSparse) % kAlignmentB; + } + + if (cute::is_same::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } else if (cute::is_same::value) { + isCMisaligned = problem_size.m() % kAlignmentC; + } else if (cute::is_same>::value + || cute::is_same>::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + isEMisaligned = (problem_size.m() % kAlignmentE) + || ((problem_size.k() / kSparse) % kAlignmentE); + + // The k dimension has to be the multiple of the Threadblock k because out + // of bound meta data would be initialized to 0 by acync.zfill but 0 is not + // a valid meta data. + if (problem_size.k() % Mma::Shape::kK) { + isEMisaligned = true; + } + + if (mode == GemmUniversalMode::kGemm + || mode == GemmUniversalMode::kGemmSplitKParallel) { + if ((problem_size.k() / split_k_count) % Mma::Shape::kK) { + isEMisaligned = true; + } + } + + // M dimension has to be multiple of 32 (sparse float) or 16 (sparse int) + // because of the row reordering of operand E + static int const kAlignmentM = (sizeof(ElementE) == 2) ? 32 : 16; + + if (problem_size.m() % kAlignmentM) { + isEMisaligned = true; + } + + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); + return Status::kErrorMisalignedOperand; + } + + if (isEMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for E operand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return can_implement(args.problem_size, args.mode, args.batch_count); + } + +public: + + // + // Device-only API + // + + // Factory invocation + CUTLASS_DEVICE + static void invoke( + Params const ¶ms, + SharedStorage &shared_storage) + { + GemmSparseUniversal op; + op(params, shared_storage); + } + + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + ThreadblockSwizzle threadblock_swizzle; + run_with_swizzle(params, shared_storage, threadblock_swizzle); + } + + /// Executes one GEMM with an externally-provided swizzling function + CUTLASS_DEVICE + void run_with_swizzle(Params const ¶ms, SharedStorage &shared_storage, ThreadblockSwizzle& threadblock_swizzle) { + + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA *ptr_A = static_cast(params.ptr_A); + ElementB *ptr_B = static_cast(params.ptr_B); + ElementE *ptr_E = static_cast(params.ptr_E); + + // + // Fetch pointers based on mode. + // + if (params.mode == GemmUniversalMode::kGemm || + params.mode == GemmUniversalMode::kGemmSplitKParallel) { + + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { + + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A += threadblock_tile_offset.k() * params.batch_stride_A / kSparse; + ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; + ptr_E += threadblock_tile_offset.k() * params.batch_stride_E / kSparse; + } + else if (params.mode == GemmUniversalMode::kArray) { + ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; + ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; + ptr_E = static_cast(params.ptr_E)[threadblock_tile_offset.k()]; + } + + __syncthreads(); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k / kSparse, + }; + + cutlass::MatrixCoord tb_offset_B{ + offset_k, + threadblock_tile_offset.n() * Mma::Shape::kN + }; + + cutlass::MatrixCoord tb_offset_E{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k / kSparse / kElementsPerElementE, + }; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + ptr_A, + {params.problem_size.m(), problem_size_k / kSparse}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, + ptr_B, + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B); + + typename Mma::IteratorE iterator_E( + params.params_E, + ptr_E, + {params.problem_size.m(), problem_size_k / kSparse / kElementsPerElementE}, + thread_idx, + tb_offset_E); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = canonical_warp_idx_sync(); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma( + gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + iterator_E, + accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + ElementC *ptr_C = static_cast(params.ptr_C); + ElementC *ptr_D = static_cast(params.ptr_D); + + // + // Fetch pointers based on mode. + // + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + if (params.mode == GemmUniversalMode::kGemm) { + + // If performing a reduction via split-K, fetch the initial synchronization + if (params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + } + else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { + ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; + ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; + } + else if (params.mode == GemmUniversalMode::kArray) { + ptr_C = static_cast(params.ptr_C)[threadblock_tile_offset.k()]; + ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, + ptr_C, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + ptr_D, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + } + + + // Execute the epilogue operator to update the destination tensor. + epilogue( + output_op, + iterator_D, + accumulators, + iterator_C); + + // + // Release the semaphore + // + + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/gemm_sparse_universal_with_absmax.h b/include/cutlass/gemm/kernel/gemm_sparse_universal_with_absmax.h new file mode 100644 index 0000000000..47b76a171d --- /dev/null +++ b/include/cutlass/gemm/kernel/gemm_sparse_universal_with_absmax.h @@ -0,0 +1,609 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/fast_math.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/params_universal_base.h" +#include "cutlass/gemm/kernel/gemm_sparse_universal.h" + +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function +> +class GemmSparseUniversalWithAbsmax { +public: + using Base = GemmSparseUniversal; + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + static int const kSparse = Mma::kSparse; + static int const kMetaSizeInBits = Mma::kMetaSizeInBits; + static int const kMaxID2 = Mma::kMaxID2; + static int const kElementsPerElementE = Mma::kElementsPerElementE; + + using ElementE = typename Mma::ElementE; + using LayoutE = typename Mma::LayoutE; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + using ElementAux = typename Epilogue::AuxOutputTileIterator::Element; + using LayoutAux = typename Epilogue::AuxOutputTileIterator::Layout; + using ElementVector = typename Epilogue::ElementVector; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); + + // + // Structures + // + + /// Argument structure + struct Arguments : detail::SparseUniversalArgumentsBase< + LayoutA, + LayoutB, + LayoutC, + LayoutE + > { + using Base = detail::SparseUniversalArgumentsBase< + LayoutA, + LayoutB, + LayoutC, + LayoutE + >; + + void const* ptr_Aux; + void const* ptr_Vector; + int64_t batch_stride_Aux; + int64_t batch_stride_Vector; + typename LayoutAux::Stride::LongIndex ldaux; + int64_t ldvector; + + typename EpilogueOutputOp::Params epilogue; + + Arguments() {} + + /// constructs an arguments structure + Arguments( + GemmUniversalMode mode, + GemmCoord problem_size, + int batch_count, + typename EpilogueOutputOp::Params epilogue, + void const * ptr_A, + void const * ptr_B, + void const * ptr_C, + void * ptr_D, + void const * ptr_E, + void const * ptr_Aux, + void const * ptr_Vector, + int64_t batch_stride_A, + int64_t batch_stride_B, + int64_t batch_stride_C, + int64_t batch_stride_D, + int64_t batch_stride_E, + int64_t batch_stride_Aux, + int64_t batch_stride_Vector, + typename LayoutA::Stride::LongIndex lda, + typename LayoutB::Stride::LongIndex ldb, + typename LayoutC::Stride::LongIndex ldc, + typename LayoutC::Stride::LongIndex ldd, + typename LayoutC::Stride::LongIndex lde, + typename LayoutAux::Stride::LongIndex ldaux, + int64_t ldvector + ) + : + Base( + mode, problem_size, batch_count, + ptr_A, ptr_B, ptr_C, ptr_D, ptr_E, + batch_stride_A, batch_stride_B, batch_stride_C, batch_stride_D, batch_stride_E, + lda, ldb, ldc, ldd, lde + ), + ptr_Aux(ptr_Aux), + ptr_Vector(ptr_Vector), + batch_stride_Aux(batch_stride_Aux), + batch_stride_Vector(batch_stride_Vector), + ldaux(ldaux), + ldvector(ldvector), + epilogue(epilogue) + { } + }; + + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params : detail::SparseUniversalParamsBase< + Mma, + Epilogue, + Arguments, + ThreadblockSwizzle, + ThreadblockShape, + ElementA, + ElementB, + ElementC, + LayoutA, + LayoutB> + { + using ParamsBase = detail::SparseUniversalParamsBase< + Mma, + Epilogue, + Arguments, + ThreadblockSwizzle, + ThreadblockShape, + ElementA, + ElementB, + ElementC, + LayoutA, + LayoutB>; + + typename Epilogue::AuxOutputTileIterator::Params params_Aux; + int64_t ldvector; + + void* ptr_Aux; + void* ptr_Vector; + + int64_t batch_stride_Aux; + int64_t batch_stride_Vector; + typename EpilogueOutputOp::Params output_op; + + // + // Host dispatch API + // + + /// Default constructor + Params() = default; + + /// Constructor + Params( + Arguments const &args, /// GEMM application arguments + int device_sms, /// Number of SMs on the device + int sm_occupancy) /// Kernel SM occupancy (in thread blocks) + : + ParamsBase(args, device_sms, sm_occupancy), + params_Aux(args.ldaux), + ldvector(args.ldvector), + ptr_Aux(const_cast(args.ptr_Aux)), + ptr_Vector(const_cast(args.ptr_Vector)), + batch_stride_Aux(args.batch_stride_Aux), + batch_stride_Vector(args.batch_stride_Vector), + output_op(args.epilogue) + {} + + /// Lightweight update given a subset of arguments. + void update(Arguments const &args) + { + CUTLASS_TRACE_HOST("GemmUniversal::Params::update()"); + + // Update input/output pointers + this->ptr_A = const_cast(args.ptr_A); + this->ptr_B = const_cast(args.ptr_B); + this->ptr_C = const_cast(args.ptr_C); + this->ptr_D = args.ptr_D; + this->ptr_E = const_cast(args.ptr_E); + ptr_Aux = const_cast(args.ptr_Aux); + ptr_Vector = const_cast(args.ptr_Vector); + + this->batch_stride_A = args.batch_stride_A; + this->batch_stride_B = args.batch_stride_B; + this->batch_stride_C = args.batch_stride_C; + this->batch_stride_D = args.batch_stride_D; + this->batch_stride_E = args.batch_stride_E; + this->batch_stride_Aux = args.batch_stride_Aux; + batch_stride_Vector = args.batch_stride_Vector; + + output_op = args.epilogue; + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + +public: + + // + // Host dispatch API + // + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size, + GemmUniversalMode mode, + int split_k_count) { + return Base::can_implement(problem_size, mode, split_k_count); + } + + static Status can_implement(Arguments const &args) { + return can_implement(args.problem_size, args.mode, args.batch_count); + } + +public: + + // + // Device-only API + // + + // Factory invocation + CUTLASS_DEVICE + static void invoke( + Params const ¶ms, + SharedStorage &shared_storage) + { + GemmSparseUniversalWithAbsmax op; + op(params, shared_storage); + } + + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + ThreadblockSwizzle threadblock_swizzle; + run_with_swizzle(params, shared_storage, threadblock_swizzle); + } + + /// Executes one GEMM with an externally-provided swizzling function + CUTLASS_DEVICE + void run_with_swizzle(Params const ¶ms, SharedStorage &shared_storage, ThreadblockSwizzle& threadblock_swizzle) { + + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA *ptr_A = static_cast(params.ptr_A); + ElementB *ptr_B = static_cast(params.ptr_B); + ElementE *ptr_E = static_cast(params.ptr_E); + + // + // Fetch pointers based on mode. + // + if (params.mode == GemmUniversalMode::kGemm || + params.mode == GemmUniversalMode::kGemmSplitKParallel) { + + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { + + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A += threadblock_tile_offset.k() * params.batch_stride_A / kSparse; + ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; + ptr_E += threadblock_tile_offset.k() * params.batch_stride_E / kSparse; + } + else if (params.mode == GemmUniversalMode::kArray) { + ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; + ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; + ptr_E = static_cast(params.ptr_E)[threadblock_tile_offset.k()]; + } + + __syncthreads(); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k / kSparse, + }; + + cutlass::MatrixCoord tb_offset_B{ + offset_k, + threadblock_tile_offset.n() * Mma::Shape::kN + }; + + cutlass::MatrixCoord tb_offset_E{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k / kSparse / kElementsPerElementE, + }; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + ptr_A, + {params.problem_size.m(), problem_size_k / kSparse}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, + ptr_B, + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B); + + typename Mma::IteratorE iterator_E( + params.params_E, + ptr_E, + {params.problem_size.m(), problem_size_k / kSparse / kElementsPerElementE}, + thread_idx, + tb_offset_E); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = canonical_warp_idx_sync(); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma( + gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + iterator_E, + accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + ElementC *ptr_C = static_cast(params.ptr_C); + ElementC *ptr_D = static_cast(params.ptr_D); + ElementAux * ptr_Aux = static_cast(params.ptr_Aux); + ElementVector * ptr_Vector = static_cast(params.ptr_Vector); + + // + // Fetch pointers based on mode. + // + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + if (params.mode == GemmUniversalMode::kGemm) { + + // If performing a reduction via split-K, fetch the initial synchronization + if (params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + } + else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { + ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; + ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; + if (ptr_Aux) { + ptr_Aux += threadblock_tile_offset.k() * params.batch_stride_Aux; + } + if (ptr_Vector) { + ptr_Vector += threadblock_tile_offset.k() * params.batch_stride_Vector; + } + } + else if (params.mode == GemmUniversalMode::kArray) { + ptr_C = static_cast(params.ptr_C)[threadblock_tile_offset.k()]; + ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; + if (ptr_Aux) { + ptr_Aux = static_cast(params.ptr_Aux)[threadblock_tile_offset.k()]; + } + if (ptr_Vector) { + ptr_Vector = static_cast(params.ptr_Vector)[threadblock_tile_offset.k()]; + } + } + + // Move to appropriate location for this output tile + if (ptr_Vector) { + ptr_Vector += threadblock_offset.column() + threadblock_tile_offset.m() * params.ldvector; + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, + ptr_C, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + ptr_D, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to auxiliary destination tensor. + typename Epilogue::AuxOutputTileIterator iterator_Aux( + params.params_Aux, + // Only the final block writes the auxiliary tensor + ((params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) && + (params.grid_tiled_shape.k() != threadblock_tile_offset.k() + 1)) + ? nullptr + : ptr_Aux, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + } + + + // Execute the epilogue operator to update the destination tensor. + epilogue( + output_op, + // Only the final block uses Vector + ((params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) && + (params.grid_tiled_shape.k() != threadblock_tile_offset.k() + 1)) + ? nullptr + : ptr_Vector, + iterator_D, + accumulators, + iterator_C, + iterator_Aux, + params.problem_size.mn(), + threadblock_offset); + + // + // Release the semaphore + // + + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/gemm_universal.hpp b/include/cutlass/gemm/kernel/gemm_universal.hpp index d1ad5288f2..b682be867d 100644 --- a/include/cutlass/gemm/kernel/gemm_universal.hpp +++ b/include/cutlass/gemm/kernel/gemm_universal.hpp @@ -30,40 +30,13 @@ **************************************************************************************************/ #pragma once +#include "cutlass/gemm/kernel/gemm_universal_decl.h" #include "cutlass/gemm/kernel/tile_scheduler.hpp" //////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::kernel { -//////////////////////////////////////////////////////////////////////////////// - -/* - * Stateless universal device GEMM kernel type that treats GEMM as - * a composition of a collective mainloop and a collective epilogue. - * - * Supports both the 2.x and 3.x APIs based on whether the first type is - * a cute::tuple<> or not. - * 2.x API implementation: cutlass/gemm/kernel/gemm_universal.h - * 3.x API implementation: cutlass/gemm/kernel/gemm_*.hpp - * - * In the following declaration, the name preceding the 'Or' refers to - * 3.x API type argument order, and the name succeeding the 'Or' refers to - * 2.x API type argument order. Template arguments without two names - * belong to the 3.x API only. -**/ -template < - class ProblemShapeOrThreadblockMma_, // (m, n, k) or (m, n, k, l) - class CollectiveMainloopOrEpilogue_, - class CollectiveEpilogueOrThreadblockSwizzle_, - class TileScheduler_ = void, - class Enable = void -> -class GemmUniversal; - - -//////////////////////////////////////////////////////////////////////////////// - // In cases where ProblemShape is not a tuple, this is used to check if the // underlying problem shape type is aliased within or not. // Used for dispatching GemmUniversal to 2.x API or 3.x API diff --git a/include/cutlass/gemm/kernel/gemm_universal_decl.h b/include/cutlass/gemm/kernel/gemm_universal_decl.h new file mode 100644 index 0000000000..73426db5b7 --- /dev/null +++ b/include/cutlass/gemm/kernel/gemm_universal_decl.h @@ -0,0 +1,61 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +namespace cutlass::gemm::kernel { + + +/* + * Stateless universal device GEMM kernel type that treats GEMM as + * a composition of a collective mainloop and a collective epilogue. + * + * Supports both the 2.x and 3.x APIs based on whether the first type is + * a cute::tuple<> or not. + * 2.x API implementation: cutlass/gemm/kernel/gemm_universal.h + * 3.x API implementation: cutlass/gemm/kernel/gemm_*.hpp + * + * In the following declaration, the name preceding the 'Or' refers to + * 3.x API type argument order, and the name succeeding the 'Or' refers to + * 2.x API type argument order. Template arguments without two names + * belong to the 3.x API only. +**/ +template < + class ProblemShapeOrThreadblockMma_, // (m, n, k) or (m, n, k, l) + class CollectiveMainloopOrEpilogue_, + class CollectiveEpilogueOrThreadblockSwizzle_, + class TileScheduler_ = void, + class Enable = void +> +class GemmUniversal; + + +} // namespace cutlass::gemm::kernel + diff --git a/include/cutlass/gemm/kernel/sm70_gemm.hpp b/include/cutlass/gemm/kernel/sm70_gemm.hpp index 954c9cbb6c..eb271b5184 100644 --- a/include/cutlass/gemm/kernel/sm70_gemm.hpp +++ b/include/cutlass/gemm/kernel/sm70_gemm.hpp @@ -196,10 +196,7 @@ static_assert(is_valid_tile_scheduler, "SM70 kernel does not support specializin // Separate out problem shape for convenience // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); - auto M = get<0>(problem_shape_MNKL); - auto N = get<1>(problem_shape_MNKL); - auto K = get<2>(problem_shape_MNKL); - auto L = get<3>(problem_shape_MNKL); + auto [M,N,K,L] = problem_shape_MNKL; // Preconditions static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp index 1c314aadac..76ff6d81ef 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp @@ -40,8 +40,9 @@ #include "cutlass/epilogue/collective/detail.hpp" #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/kernel/gemm_universal_decl.h" #include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" #include "cutlass/pipeline/pipeline.hpp" #include "cute/tensor.hpp" #include "cutlass/trace.h" @@ -79,9 +80,9 @@ class GemmUniversal< using ArchTag = typename CollectiveMainloop::ArchTag; using ElementA = typename CollectiveMainloop::ElementA; using StrideA = typename CollectiveMainloop::StrideA; - using UnderlyingStrideA = typename CollectiveMainloop::UnderlyingStrideA; + using InternalStrideA = typename CollectiveMainloop::InternalStrideA; using ElementB = typename CollectiveMainloop::ElementB; - using UnderlyingStrideB = typename CollectiveMainloop::UnderlyingStrideB; + using InternalStrideB = typename CollectiveMainloop::InternalStrideB; using StrideB = typename CollectiveMainloop::StrideB; using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; using Schedule = typename DispatchPolicy::Schedule; @@ -94,18 +95,18 @@ class GemmUniversal< using CollectiveEpilogue = CollectiveEpilogue_; using ElementC = typename CollectiveEpilogue::ElementC; using StrideC = typename CollectiveEpilogue::StrideC; - using UnderlyingStrideC = typename CollectiveEpilogue::UnderlyingStrideC; + using InternalStrideC = typename CollectiveEpilogue::InternalStrideC; using ElementD = typename CollectiveEpilogue::ElementD; using StrideD = typename CollectiveEpilogue::StrideD; - using UnderlyingStrideD = typename CollectiveEpilogue::UnderlyingStrideD; + using InternalStrideD = typename CollectiveEpilogue::InternalStrideD; using EpilogueArguments = typename CollectiveEpilogue::Arguments; using EpilogueParams = typename CollectiveEpilogue::Params; static_assert(ArchTag::kMinComputeCapability >= 90); static_assert(cute::is_void_v, "Ptr-Array Cooperative and Grouped Gemm Cooperative kernel only supports the default scheduler."); - - static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; using TileScheduler = cute::conditional_t { using MainloopTensorMapStorage = typename CollectiveMainloop::TensorMapStorage; + using EpilogueTensorMapStorage = typename CollectiveEpilogue::TensorMapStorage; + alignas(128) MainloopTensorMapStorage mainloop; + alignas(128) EpilogueTensorMapStorage epilogue; } tensormaps; }; @@ -211,7 +215,7 @@ class GemmUniversal< workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); void* epilogue_workspace = workspace_ptr + workspace_offset; - workspace_offset += CollectiveEpilogue::get_workspace_size(problem_shapes, args.epilogue); + workspace_offset += CollectiveEpilogue::get_workspace_size(problem_shapes, args.epilogue, args.hw_info.sm_count); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); void* mainloop_workspace = workspace_ptr + workspace_offset; @@ -230,7 +234,7 @@ class GemmUniversal< else { scheduler = TileScheduler::to_underlying_arguments( problem_shapes.get_host_problem_shape(), TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles); - } + } return { args.mode, @@ -243,8 +247,7 @@ class GemmUniversal< }; } - CUTLASS_HOST_DEVICE static - bool + static bool can_implement(Arguments const& args) { bool implementable = true; if constexpr (IsGroupedGemmKernel) { @@ -272,7 +275,7 @@ class GemmUniversal< args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, args.hw_info.sm_count); workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); // Get SM count if needed, otherwise use user supplied SM count @@ -298,7 +301,7 @@ class GemmUniversal< constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); status = TileScheduler::template initialize_workspace( - args.scheduler, workspace_ptr + workspace_offset, stream, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + args.scheduler, workspace_ptr + workspace_offset, stream, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles, cuda_adapter); workspace_offset += TileScheduler::template get_workspace_size( args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); @@ -307,10 +310,10 @@ class GemmUniversal< } status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, args.hw_info.sm_count); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - status = CollectiveMainloop::initialize_workspace(args.problem_shape, args.mainloop, workspace_ptr + workspace_offset, stream); + status = CollectiveMainloop::initialize_workspace(args.problem_shape, args.mainloop, workspace_ptr + workspace_offset, stream, cuda_adapter); workspace_offset += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, args.hw_info.sm_count); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); @@ -336,7 +339,7 @@ class GemmUniversal< } else { grid_shape = TileScheduler::get_grid_shape(params.problem_shape.get_host_problem_shape(), TileShape{}, ClusterShape{}, params.hw_info, args); - } + } return grid_shape; } @@ -361,10 +364,10 @@ class GemmUniversal< static_assert(size<0>(TileShape{}) >= 128, "Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension."); - static_assert(cute::rank(UnderlyingStrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(UnderlyingStrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(UnderlyingStrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(UnderlyingStrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(InternalStrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(InternalStrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(InternalStrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(InternalStrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); /* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the same tile */ enum class WarpGroupRole { @@ -406,7 +409,7 @@ class GemmUniversal< } mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; mainloop_pipeline_params.num_consumers = size(TiledMma{}); - mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes; MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); // Epilogue Load pipeline @@ -421,7 +424,9 @@ class GemmUniversal< epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; epi_load_pipeline_params.consumer_arv_count = size(TiledMma{}); - epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + if constexpr (CollectiveEpilogue::RequiresTransactionBytes) { + epi_load_pipeline_params.transaction_bytes = params.epilogue.tma_transaction_bytes; + } EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); // Epilogue Store pipeline @@ -464,18 +469,23 @@ class GemmUniversal< auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) TileScheduler scheduler{params.scheduler}; - auto work_tile_info = scheduler.get_current_work(); + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Wait for all thread blocks in the Cluster + cluster_wait_fn(); + + auto work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); if (not work_tile_info.is_valid()) { + // When problem shapes are only on device, the grid launched may be larger than the total number of blocks across groups return; } // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{}); - // In a warp specialized kernel, collectives expose data movement and compute operations separately - CollectiveMainloop collective_mainloop; - CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); - // Prepare and partition the input tensors. Expects a tuple of tensors where: // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) @@ -489,16 +499,12 @@ class GemmUniversal< // Get pipeline stage increments from tensor shapes auto k_tile_count = size<3>(gA_mkl); - // Wait for all thread blocks in the Cluster - cluster_wait_fn(); - if (warp_group_role == WarpGroupRole::Producer) { cutlass::arch::warpgroup_reg_dealloc(); // Mainloop Producer Warp if (producer_warp_role == ProducerWarpRole::Mainloop) { int32_t curr_batch = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); // Usually just returns work_tile_info.L_idx; - int32_t next_batch = curr_batch; int32_t const mock_l_coord = 0; int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x); int32_t const sm_count = params.hw_info.sm_count; @@ -513,18 +519,19 @@ class GemmUniversal< params.mainloop, input_tensormaps, problem_shape_MNKL, - next_batch + curr_batch ); - // Ensure warp is converged before issuing tensor replace + // Ensure warp is converged before issuing tensormap fence release __syncwarp(); - // Entire warp must do this (ie its aligned) + // Entire warp must do this (i.e. it's aligned) collective_mainloop.tensormaps_cp_fence_release(shared_storage.tensormaps.mainloop, input_tensormaps); } bool do_load_order_arrive = true; + bool did_batch_change = true; while (work_tile_info.is_valid()) { if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { - work_tile_info = fetch_next_work(work_tile_info, scheduler); + work_tile_info = scheduler.fetch_next_work(work_tile_info); continue; } @@ -538,7 +545,9 @@ class GemmUniversal< auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); - collective_mainloop.tensormaps_fence_acquire(input_tensormaps); + if (did_batch_change) { + collective_mainloop.tensormaps_fence_acquire(input_tensormaps); + } collective_mainloop.load( params.mainloop, @@ -563,16 +572,17 @@ class GemmUniversal< } // Get next work tile - work_tile_info = fetch_next_work(work_tile_info, scheduler); - next_batch = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); // Usually just returns work_tile_info.L_idx - - if (work_tile_info.is_valid() && next_batch != curr_batch ) { + work_tile_info = scheduler.fetch_next_work(work_tile_info); + auto next_batch = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); // Usually just returns work_tile_info.L_idx + did_batch_change = next_batch != curr_batch; + if (work_tile_info.is_valid() && did_batch_change) { + curr_batch = next_batch; if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(next_batch), Int<1>{}); + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_batch), Int<1>{}); } // Purpose of this pipeline state is to make sure TMA loads have finished before doing descriptor updates // Since this state is waiting for loads to finish, it must start in the inverted phase. - typename CollectiveMainloop::PipelineState mainloop_pipe_tma_consumer_state = + typename CollectiveMainloop::PipelineState mainloop_pipe_tma_consumer_state = {mainloop_pipe_producer_state.index(), !mainloop_pipe_producer_state.phase(), mainloop_pipe_producer_state.count()}; mainloop_pipeline.consumer_wait(mainloop_pipe_tma_consumer_state); collective_mainloop.tensormaps_perform_update( @@ -580,13 +590,12 @@ class GemmUniversal< params.mainloop, input_tensormaps, problem_shape_MNKL, - next_batch + curr_batch ); // Ensure warp is converged before issuing tensor replace __syncwarp(); - // Entire warp must do this (ie its aligned) + // Entire warp must do this (i.e. it's aligned) collective_mainloop.tensormaps_cp_fence_release(shared_storage.tensormaps.mainloop, input_tensormaps); - curr_batch = next_batch; } // Advance the producer state for the last remaining stage that was being waited for above mainloop_pipe_producer_state.advance(1); @@ -598,19 +607,49 @@ class GemmUniversal< // Epilogue Producer Warp else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) { + int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x); + int32_t const sm_count = params.hw_info.sm_count; + + auto epi_load_tensormap = get<0>(collective_epilogue.load_init(params.epilogue, sm_count, sm_idx)); + + bool did_batch_change = true; + constexpr bool IsEpiLoad = true; + + if (work_tile_info.is_valid()) { + collective_epilogue.tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + epi_load_tensormap, + work_tile_info.L_idx + ); + + // Converge before issuing tensormap fence release since fence is aligned + __syncwarp(); + collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate); + } + + load_order_barrier.wait(); while (work_tile_info.is_valid()) { - if (!TileScheduler::requires_separate_reduction(params.scheduler)) { - load_order_barrier.wait(); - } - if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { + int32_t curr_batch = work_tile_info.L_idx; + + bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + + if (compute_epilogue) { + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{}); + } + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - epi_load_pipe_producer_state = - collective_epilogue.load( + if (did_batch_change) { + collective_epilogue.tensormaps_fence_acquire(epi_load_tensormap); + } + + epi_load_pipe_producer_state = collective_epilogue.load( epi_load_pipeline, epi_load_pipe_producer_state, problem_shape_MNKL, @@ -619,17 +658,40 @@ class GemmUniversal< tiled_mma, lane_idx, shared_storage.tensors.epilogue, - work_tile_info.reduction_subtile_idx() + epi_load_tensormap, + work_tile_info.reduction_subtile_idx(), + true // return state prior to last advance ); + } // Get next work tile - work_tile_info = fetch_next_work(work_tile_info, scheduler); - if constexpr (IsGroupedGemmKernel) { - if (work_tile_info.is_valid()) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{}); - } + work_tile_info = scheduler.fetch_next_work(work_tile_info); + did_batch_change = curr_batch != work_tile_info.L_idx; + + if (work_tile_info.is_valid() && did_batch_change) { + // Wait for TMA load to finish before updating + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_tma_consumer_state = + {epi_load_pipe_producer_state.index(), !epi_load_pipe_producer_state.phase(), epi_load_pipe_producer_state.count()}; + + epi_load_pipeline.consumer_wait(epi_load_pipe_tma_consumer_state); + + collective_epilogue.tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + epi_load_tensormap, + work_tile_info.L_idx + ); + + // Converge before issuing tensormap fence release since fence is aligned + __syncwarp(); + collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate); } + + if(compute_epilogue) { + epi_load_pipe_producer_state.advance(1); + } + } // Scheduler work fetch loop // Make sure all Consumer Warp Groups have been waited upon @@ -640,9 +702,36 @@ class GemmUniversal< else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { cutlass::arch::warpgroup_reg_alloc(); + int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x); + int32_t const sm_count = params.hw_info.sm_count; // Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it bool do_store_tail = false; + // Get a copy of tensormaps + auto epi_store_tensormap = get<0>(collective_epilogue.store_init(params.epilogue, sm_count, sm_idx)); + + bool did_batch_change = true; + constexpr bool IsEpiLoad = false; + + if (work_tile_info.is_valid()) { + collective_epilogue.tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + epi_store_tensormap, + work_tile_info.L_idx + ); + + // Converge before issuing tensormap fence release since fence is aligned + __syncwarp(); + collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_store_tensormap, lane_predicate); + } + while (work_tile_info.is_valid()) { + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{}); + } + + int32_t curr_batch = work_tile_info.L_idx; + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); @@ -683,6 +772,11 @@ class GemmUniversal< params.scheduler, work_tile_info, accumulators, NumMmaWarpGroups, consumer_warp_group_idx); if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { + + if (did_batch_change) { + collective_epilogue.tensormaps_fence_acquire(epi_store_tensormap); + } + // Epilogue and write to gD auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = collective_epilogue.store( @@ -697,6 +791,7 @@ class GemmUniversal< tiled_mma, mma_thread_idx, shared_storage.tensors.epilogue, + epi_store_tensormap, work_tile_info.reduction_subtile_idx() ); epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next; @@ -705,12 +800,22 @@ class GemmUniversal< } // Get next work tile - work_tile_info = fetch_next_work(work_tile_info, scheduler); - if constexpr (IsGroupedGemmKernel) { - if (work_tile_info.is_valid()) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{}); - } + work_tile_info = scheduler.fetch_next_work(work_tile_info); + + did_batch_change = curr_batch != work_tile_info.L_idx; + if (work_tile_info.is_valid() && did_batch_change) { + collective_epilogue.tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + epi_store_tensormap, + work_tile_info.L_idx + ); + + // Converge before issuing tensormap fence release since fence is aligned + __syncwarp(); + collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_store_tensormap, lane_predicate); } + } // Scheduler work fetch loop if (do_store_tail) { @@ -725,24 +830,6 @@ class GemmUniversal< #endif } -private: - // Kernel helper function to get next work unit - CUTLASS_DEVICE - typename TileScheduler::WorkTileInfo - fetch_next_work( - typename TileScheduler::WorkTileInfo& work_tile_info, - TileScheduler& scheduler) const { - // Check whether we should continue on with the current work unit. If this is the case, - // the work unit will have been updated in continue_current_work to reflect the new - // tile to be computed. - if (scheduler.continue_current_work(work_tile_info)) { - return work_tile_info; - } - - // Get next work tile - scheduler.advance_to_next_work(); - return scheduler.get_current_work(); - } }; /////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp index 916f92d707..8281561594 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp @@ -38,7 +38,9 @@ #include "cutlass/epilogue/collective/detail.hpp" #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/gemm_universal_decl.h" #include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" #include "cutlass/trace.h" #include "cute/tensor.hpp" @@ -46,19 +48,6 @@ namespace cutlass::gemm::kernel { -namespace detail { - -// IF_SWAP_AB::value will be true only if: -// class T has member SwapAB and T::SwapAB is true -template -struct IF_SWAP_AB { static constexpr bool value = false; }; - -template -struct IF_SWAP_AB > -{ static constexpr bool value = T::SwapAB; }; - -} // namespace - /////////////////////////////////////////////////////////////////////////////// template < @@ -151,7 +140,7 @@ class GemmUniversal< to_underlying_arguments(Arguments const& args, void* workspace) { (void) workspace; auto problem_shape = args.problem_shape; - if constexpr (detail::IF_SWAP_AB::value) { + if constexpr (detail::Has_SwapAB_v) { // swap M/N get<0>(problem_shape) = get<1>(args.problem_shape); get<1>(problem_shape) = get<0>(args.problem_shape); @@ -164,8 +153,7 @@ class GemmUniversal< }; } - CUTLASS_HOST_DEVICE static - bool + static bool can_implement(Arguments const& args) { bool implementable = (args.mode == GemmUniversalMode::kGemm) or (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); @@ -285,15 +273,13 @@ class GemmUniversal< ); constexpr int BLK_M_RANK = cute::rank<0>(blk_shape); - bool m_oob = int(blockIdx.x) >= size<2>(gA_mkl); auto m_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { - return m_oob ? 0 : get(M) - get<0,i>(blk_shape) * get(m_coord); + return get(M) - get<0,i>(blk_shape) * get(m_coord); })); constexpr int BLK_N_RANK = cute::rank<1>(blk_shape); - bool n_oob = int(blockIdx.y) >= size<2>(gB_nkl); auto n_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { - return n_oob ? 0 : get(N) - get<1,i>(blk_shape) * get(n_coord); + return get(N) - get<1,i>(blk_shape) * get(n_coord); })); auto residue_mnk = make_tuple(m_max_coord, n_max_coord, Int<0>{}); diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp index 2ddaef13b3..124cc43770 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp @@ -157,7 +157,7 @@ class GemmUniversal< to_underlying_arguments(Arguments const& args, void* workspace) { (void) workspace; auto problem_shape = args.problem_shape; - if constexpr (detail::IF_SWAP_AB::value) { + if constexpr (detail::Has_SwapAB_v) { // swap M/N get<0>(problem_shape) = get<1>(args.problem_shape); get<1>(problem_shape) = get<0>(args.problem_shape); @@ -170,8 +170,7 @@ class GemmUniversal< }; } - CUTLASS_HOST_DEVICE static - bool + static bool can_implement(Arguments const& args) { bool implementable = (args.mode == GemmUniversalMode::kGemm) or (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); @@ -220,8 +219,11 @@ class GemmUniversal< using namespace cute; using X = Underscore; +#if defined(__CUDA_ARCH_FEAT_SM90_ALL) +# define ENABLE_SM90_KERNEL_LEVEL 1 +#endif // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. -#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) +#if ! defined(ENABLE_SM90_KERNEL_LEVEL) printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); #else @@ -267,7 +269,7 @@ class GemmUniversal< } mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup; - mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes; MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); // Epilogue Load pipeline @@ -282,7 +284,9 @@ class GemmUniversal< epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; - epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + if constexpr (CollectiveEpilogue::RequiresTransactionBytes) { + epi_load_pipeline_params.transaction_bytes = params.epilogue.tma_transaction_bytes; + } EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); // Epilogue Store pipeline @@ -388,7 +392,7 @@ class GemmUniversal< ); collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); } - } + } } else if (warp_group_role == WarpGroupRole::Consumer) { Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp index b2077db027..6f2cff2601 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp @@ -44,6 +44,7 @@ #include "cutlass/pipeline/pipeline.hpp" #include "cute/tensor.hpp" #include "cutlass/trace.h" +#include "cutlass/gemm/kernel/gemm_universal_decl.h" /////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::kernel { @@ -116,14 +117,6 @@ class GemmUniversal< // Kernel level shared memory storage struct SharedStorage { - struct TensorStorage : cute::aligned_struct<128> { - using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; - using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; - - MainloopTensorStorage mainloop; - EpilogueTensorStorage epilogue; - } tensors; - struct PipelineStorage : cute::aligned_struct<16> { using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; @@ -132,6 +125,14 @@ class GemmUniversal< alignas(16) EpiLoadPipelineStorage epi_load; alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order; } pipelines; + + struct TensorStorage : cute::aligned_struct<128> { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + EpilogueTensorStorage epilogue; + MainloopTensorStorage mainloop; + } tensors; }; static constexpr int SharedStorageSize = sizeof(SharedStorage); @@ -168,7 +169,7 @@ class GemmUniversal< CUTLASS_TRACE_HOST("to_underlying_arguments():"); auto problem_shape = args.problem_shape; - if constexpr (detail::IF_SWAP_AB::value) { + if constexpr (detail::Has_SwapAB_v) { // swap M/N get<0>(problem_shape) = get<1>(args.problem_shape); get<1>(problem_shape) = get<0>(args.problem_shape); @@ -219,8 +220,7 @@ class GemmUniversal< }; } - CUTLASS_HOST_DEVICE static - bool + static bool can_implement(Arguments const& args) { bool implementable = (args.mode == GemmUniversalMode::kGemm) or (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); @@ -250,7 +250,7 @@ class GemmUniversal< } static cutlass::Status - initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, CudaHostAdapter* cuda_adapter = nullptr) { Status status = Status::kSuccess; uint8_t* workspace_ptr = reinterpret_cast(workspace); @@ -258,7 +258,7 @@ class GemmUniversal< constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); status = TileScheduler::template initialize_workspace( - args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles, cuda_adapter); workspace_offset += TileScheduler::template get_workspace_size( args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); @@ -299,9 +299,12 @@ class GemmUniversal< using namespace cute; using X = Underscore; +#if defined(__CUDA_ARCH_FEAT_SM90_ALL) +# define ENABLE_SM90_KERNEL_LEVEL 1 +#endif // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. -#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) - printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); +#if ! defined(ENABLE_SM90_KERNEL_LEVEL) + printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); #else // Preconditions @@ -358,7 +361,7 @@ class GemmUniversal< } mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; mainloop_pipeline_params.num_consumers = size(TiledMma{}); - mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes; MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); // Epilogue Load pipeline @@ -373,7 +376,9 @@ class GemmUniversal< epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; epi_load_pipeline_params.consumer_arv_count = size(TiledMma{}); - epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + if constexpr (CollectiveEpilogue::RequiresTransactionBytes) { + epi_load_pipeline_params.transaction_bytes = params.epilogue.tma_transaction_bytes; + } EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); // Epilogue Store pipeline @@ -419,11 +424,10 @@ class GemmUniversal< auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) TileScheduler scheduler{params.scheduler}; - auto work_tile_info = scheduler.get_current_work(); + auto work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); // In a warp specialized kernel, collectives expose data movement and compute operations separately CollectiveMainloop collective_mainloop; - CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); // Prepare and partition the input tensors. Expects a tuple of tensors where: // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) @@ -435,21 +439,20 @@ class GemmUniversal< Tensor gA_mkl = get<0>(load_inputs); Tensor gB_nkl = get<1>(load_inputs); - // Get pipeline stage increments from tensor shapes - auto k_tile_count = size<3>(gA_mkl); - // Wait for all thread blocks in the Cluster cluster_wait_fn(); if (warp_group_role == WarpGroupRole::Producer) { cutlass::arch::warpgroup_reg_dealloc(); + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + // Mainloop Producer Warp if (producer_warp_role == ProducerWarpRole::Mainloop) { bool do_load_order_arrive = true; while (work_tile_info.is_valid()) { if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { - work_tile_info = fetch_next_work(work_tile_info, scheduler); + work_tile_info = scheduler.fetch_next_work(work_tile_info); continue; } @@ -485,19 +488,21 @@ class GemmUniversal< } // Get next work tile - work_tile_info = fetch_next_work(work_tile_info, scheduler); + work_tile_info = scheduler.fetch_next_work(work_tile_info); } // Scheduler work fetch loop // Make sure all Consumer Warp Groups have been waited upon collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + } // Mainloop Producer Warp End // Epilogue Producer Warp else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) { + + if (!TileScheduler::requires_separate_reduction(params.scheduler) && work_tile_info.is_valid()) { + load_order_barrier.wait(); + } while (work_tile_info.is_valid()) { - if (!TileScheduler::requires_separate_reduction(params.scheduler)) { - load_order_barrier.wait(); - } if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); @@ -520,7 +525,7 @@ class GemmUniversal< } // Get next work tile - work_tile_info = fetch_next_work(work_tile_info, scheduler); + work_tile_info = scheduler.fetch_next_work(work_tile_info); } // Scheduler work fetch loop // Make sure all Consumer Warp Groups have been waited upon @@ -531,6 +536,8 @@ class GemmUniversal< else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { cutlass::arch::warpgroup_reg_alloc(); + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + // Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it bool do_store_tail = false; while (work_tile_info.is_valid()) { @@ -596,7 +603,7 @@ class GemmUniversal< } // Get next work tile - work_tile_info = fetch_next_work(work_tile_info, scheduler); + work_tile_info = scheduler.fetch_next_work(work_tile_info); } // Scheduler work fetch loop if (do_store_tail) { @@ -611,24 +618,6 @@ class GemmUniversal< #endif } -private: - // Kernel helper function to get next work unit - CUTLASS_DEVICE - typename TileScheduler::WorkTileInfo - fetch_next_work( - typename TileScheduler::WorkTileInfo& work_tile_info, - TileScheduler& scheduler) const { - // Check whether we should continue on with the current work unit. If this is the case, - // the work unit will have been updated in continue_current_work to reflect the new - // tile to be computed. - if (scheduler.continue_current_work(work_tile_info)) { - return work_tile_info; - } - - // Get next work tile - scheduler.advance_to_next_work(); - return scheduler.get_current_work(); - } }; /////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp index 7924b0131d..a694d06a4e 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp @@ -41,6 +41,8 @@ #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/gemm/kernel/gemm_universal_decl.h" #include "cutlass/pipeline/pipeline.hpp" #include "cutlass/trace.h" @@ -119,27 +121,31 @@ class GemmUniversal< static constexpr uint32_t StagesPerMathWarpGroup = 2; using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier< StagesPerMathWarpGroup, NumMmaWarpGroups>; + using MathWarpGroupOrderBarrierSharedStorage = + cutlass::PipelineDetail::OrderedSequenceBarrierSharedStorage< + MathWarpGroupOrderBarrier::SequenceDepth, + MathWarpGroupOrderBarrier::SequenceLength>; // Kernel level shared memory storage struct SharedStorage { - struct TensorStorage : cute::aligned_struct<128> { - using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; - using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; - - MainloopTensorStorage mainloop; - EpilogueTensorStorage epilogue; - } tensors; - struct PipelineStorage : cute::aligned_struct<16> { using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; - using MathWarpGroupOrderBarrierStorage = typename MathWarpGroupOrderBarrier::SharedStorage; + using MathWarpGroupOrderBarrierStorage = MathWarpGroupOrderBarrierSharedStorage; alignas(16) MainloopPipelineStorage mainloop; alignas(16) EpiLoadPipelineStorage epi_load; alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order; alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order; } pipelines; + + struct TensorStorage : cute::aligned_struct<128> { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + EpilogueTensorStorage epilogue; + MainloopTensorStorage mainloop; + } tensors; }; static constexpr int SharedStorageSize = sizeof(SharedStorage); @@ -176,7 +182,7 @@ class GemmUniversal< (void) workspace; auto problem_shape = args.problem_shape; - if constexpr (detail::IF_SWAP_AB::value) { + if constexpr (detail::Has_SwapAB_v) { // swap M/N get<0>(problem_shape) = get<1>(args.problem_shape); get<1>(problem_shape) = get<0>(args.problem_shape); @@ -219,8 +225,7 @@ class GemmUniversal< }; } - CUTLASS_HOST_DEVICE static - bool + static bool can_implement(Arguments const& args) { bool implementable = (args.mode == GemmUniversalMode::kGemm) or (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); @@ -256,7 +261,7 @@ class GemmUniversal< size_t workspace_offset = 0; status = TileScheduler::template initialize_workspace( - args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups); + args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups, 1, cuda_adapter); workspace_offset += TileScheduler::template get_workspace_size( args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); @@ -350,7 +355,7 @@ class GemmUniversal< } mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup; - mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes; MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); // Epilogue Load pipeline @@ -365,7 +370,9 @@ class GemmUniversal< epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; - epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + if constexpr (CollectiveEpilogue::RequiresTransactionBytes) { + epi_load_pipeline_params.transaction_bytes = params.epilogue.tma_transaction_bytes; + } EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); // Epilogue Store pipeline @@ -446,7 +453,7 @@ class GemmUniversal< epi_load_pipe_consumer_state.advance(c_tile_count); epi_store_pipe_producer_state.advance(d_tile_count); } - auto work_tile_info = scheduler.get_current_work(); + auto work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); // Wait for all thread blocks in the Cluster cluster_wait_fn(); @@ -493,10 +500,12 @@ class GemmUniversal< // Make sure all Consumer Warp Groups have been waited upon collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + } // Mainloop Producer Warp End // Epilogue Producer Warp else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) { + load_order_barrier.wait(); while (work_tile_info.is_valid()) { // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape diff --git a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp index edff1d8e18..a2bc4dbf25 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp @@ -161,7 +161,7 @@ class GemmUniversal< to_underlying_arguments(Arguments const& args, void* workspace) { (void) workspace; auto problem_shape = args.problem_shape; - if constexpr (detail::IF_SWAP_AB::value) { + if constexpr (detail::Has_SwapAB_v) { // swap M/N get<0>(problem_shape) = get<1>(args.problem_shape); get<1>(problem_shape) = get<0>(args.problem_shape); @@ -174,8 +174,7 @@ class GemmUniversal< }; } - CUTLASS_HOST_DEVICE static - bool + static bool can_implement(Arguments const& args) { bool implementable = (args.mode == GemmUniversalMode::kGemm) or (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); diff --git a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp index 877d2c1ddf..0570208e8d 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp @@ -164,7 +164,7 @@ class GemmUniversal< CUTLASS_TRACE_HOST("to_underlying_arguments():"); auto problem_shape = args.problem_shape; - if constexpr (detail::IF_SWAP_AB::value) { + if constexpr (detail::Has_SwapAB_v) { // swap M/N get<0>(problem_shape) = get<1>(args.problem_shape); get<1>(problem_shape) = get<0>(args.problem_shape); @@ -195,8 +195,7 @@ class GemmUniversal< }; } - CUTLASS_HOST_DEVICE static - bool + static bool can_implement(Arguments const& args) { bool implementable = (args.mode == GemmUniversalMode::kGemm) or (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); @@ -225,7 +224,7 @@ class GemmUniversal< CudaHostAdapter* cuda_adapter = nullptr) { TileScheduler t; return t.template initialize_workspace( - args.scheduler, workspace, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups); + args.scheduler, workspace, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups, 1, cuda_adapter); } // Computes the kernel launch grid shape based on runtime parameters @@ -340,7 +339,7 @@ class GemmUniversal< Tensor gB_nkl = local_tile(mB_nkl, blk_shape, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) TileScheduler scheduler{params.scheduler}; - auto work_tile_info = scheduler.get_current_work(); + auto work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); // In a warp specialized kernel, collectives expose data movement and compute operations separately CollectiveMainloop collective_mainloop; @@ -402,7 +401,7 @@ class GemmUniversal< } // Get next work tile - work_tile_info = fetch_next_work(work_tile_info, scheduler); + work_tile_info = scheduler.fetch_next_work(work_tile_info); } // Scheduler work fetch loop // Make sure all Consumer Warp Groups have been waited upon @@ -478,7 +477,7 @@ class GemmUniversal< } // Get next work tile - work_tile_info = fetch_next_work(work_tile_info, scheduler); + work_tile_info = scheduler.fetch_next_work(work_tile_info); } // Scheduler work fetch loop if (do_store_tail) { @@ -493,24 +492,6 @@ class GemmUniversal< #endif } -private: - // Kernel helper function to get next work unit - CUTLASS_DEVICE - typename TileScheduler::WorkTileInfo - fetch_next_work( - typename TileScheduler::WorkTileInfo& work_tile_info, - TileScheduler& scheduler) const { - // Check whether we should continue on with the current work unit. If this is the case, - // the work unit will have been updated in continue_current_work to reflect the new - // tile to be computed. - if (scheduler.continue_current_work(work_tile_info)) { - return work_tile_info; - } - - // Get next work tile - scheduler.advance_to_next_work(); - return scheduler.get_current_work(); - } }; /////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp index abf79e842a..280de66358 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp @@ -40,6 +40,7 @@ #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/gemm/kernel/gemm_universal_decl.h" #include "cutlass/pipeline/pipeline.hpp" #include "cutlass/trace.h" @@ -175,7 +176,7 @@ class GemmUniversal< (void) workspace; auto problem_shape = args.problem_shape; - if constexpr (detail::IF_SWAP_AB::value) { + if constexpr (detail::Has_SwapAB_v) { // swap M/N get<0>(problem_shape) = get<1>(args.problem_shape); get<1>(problem_shape) = get<0>(args.problem_shape); @@ -206,8 +207,7 @@ class GemmUniversal< }; } - CUTLASS_HOST_DEVICE static - bool + static bool can_implement(Arguments const& args) { bool implementable = (args.mode == GemmUniversalMode::kGemm) or (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); @@ -367,7 +367,7 @@ class GemmUniversal< epi_load_pipe_consumer_state.advance(c_tile_count); epi_store_pipe_producer_state.advance(d_tile_count); } - auto work_tile_info = scheduler.get_current_work(); + auto work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); // In a warp specialized kernel, collectives expose data movement and compute operations separately CollectiveMainloop collective_mainloop; diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp index d55b7d7695..68ea45c0b6 100644 --- a/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp +++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp @@ -128,10 +128,22 @@ public StaticPersistentTileScheduler { template static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t, ProblemShape, KernelHardwareInfo const&, - uint32_t, const uint32_t = 1) { + uint32_t, const uint32_t = 1, CudaHostAdapter* cuda_adapter = nullptr) { return Status::kSuccess; } + // Kernel helper function to get next work tile + CUTLASS_DEVICE + auto + fetch_next_work(WorkTileInfo work_tile_info) { + if (continue_current_work(work_tile_info)) { + return work_tile_info; + } + + advance_to_next_work(); + return get_current_work(); + } + }; } diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp index 49e1d1e51c..ebeb0434b5 100644 --- a/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp +++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp @@ -204,7 +204,6 @@ class PersistentTileSchedulerSm90Group { ); } - CUTLASS_HOST_DEVICE static bool can_implement(Arguments const& args) { return true; @@ -408,7 +407,7 @@ class PersistentTileSchedulerSm90Group { template static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t, ProblemShape, KernelHardwareInfo const&, - uint32_t, const uint32_t = 1) { + uint32_t, const uint32_t = 1, CudaHostAdapter* cuda_adapter = nullptr) { return Status::kSuccess; } @@ -480,6 +479,27 @@ class PersistentTileSchedulerSm90Group { requires_separate_reduction(Params const& params) { return false; } + + // Kernel helper function to get next work tile + CUTLASS_DEVICE + auto + fetch_next_work(WorkTileInfo work_tile_info) { + if (continue_current_work(work_tile_info)) { + return work_tile_info; + } + + advance_to_next_work(); + return get_current_work(); + } + + // Returns the initial work tile info that will be computed over + template + CUTLASS_DEVICE + WorkTileInfo + initial_work_tile_info(ClusterShape) { + return get_current_work(); + } + }; } // namespace cutlass::gemm::kernel::detail diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp index d981e29a2d..c65d6e1d7e 100644 --- a/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp +++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp @@ -226,7 +226,6 @@ class PersistentTileSchedulerSm90StreamK { return params; } - CUTLASS_HOST_DEVICE static bool can_implement(Arguments const& args) { // Split count > 1 is only valid for heuristic and split-K decomposition modes @@ -263,7 +262,7 @@ class PersistentTileSchedulerSm90StreamK { // for the fact that we have splits_ peers per output tile, we multiply this // value by splits_. For stream-K, this multiplication ends up being a no-op // because splits_ is set to 1 for stream-K. - if(linear_idx >= (params.units_per_problem_ * params.splits_ + params.separate_reduction_units_)) { + if(linear_idx >= (params.units_per_problem_ * params.divmod_splits_.divisor + params.separate_reduction_units_)) { // Invalid work. Return an empty result. return WorkTileInfo::invalid_work_tile(); } @@ -423,7 +422,7 @@ class PersistentTileSchedulerSm90StreamK { using BlockStripedReduceT = BlockStripedReduce; AccumulatorArrayT* reduction_workspace_array = reinterpret_cast(group_reduction_workspace); - AccumulatorArrayT* accumulator_array = reinterpret_cast(&accumulators); + AccumulatorArrayT* accumulator_array = reinterpret_cast(accumulators.data()); int barrier_group_thread_idx = threadIdx.x % BarrierManager::ThreadCount; @@ -434,7 +433,7 @@ class PersistentTileSchedulerSm90StreamK { // note that, in the split-K case, the units_per_problem_ member of Params will be // the total number of output tiles. uint32_t reduction_tiles = 0; - if (params.splits_ > 1) { + if (params.divmod_splits_.divisor > 1) { reduction_tiles = params.units_per_problem_; } else if (params.requires_separate_reduction()) { @@ -583,7 +582,8 @@ class PersistentTileSchedulerSm90StreamK { ProblemShape const& problem_shape, KernelHardwareInfo const& hw_info, uint32_t mma_warp_groups, - const uint32_t epilogue_subtile = 1) { + const uint32_t epilogue_subtile = 1, + CudaHostAdapter* cuda_adapter = nullptr) { auto problem_shape_mnkl = cute::append<4>(problem_shape, 1); @@ -608,7 +608,9 @@ class PersistentTileSchedulerSm90StreamK { mma_warp_groups, sizeof_bits::value, sizeof_bits::value, - epilogue_subtile + epilogue_subtile, + 1, + cuda_adapter ); } @@ -625,6 +627,25 @@ class PersistentTileSchedulerSm90StreamK { return work_tile_info.K_idx; } + // Kernel helper function to get next work tile + CUTLASS_DEVICE + auto + fetch_next_work(WorkTileInfo work_tile_info) { + if (continue_current_work(work_tile_info)) { + return work_tile_info; + } + + advance_to_next_work(); + return get_current_work(); + } + + // Returns the initial work tile info that will be computed over + CUTLASS_DEVICE + WorkTileInfo + initial_work_tile_info(ClusterShape) { + return get_current_work(); + } + private: // Sets the current stream-K work to compute within work_tile_info. If new_unit is true, work_tile_info // is populated as a new unit of work. Otherwise, state existing in work_tile_info (e.g., remaining @@ -636,8 +657,11 @@ class PersistentTileSchedulerSm90StreamK { uint64_t linear_idx, WorkTileInfo& work_tile_info) { + auto [cta_m_in_cluster_, cta_n_in_cluster_, _] = cute::block_id_in_cluster(); + uint64_t cta_m_in_cluster = static_cast(cta_m_in_cluster_); + uint64_t cta_n_in_cluster = static_cast(cta_n_in_cluster_); uint64_t output_tile_id = linear_idx; - if (linear_idx >= params.units_per_problem_ * params.splits_) { + if (linear_idx >= params.units_per_problem_ * params.divmod_splits_.divisor) { // Separate-reduction work auto cluster_size = params.get_cluster_size(); // Divide up the linearized separate reduction units into clusters @@ -649,7 +673,7 @@ class PersistentTileSchedulerSm90StreamK { work_tile_info.setup_separate_reduction(epi_subtile_idx); } - else if (linear_idx >= params.sk_units_ && params.splits_ == 1) { + else if (linear_idx >= params.sk_units_ && params.divmod_splits_.divisor == 1) { // Data-parallel work output_tile_id = linear_idx - params.sk_units_ + params.sk_tiles_; work_tile_info.K_idx = 0; @@ -697,11 +721,11 @@ class PersistentTileSchedulerSm90StreamK { uint64_t split; params.divmod_clusters_mnl_(split, cluster_linear_work_idx, cluster_linear_work_idx); - bool is_split_k = params.splits_ > 1; + bool is_split_k = params.divmod_splits_.divisor > 1; auto big_unit_cmp_lhs = is_split_k ? split : cluster_linear_work_idx; auto big_unit_cmp_rhs = is_split_k ? params.big_units_ : big_units_in_group; auto linear_idx_mult = is_split_k ? params.divmod_tiles_per_output_tile_.divisor : k_tiles_per_unit_in_group; - auto k_tiles_per_split = is_split_k ? params.k_tiles_per_sk_unit_ : k_tiles_per_unit_in_group; + auto k_tiles_per_split = is_split_k ? params.divmod_k_tiles_per_sk_unit_.divisor : k_tiles_per_unit_in_group; // Determine the starting k iteration computed by this stream-K work unit uint32_t unit_iter_start = (linear_idx_mult * cluster_linear_work_idx) + @@ -744,6 +768,15 @@ class PersistentTileSchedulerSm90StreamK { unit_iter_start += adjustment_tiles; k_tiles_in_my_split -= adjustment_tiles; } + else if (params.ktile_start_alignment_count == 2 && start_tile_k_tile % 2 != 0) { + // ktile for each SM start from even number + // If start from odd number ktile within the output tile + // now start at the ktile one before my initial ktile start (take one ktile from prev sm) + // if end on odd number ktile within the output tile + // now end at ktile that one before my ktile end (give one ktile to next sm) + unit_iter_start -= 1; + k_tiles_in_my_split += 1; + } } if (work_tile_info.k_tile_count == 0) { @@ -773,6 +806,14 @@ class PersistentTileSchedulerSm90StreamK { // Adjust our work to take on these K tiles. k_tiles_in_my_split += (params.divmod_tiles_per_output_tile_.divisor - end_tile_k_tile); } + else if (params.ktile_start_alignment_count == 2 && end_tile_k_tile % 2 != 0) { + // ktile for each SM start from even number + // If start from odd number ktile within the output tile + // now start at the ktile one before my initial ktile start (take one ktile from prev sm) + // If end on odd number ktile within the output tile, + // now end at ktile that one before my ktile end (give one ktile to next sm) + k_tiles_in_my_split -= 1; + } } work_tile_info.k_tile_remaining = k_tiles_in_my_split; @@ -801,8 +842,6 @@ class PersistentTileSchedulerSm90StreamK { // Bring the linearized tile ID back into the space of tiles, rather than clusters output_tile_id *= params.get_cluster_size(); - auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster(); - // The final linearized tile ID is in units of the cluster dimension over which we rasterize. if (params.raster_order_ == RasterOrder::AlongN) { output_tile_id += cta_n_in_cluster * params.divmod_cluster_shape_minor_.divisor; @@ -853,7 +892,7 @@ class PersistentTileSchedulerSm90StreamK { auto tile_idx_in_cluster_path = params.div_cluster_size(tile_idx); auto start_k_tile = params.divmod_tiles_per_output_tile_.divisor * tile_idx_in_cluster_path; auto end_k_tile = start_k_tile + params.divmod_tiles_per_output_tile_.divisor - 1; - auto big_unit_k_tiles = params.big_units_ * (params.k_tiles_per_sk_unit_ + 1); + auto big_unit_k_tiles = params.big_units_ * (params.divmod_k_tiles_per_sk_unit_.divisor + 1); auto adjust_unit = [&](uint32_t k_tile, uint32_t unit_idx, uint32_t k_tiles_per_unit) { auto unit_k_start = unit_idx * k_tiles_per_unit; @@ -881,16 +920,14 @@ class PersistentTileSchedulerSm90StreamK { auto find_unit = [&](uint32_t k_tile) { if (k_tile < big_unit_k_tiles) { // The tile is within the "big unit range" - auto k_tiles_per_unit = params.k_tiles_per_sk_unit_ + 1; - auto unit_idx = k_tile / k_tiles_per_unit; - return static_cast(adjust_unit(k_tile, unit_idx, k_tiles_per_unit)); + auto unit_idx = params.divmod_k_tiles_per_sk_big_unit_.divide(k_tile); + return static_cast(adjust_unit(k_tile, unit_idx, params.divmod_k_tiles_per_sk_big_unit_.divisor)); } else { // The tile is after the "big unit range." Account for this by finding the "normal unit" // that it belongs to, and then offsetting by the number of big units - auto k_tiles_per_unit = params.k_tiles_per_sk_unit_; - auto unit_idx = ((k_tile - big_unit_k_tiles) / params.k_tiles_per_sk_unit_) + (params.big_units_); - return static_cast(adjust_unit(k_tile, unit_idx, k_tiles_per_unit)); + auto unit_idx = params.divmod_k_tiles_per_sk_unit_.divide(k_tile - big_unit_k_tiles) + params.big_units_; + return static_cast(adjust_unit(k_tile, unit_idx, params.divmod_k_tiles_per_sk_unit_.divisor)); } }; diff --git a/include/cutlass/gemm/kernel/static_tile_scheduler.hpp b/include/cutlass/gemm/kernel/static_tile_scheduler.hpp index 0e1c210f32..b0af23c43c 100644 --- a/include/cutlass/gemm/kernel/static_tile_scheduler.hpp +++ b/include/cutlass/gemm/kernel/static_tile_scheduler.hpp @@ -127,7 +127,7 @@ class StaticPersistentTileScheduler { CUTLASS_HOST_DEVICE static bool can_implement(Arguments const& args) { - return true; + return args.max_swizzle_size >= 1; } CUTLASS_HOST_DEVICE @@ -206,18 +206,18 @@ class StaticPersistentTileScheduler { int32_t log_swizzle_size, RasterOrder raster_order) { - auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster(); - uint64_t minor_work_idx, major_work_idx, cluster_minor_offset; if (raster_order == RasterOrder::AlongN) { minor_work_idx = static_cast(tile_m); major_work_idx = static_cast(tile_n); - cluster_minor_offset = cta_m_in_cluster; + uint64_t cluster_m = divmod_cluster_shape_minor.divide(tile_m) * divmod_cluster_shape_minor.divisor; + cluster_minor_offset = tile_m - cluster_m; } else { major_work_idx = static_cast(tile_m); minor_work_idx = static_cast(tile_n); - cluster_minor_offset = cta_n_in_cluster; + uint64_t cluster_n = divmod_cluster_shape_minor.divide(tile_n) * divmod_cluster_shape_minor.divisor; + cluster_minor_offset = tile_n - cluster_n; } uint64_t cluster_idx_minor, cluster_idx_major, cluster_major_offset; @@ -248,21 +248,6 @@ class StaticPersistentTileScheduler { cta_m, cta_n ); } - // Kernel helper function to get next work ID - template - CUTLASS_DEVICE - auto - fetch_next_work( - WorkTileInfo work_tile_info, - WorkIdPipeline& work_id_pipeline, - WorkIdPipelineState work_id_pipe_consumer_state) { - WorkTileInfo new_work_tile_info; - advance_to_next_work(); - new_work_tile_info = get_current_work(); - - // Return true to indicate that the WorkID pipeline state should be advanced - return cute::make_tuple(new_work_tile_info, true); - } CUTLASS_DEVICE static auto diff --git a/include/cutlass/gemm/kernel/tile_scheduler.hpp b/include/cutlass/gemm/kernel/tile_scheduler.hpp index 41cf056459..9835e37fc8 100644 --- a/include/cutlass/gemm/kernel/tile_scheduler.hpp +++ b/include/cutlass/gemm/kernel/tile_scheduler.hpp @@ -35,6 +35,7 @@ \brief Utilities for selecting default tile schedulers */ +#include "cutlass/arch/arch.h" #include "cutlass/detail/dependent_false.hpp" #include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" #include "cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp" diff --git a/include/cutlass/gemm/kernel/tile_scheduler_params.h b/include/cutlass/gemm/kernel/tile_scheduler_params.h index 1630583f6c..36888a29fa 100644 --- a/include/cutlass/gemm/kernel/tile_scheduler_params.h +++ b/include/cutlass/gemm/kernel/tile_scheduler_params.h @@ -168,7 +168,8 @@ struct PersistentTileSchedulerSm90Params { KernelHardwareInfo hw_info, int max_swizzle_size, RasterOrderOptions raster_order_option, - bool truncate_by_problem_size=true) { + bool truncate_by_problem_size=true + ) { dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape, cta_shape, cluster_shape); return get_grid_shape( @@ -192,7 +193,8 @@ struct PersistentTileSchedulerSm90Params { KernelHardwareInfo hw_info, int max_swizzle_size, RasterOrderOptions raster_order_option, - bool truncate_by_problem_size=true) { + bool truncate_by_problem_size=true + ) { int const sm_count = hw_info.sm_count; @@ -238,6 +240,7 @@ struct PersistentTileSchedulerSm90Params { } } else { + int cta_per_device = sm_count; /* * Optimal grid size calculation is based on * GH100: 8 GPCs, 72 TPCs (9 TPCs/GPC), 2 SMs/TPC, 144 SMs per full GPU @@ -248,15 +251,16 @@ struct PersistentTileSchedulerSm90Params { auto cluster_size = cluster_shape.m() * cluster_shape.n(); int const min_num_gpc = sm_count < max_sm_per_gpc ? 1 : sm_count / max_sm_per_gpc; int const max_cta_occupancy_per_gpc = max_sm_per_gpc - (max_sm_per_gpc % cluster_size); - int cta_per_device = min_num_gpc * max_cta_occupancy_per_gpc; + cta_per_device = min_num_gpc * max_cta_occupancy_per_gpc; // The calculation below allows for larger grid size launch for different GPUs. int const num_gpc_residual = sm_count < max_sm_per_gpc ? 0 : sm_count % max_sm_per_gpc; int const max_cta_occupancy_per_residual_gpc = num_gpc_residual - (num_gpc_residual % cluster_size); cta_per_device += max_cta_occupancy_per_residual_gpc; - cta_per_device = sm_count < cta_per_device ? sm_count : cta_per_device; - + if (sm_count < cta_per_device) { + cta_per_device = sm_count; + } if (raster_order == RasterOrder::AlongN) { launch_grid.y = possibly_truncate( cta_per_device / cluster_shape.m(), @@ -420,7 +424,7 @@ struct PersistentTileSchedulerSm90StreamKParams { // The splitting factor to be used in a split-K decomposition of the problem. // If this is set to a value greater than 1, stream-K decomposition logic // is bypassed in favor of a split-K decomposition. - uint32_t splits_ = 1; + FastDivmod divmod_splits_{}; // Number of stream-K or split-K work units that compute an extra k iteration. // This is done to handle residuals in dividing up the k iteration space. @@ -442,7 +446,10 @@ struct PersistentTileSchedulerSm90StreamKParams { // Number of tiled k iterations computed by each stream-K work unit. This // can potentially cover more than one output tile. - uint32_t k_tiles_per_sk_unit_ = 0; + FastDivmod divmod_k_tiles_per_sk_unit_{}; + // Number of tiled k iterations computed by each "big" stream-K units, which + // processes one more K chunk than a "normal" stream-K unit. + FastDivmod divmod_k_tiles_per_sk_big_unit_{}; // Strategy to use when reducing between collaborating CTAs ReductionMode reduction_mode_ = ReductionMode::Deterministic; @@ -459,6 +466,9 @@ struct PersistentTileSchedulerSm90StreamKParams { // Maximum number of groups of stream-K units static constexpr uint32_t max_sk_groups_ = 8u; + // ktile start from even for each cta + uint32_t ktile_start_alignment_count { 1u }; + // Divides dividend by the cluster size CUTLASS_HOST_DEVICE uint64_t @@ -585,6 +595,14 @@ struct PersistentTileSchedulerSm90StreamKParams { splits = k_tiles_per_output_tile; } + // If splits == k_tiles_per_output_tiles, there will be one k_tile per cta + // and this violate k_tile start from even requirements. Thus we need to + // reduce the number of splits. + if (ktile_start_alignment_count > 1u && + static_cast(splits) == k_tiles_per_output_tile) { + splits = k_tiles_per_output_tile / ktile_start_alignment_count; + } + set_params_basic( underlying_params, problem_blocks_m, @@ -686,7 +704,8 @@ struct PersistentTileSchedulerSm90StreamKParams { auto sk_splits_too_small = [&](uint32_t g) { // Check whether the number of K tiles computed per stream-K unit is less // than min_iters_per_sk_unit_ - auto total_sk_k_tiles = (sk_tiles / g) * k_tiles_per_output_tile; + auto total_sk_cluster_tiles = (sk_cluster_tiles / g) * cluster_size; + auto total_sk_k_tiles = total_sk_cluster_tiles * k_tiles_per_output_tile; auto k_tiles_per_sk_unit = total_sk_k_tiles / (sk_units / g); return k_tiles_per_sk_unit < min_iters_per_sk_unit_; }; @@ -725,13 +744,12 @@ struct PersistentTileSchedulerSm90StreamKParams { // sk_tiles = (waves <= 2) ? total_tiles : (sm_count + (total_tiles % sm_count)) // Both total_tiles and sm_count are multiples of cluster size due to padding added // prior to kernel launch. - uint64_t sk_clustered_tiles = sk_tiles / cluster_size; - uint64_t sk_clustered_tiles_per_group = sk_clustered_tiles / groups; - uint64_t sk_tiles_per_group = sk_clustered_tiles_per_group * cluster_size; + uint64_t sk_cluster_tiles_per_group = sk_cluster_tiles / groups; + uint64_t sk_tiles_per_group = sk_cluster_tiles_per_group * cluster_size; // Groups that will process an extra stream-K tile cluster. These differ from "big_units," which // are stream-K units within a group that process an extra K chunk. - uint64_t sk_big_groups = sk_clustered_tiles % groups; + uint64_t sk_big_groups = sk_cluster_tiles % groups; uint64_t k_tiles_per_group = k_tiles_per_output_tile * sk_tiles_per_group; @@ -777,7 +795,7 @@ struct PersistentTileSchedulerSm90StreamKParams { // This setting ensures that the use of this divmod for stream-K decompositions // is essentially a no-op. divmod_clusters_mnl_ = FastDivmodU64(sk_units / cluster_size); - splits_ = 1; + divmod_splits_ = FastDivmod(1); log_swizzle_size_ = underlying_params.log_swizzle_size_; units_per_problem_ = static_cast(dp_units + sk_units); raster_order_ = underlying_params.raster_order_; @@ -790,7 +808,8 @@ struct PersistentTileSchedulerSm90StreamKParams { reduction_workspace_ = reduction_workspace; sk_tiles_ = sk_tiles; sk_units_ = static_cast(sk_units); - k_tiles_per_sk_unit_ = static_cast(k_tiles_per_sk_unit); + divmod_k_tiles_per_sk_unit_ = FastDivmod(static_cast(k_tiles_per_sk_unit)); + divmod_k_tiles_per_sk_big_unit_ = FastDivmod(static_cast(k_tiles_per_sk_unit + 1)); reduction_mode_ = reduction_mode; divmod_epilogue_subtile_ = FastDivmodU64(epilogue_subtile); separate_reduction_units_ = reduction_units; @@ -923,19 +942,19 @@ struct PersistentTileSchedulerSm90StreamKParams { // Calculates the size of the workspace needed for holding reduction barriers CUTLASS_HOST_DEVICE - static int + static size_t get_barrier_workspace_size(uint64_t num_tiles, uint32_t mma_warp_groups, uint32_t barrier_bits) { - auto workspace_bits = num_tiles * mma_warp_groups * barrier_bits; - return round_up_to_l2_alignment(bits_to_bytes(static_cast(workspace_bits))); + size_t workspace_bits = num_tiles * static_cast(mma_warp_groups) * static_cast(barrier_bits); + return round_up_to_l2_alignment(bits_to_bytes(workspace_bits)); } // Calculates the size of the workspace needed for holding partial outputs from splits CUTLASS_HOST_DEVICE - static int + static size_t get_reduction_workspace_size(uint64_t num_tiles, GemmCoord tile_shape, uint32_t accumulator_bits, uint32_t num_accumulator_mtxs = 1) { - auto output_tile_size = tile_shape.m() * tile_shape.n(); - auto workspace_bits = accumulator_bits * output_tile_size * num_tiles * num_accumulator_mtxs; - return round_up_to_l2_alignment(bits_to_bytes(static_cast(workspace_bits))); + size_t output_tile_size = tile_shape.m() * tile_shape.n(); + size_t workspace_bits = accumulator_bits * output_tile_size * num_tiles * num_accumulator_mtxs; + return round_up_to_l2_alignment(bits_to_bytes(workspace_bits)); } #if !defined(__CUDACC_RTC__) @@ -945,8 +964,8 @@ struct PersistentTileSchedulerSm90StreamKParams { uint32_t k_tiles_per_output_tile, GemmCoord tile_shape, GemmCoord cluster_shape, - int& barrier_workspace_size, - int& reduction_workspace_size, + size_t& barrier_workspace_size, + size_t& reduction_workspace_size, KernelHardwareInfo const& hw_info, int splits, int max_swizzle, @@ -970,8 +989,8 @@ struct PersistentTileSchedulerSm90StreamKParams { barrier_workspace_size = 0; reduction_workspace_size = 0; } - else if (decomposition_mode == DecompositionMode::SplitK || - (decomposition_mode == DecompositionMode::Heuristic && splits > 1)) { + else if (splits > 1 && + (decomposition_mode == DecompositionMode::SplitK || decomposition_mode == DecompositionMode::Heuristic)) { // Basic split-K variant requires workspace for all output tiles barrier_workspace_size = get_barrier_workspace_size(output_tiles, mma_warp_groups, barrier_bits); reduction_workspace_size = get_reduction_workspace_size(output_tiles, tile_shape, accumulator_bits, num_accumulator_mtxs); @@ -1094,8 +1113,8 @@ struct PersistentTileSchedulerSm90StreamKParams { uint32_t epilogue_subtile = 1, uint32_t num_accumulator_mtxs = 1) { - int barrier_workspace_size = 0; - int reduction_workspace_size = 0; + size_t barrier_workspace_size = 0; + size_t reduction_workspace_size = 0; #if !defined(__CUDACC_RTC__) get_workspace_component_sizes( @@ -1138,7 +1157,8 @@ struct PersistentTileSchedulerSm90StreamKParams { uint32_t mma_warp_groups, uint32_t barrier_bits, uint32_t element_accumulator_bits, - uint32_t epilogue_subtile) { + uint32_t epilogue_subtile, + CudaHostAdapter* cuda_adapter = nullptr) { dim3 problem_blocks = UnderlyingParams::get_tiled_cta_shape_mnl(problem_shape, tile_shape, cluster_shape); uint32_t k_tiles_per_output_tile = (problem_shape.k() + tile_shape.k() - 1) / tile_shape.k(); @@ -1158,7 +1178,9 @@ struct PersistentTileSchedulerSm90StreamKParams { mma_warp_groups, barrier_bits, element_accumulator_bits, - epilogue_subtile + epilogue_subtile, + 1, + cuda_adapter ); } @@ -1182,11 +1204,12 @@ struct PersistentTileSchedulerSm90StreamKParams { uint32_t barrier_bits, uint32_t element_accumulator_bits, uint32_t epilogue_subtile = 1, - uint32_t num_accumulator_mtxs = 1) { + uint32_t num_accumulator_mtxs = 1, + CudaHostAdapter* cuda_adapter = nullptr) { #if !defined(__CUDACC_RTC__) - int barrier_workspace_size = 0; - int reduction_workspace_size = 0; + uint64_t barrier_workspace_size = 0; + uint64_t reduction_workspace_size = 0; get_workspace_component_sizes( problem_blocks, @@ -1215,7 +1238,7 @@ struct PersistentTileSchedulerSm90StreamKParams { // Only the barrier workspace needs to be cleared for stream-K. // Barrier workspace follows reduction workspace. uint8_t* barrier_workspace = reinterpret_cast(workspace) + reduction_workspace_size; - return zero_workspace(static_cast(barrier_workspace), barrier_workspace_size, stream); + return zero_workspace(static_cast(barrier_workspace), barrier_workspace_size, stream, cuda_adapter); } #endif // !defined(__CUDACC_RTC__) @@ -1240,7 +1263,7 @@ struct PersistentTileSchedulerSm90StreamKParams { divmod_sk_groups_ = FastDivmodU64(1u); auto cluster_size = underlying_params.divmod_cluster_shape_major_.divisor * underlying_params.divmod_cluster_shape_minor_.divisor; divmod_clusters_mnl_ = FastDivmodU64((blocks_m * blocks_n * blocks_l) / cluster_size); - splits_ = splits; + divmod_splits_ = FastDivmod(splits); divmod_cluster_blk_major_ = underlying_params.divmod_cluster_blk_major_; log_swizzle_size_ = underlying_params.log_swizzle_size_; units_per_problem_ = blocks_m * blocks_n * blocks_l; @@ -1248,7 +1271,8 @@ struct PersistentTileSchedulerSm90StreamKParams { big_units_ = k_tiles_per_output_tile % splits; reduction_workspace_ = reduction_workspace; reduction_mode_ = reduction_mode; - k_tiles_per_sk_unit_ = k_tiles_per_output_tile / splits; + divmod_k_tiles_per_sk_unit_ = FastDivmod(k_tiles_per_output_tile / splits); + divmod_k_tiles_per_sk_big_unit_ = FastDivmod(k_tiles_per_output_tile / splits + 1); // No stream-K work is performed for "basic" data-parallel and split-K decompositions sk_tiles_ = 0; @@ -1260,9 +1284,9 @@ struct PersistentTileSchedulerSm90StreamKParams { private: // Round up number of bytes to the nearest multiple of L2 cache line alignment CUTLASS_HOST_DEVICE - static int - round_up_to_l2_alignment(int bytes) { - constexpr static uint32_t L2CacheLineSizeBytes = 128; + static size_t + round_up_to_l2_alignment(size_t bytes) { + constexpr size_t L2CacheLineSizeBytes = 128u; return (bytes + L2CacheLineSizeBytes - 1) / L2CacheLineSizeBytes * L2CacheLineSizeBytes; } }; diff --git a/include/cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h b/include/cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h index 4059110540..985693ce6d 100644 --- a/include/cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h +++ b/include/cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h @@ -191,16 +191,34 @@ struct DefaultSparseMmaCore::value), 8); + + static int const kWarpThreadArrangementStridedA = + kWarpSize / kWarpThreadArrangementContiguousA; + + static int const kWarpThreadArrangementContiguousB = + platform::min(Shape::kN / (kAccessSizeInBits / sizeof_bits::value), 8); + + static int const kWarpThreadArrangementStridedB = + kWarpSize / kWarpThreadArrangementContiguousB; + // // Shared memory layouts // + static int const Crosswise_A = platform::min(int(128 / sizeof(ElementA)), + Shape::kM); using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous< - sizeof_bits::value, int(128 / sizeof(ElementA))>; + sizeof_bits::value, Crosswise_A>; // Shared memory layout + static int const Crosswise_B = platform::min(int(128 / sizeof(ElementB)), + Shape::kN); + using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous< - sizeof_bits::value, int(128 / sizeof(ElementB))>; + sizeof_bits::value, Crosswise_B>; // // Iterators to write to shared memory @@ -209,7 +227,8 @@ struct DefaultSparseMmaCore, kThreads, - layout::PitchLinearShape<8, 4>, + layout::PitchLinearShape, kAccessSizeInBits / sizeof_bits::value>; /// Shared memory iterator to A operand @@ -220,7 +239,8 @@ struct DefaultSparseMmaCore, kThreads, - layout::PitchLinearShape<8, 4>, + layout::PitchLinearShape, kAccessSizeInBits / sizeof_bits::value>; /// Shared memory iterator to B operand @@ -547,6 +567,16 @@ struct DefaultSparseMmaCore::value), 8); + + static int const kWarpThreadArrangementStridedA = + kWarpSize / kWarpThreadArrangementContiguousA; + // Warp thread arrangement // crosswise cannot be larger than 1024 bit. static int const kCrosswiseB = @@ -565,7 +595,7 @@ struct DefaultSparseMmaCore::value, int(128 / sizeof(ElementA))>; + sizeof_bits::value, Crosswise_A>; // Shared memory layout using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< @@ -578,7 +608,8 @@ struct DefaultSparseMmaCore, kThreads, - layout::PitchLinearShape<8, 4>, + layout::PitchLinearShape, kAccessSizeInBits / sizeof_bits::value>; /// Shared memory iterator to A operand @@ -734,6 +765,16 @@ struct DefaultSparseMmaCore::value), 8); + + static int const kWarpThreadArrangementStridedB = + kWarpSize / kWarpThreadArrangementContiguousB; + + static int const Crosswise_B = platform::min(int(128 / sizeof(ElementB)), + Shape::kN); + + // // Shared memory layouts // @@ -743,7 +784,7 @@ struct DefaultSparseMmaCore::value, int(128 / sizeof(ElementB))>; + sizeof_bits::value, Crosswise_B>; // // Iterators to write to shared memory @@ -764,7 +805,8 @@ struct DefaultSparseMmaCore, kThreads, - layout::PitchLinearShape<8, 4>, + layout::PitchLinearShape, kAccessSizeInBits / sizeof_bits::value>; /// Shared memory iterator to B operand diff --git a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h index 004fc749a1..46690bf1ba 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h @@ -1040,6 +1040,15 @@ class MmaTensorOpMultiplicandTileIterator< partition_contiguous_idx = (lane_id % Layout::kFactor); access_contiguous_idx = (quad_quad + i * 2) ^ (lane_in_quad_pair / Layout::kFactor); access_strided_idx = (lane_in_quad_quad / Layout::kFactor); + } else if (Policy::LdsmShape::kContiguous == 1) { + // Matrix multiply 16832.SP B + // Q0 + // Q1 + // Q2 + // Q3 + partition_contiguous_idx = (lane_id % Layout::kFactor); + access_contiguous_idx = (lane_in_quad_pair / Layout::kFactor) ^ i; + access_strided_idx = lane_id / Layout::kFactor; } int access_contiguous = @@ -1432,7 +1441,21 @@ class MmaTensorOpMultiplicandTileIterator< access_contiguous_idx = ((lane_in_pair * 2 + quad_quad) ^ access_strided_idx); - } + } else if (Policy::LdsmShape::kContiguous == 1) { + // Matrix multiply 16832.SP B + // Q0 + // Q1 + // Q2 + // Q3 + int factor_in_partition = + (Layout::PartitionShape::kContiguous * Layout::kFactor / + Layout::TileShape::kContiguous); + + partition_contiguous_idx = lane_in_quad / factor_in_partition; + access_contiguous_idx = ((lane_in_pair * factor_in_partition) ^ + (lane_in_quad_quad / Layout::kFactor) ^ i); + access_strided_idx = lane_id / Layout::kFactor; + } int access_contiguous = partition_contiguous_idx * Layout::PartitionShape::kContiguous + diff --git a/include/cutlass/half.h b/include/cutlass/half.h index c203e6cb07..e5dcd71eb9 100644 --- a/include/cutlass/half.h +++ b/include/cutlass/half.h @@ -631,9 +631,10 @@ struct numeric_limits { } // namespace std #endif +namespace cutlass { namespace platform { -/// std::numeric_limits +/// Forward Declaration template struct numeric_limits; @@ -696,6 +697,7 @@ struct numeric_limits { static cutlass::half_t denorm_min() { return cutlass::half_t::bitcast(0x0001); } }; } // namespace platform +} // namespace cutlass /////////////////////////////////////////////////////////////////////////////////////////////////// // diff --git a/include/cutlass/integer_subbyte.h b/include/cutlass/integer_subbyte.h index 1a9728e7ab..b84d322dbb 100644 --- a/include/cutlass/integer_subbyte.h +++ b/include/cutlass/integer_subbyte.h @@ -48,53 +48,79 @@ namespace cutlass { -/////////////////////////////////////////////////////////////////////////////////////////////////// - template struct integer_subbyte { - /// Storage type using Storage = uint8_t; - /// Number of bits static_assert(Bits <= 8*sizeof(Storage), "Require a subbyte of bits in integer_subbyte"); - /// External type - using xint_t = typename platform::conditional::type; + // "External type"; the integer type for which + // integer_subbyte has a conversion-to operator + using xint_t = typename cutlass::platform::conditional::type; - /// Bitmask for truncation from larger integers + // Bitmask for truncation from larger integers static constexpr Storage bits_mask_ = Storage(Storage(-1) >> (8 - Bits)); - /// Bitmask for the sign bit + // Bitmask for the sign bit static constexpr Storage sign_mask_ = Storage((Signed ? 1 : 0) << (Bits - 1)); - // - // Data members - // - + // Where the bits are stored Storage storage; - // - // Methods - // - - /// No operation + // Default construction does NOT zero-initialize integer_subbyte() = default; - /// Conversion from integer type + // Implicit conversion is DEPRECATED. + // Please use one of the two explicit constructors below. + template> + > + [[deprecated("Implicit conversion is deprecated; please use explicit construction instead")]] + CUTLASS_HOST_DEVICE + integer_subbyte(T value) + : integer_subbyte(static_cast(value)) {} + + // CUTLASS code commonly converts both signed and unsigned integers + // into integer_subbyte, so the class provides both explicit + // conversions. + + // Precondition: If the external type is unsigned int, then value + // fits in unsigned int (is nonnegative). CUTLASS_HOST_DEVICE explicit integer_subbyte(int value) - : storage(reinterpret_cast(value) & bits_mask_) {} + : storage(reinterpret_cast(value) & bits_mask_) + { + if constexpr (Signed) { + [[maybe_unused]] constexpr int lower_bound = -(1 << (Bits - 1)); + [[maybe_unused]] constexpr int upper_bound = (1 << (Bits - 1)) - 1; + assert(value >= lower_bound); + assert(value < upper_bound); + } + else { + [[maybe_unused]] constexpr unsigned upper_bound = 1u << Bits; + assert(value >= 0); + assert(value < static_cast(upper_bound)); + } + } + // Precondition: If the external type is (signed) int, then value + // fits in int. CUTLASS_HOST_DEVICE explicit integer_subbyte(unsigned value) - : storage(reinterpret_cast(value) & bits_mask_) {} - - CUTLASS_HOST_DEVICE explicit - integer_subbyte(double value) { - xint_t tmp = static_cast(value); - storage = reinterpret_cast(tmp) & bits_mask_; + : storage(reinterpret_cast(value) & bits_mask_) + { + if constexpr (Signed) { + [[maybe_unused]] constexpr int lower_bound = -(1 << (Bits - 1)); + [[maybe_unused]] constexpr int upper_bound = (1 << (Bits - 1)) - 1; + assert(value >= lower_bound); + assert(value < upper_bound); + } + else { + [[maybe_unused]] constexpr unsigned upper_bound = 1u << Bits; + assert(value < upper_bound); + } } - /// Convert to int or unsigned + // Convert to the "external" integer type (int or unsigned) CUTLASS_HOST_DEVICE operator xint_t() const { if (sign_mask_ & storage) { // Sign extend @@ -104,49 +130,56 @@ struct integer_subbyte { } } - /// Equality CUTLASS_HOST_DEVICE bool operator==(integer_subbyte const& rhs) const { return storage == rhs.storage; } - /// Inequality CUTLASS_HOST_DEVICE bool operator!=(integer_subbyte const& rhs) const { return storage != rhs.storage; } - /// Less than or equal CUTLASS_HOST_DEVICE - bool operator<=(integer_subbyte const& rhs) const { - if (sign_mask_ & storage) { - return !(rhs.storage < storage); - } else { - return storage <= rhs.storage; + bool operator<(integer_subbyte const& rhs) const { + if ((sign_mask_ & storage) == (sign_mask_ & rhs.storage)) { + // If both *this and rhs have the same sign, compare storage directly. + return storage < rhs.storage; + } + else { + // If *this and rhs don't have the same sign, + // then return whether *this is negative. + return sign_mask_ & storage; } } - /// Less than CUTLASS_HOST_DEVICE - bool operator<(integer_subbyte const& rhs) const { - if (sign_mask_ & storage) { - return !(rhs.storage <= storage); - } else { - return storage < rhs.storage; + bool operator<=(integer_subbyte const& rhs) const { + if ((sign_mask_ & storage) == (sign_mask_ & rhs.storage)) { + // If both *this and rhs have the same sign, compare storage directly. + return storage <= rhs.storage; + } + else { + // If *this and rhs don't have the same sign, + // then return whether *this is negative. + return sign_mask_ & storage; } } - /// Greater than or equal CUTLASS_HOST_DEVICE bool operator>=(integer_subbyte const& rhs) const { return !(*this < rhs); } - /// Greater than CUTLASS_HOST_DEVICE bool operator>(integer_subbyte const& rhs) const { return !(*this <= rhs); } + + CUTLASS_HOST_DEVICE friend integer_subbyte + conj(integer_subbyte const& x) { + return x; + } }; /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -186,83 +219,62 @@ struct sizeof_bits { namespace platform { -template <> -struct numeric_limits { - CUTLASS_HOST_DEVICE static - cutlass::int4b_t const lowest() noexcept { return int4b_t{-8};} +/// Forward Declaration +template +struct numeric_limits; + +// Specialization for signed integer_subbyte +template +struct numeric_limits> { +private: + using value_type = cutlass::integer_subbyte; + +public: + CUTLASS_HOST_DEVICE static value_type lowest() noexcept { + return value_type{ + -(1 << (NumBits - 1)) + }; + } - CUTLASS_HOST_DEVICE static - cutlass::int4b_t const max() noexcept { return int4b_t{7};} + CUTLASS_HOST_DEVICE static value_type max() noexcept { + return value_type{ + (1 << (NumBits - 1)) - 1 + }; + } - CUTLASS_HOST_DEVICE static - cutlass::int4b_t const min() noexcept { return lowest();} + CUTLASS_HOST_DEVICE static value_type const min() noexcept { + return lowest(); + } static constexpr bool is_integer = true; static constexpr bool is_signed = true; + static constexpr bool has_infinity = false; }; -template <> -struct numeric_limits { - CUTLASS_HOST_DEVICE static - cutlass::uint4b_t const lowest() noexcept { return uint4b_t{0};} - - CUTLASS_HOST_DEVICE static - cutlass::uint4b_t const max() noexcept { return uint4b_t{15};} - - CUTLASS_HOST_DEVICE static - cutlass::uint4b_t const min() noexcept { return lowest();} - - static constexpr bool is_integer = true; - static constexpr bool is_signed = false; -}; - -template <> -struct numeric_limits { - CUTLASS_HOST_DEVICE static - cutlass::uint1b_t const lowest() noexcept { return uint1b_t{0};} - - CUTLASS_HOST_DEVICE static - cutlass::uint1b_t const max() noexcept { return uint1b_t{1};} - - CUTLASS_HOST_DEVICE static - cutlass::uint1b_t const min() noexcept { return lowest();} - - static constexpr bool is_integer = true; - static constexpr bool is_signed = false; -}; +// Specialization for unsigned integer_subbyte +template +struct numeric_limits> { +private: + using value_type = cutlass::integer_subbyte; -template <> -struct numeric_limits { - CUTLASS_HOST_DEVICE static - cutlass::int2b_t lowest() noexcept { return int2b_t{-2}; } - - CUTLASS_HOST_DEVICE static - cutlass::int2b_t min() noexcept { return lowest(); } - - CUTLASS_HOST_DEVICE static - cutlass::int2b_t max() noexcept { return int2b_t{1}; } - - static constexpr bool is_integer = true; - static constexpr bool is_signed = true; -}; - -template <> -struct numeric_limits { - CUTLASS_HOST_DEVICE static - cutlass::uint2b_t const lowest() noexcept { return uint2b_t{0}; } +public: + CUTLASS_HOST_DEVICE static value_type lowest() noexcept { + return value_type{0u}; + } - CUTLASS_HOST_DEVICE static - cutlass::uint2b_t const min() noexcept { return lowest(); } + CUTLASS_HOST_DEVICE static value_type max() noexcept { + return value_type{ + (1u << NumBits) - 1u + }; + } - CUTLASS_HOST_DEVICE static - cutlass::uint2b_t const max() noexcept { return uint2b_t{3}; } + CUTLASS_HOST_DEVICE static value_type const min() noexcept { + return lowest(); + } static constexpr bool is_integer = true; static constexpr bool is_signed = false; }; -/////////////////////////////////////////////////////////////////////////////////////////////////// - } // namespace platform } // namespace cutlass - diff --git a/include/cutlass/layout/tensor.h b/include/cutlass/layout/tensor.h index 409f6ab97e..8374fe31d0 100644 --- a/include/cutlass/layout/tensor.h +++ b/include/cutlass/layout/tensor.h @@ -60,14 +60,19 @@ namespace layout { // ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Tag used for 3-D NWC tensors for 1D conv, only used in 3.x API +/// Tag used for 3-D NWC tensors for 1-D convolutions; only used in 3.x API class TensorNWC {}; -/// Tag used for n-D KCSRT tensors for nD conv, only used in 3.x API for wgrad output layouts +/// Tag used for n-D KCSRT tensors for n-D convolutions; only used in 3.x API for wgrad output layouts class TensorKCS {}; class TensorKCSR {}; class TensorKCSRT {}; +/// Tag used for n-D CSRTK tensors for n-D convolutions; only used in 3.x API for wgrad output layouts +class TensorCSK {}; +class TensorCSRK {}; +class TensorCSRTK {}; + /// Mapping function for 4-D NHWC tensors. class TensorNHWC { public: @@ -639,14 +644,5 @@ class TensorNDHWC { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Tag used for linearized tensors with shape (NW, C) for 1D conv, only used in 3.x API -class TensorLinearizedNWC {}; -/// Tag used for linearized tensors with shape (NHW, C) for 2D conv, only used in 3.x API -class TensorLinearizedNHWC : public TensorNHWC {}; -/// Tag used for linearized tensors with shape (NDHW, C) for 3D conv, only used in 3.x API -class TensorLinearizedNDHWC : public TensorNDHWC {}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - } // namespace layout } // namespace cutlass diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index 2a3a09549c..2e74afa8e4 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -104,7 +104,6 @@ struct NumericConverter { CUTLASS_DEVICE static result_type convert(source_type const & s) { - return __float2int_rn(s); } @@ -221,6 +220,50 @@ struct NumericConverter { } }; +template <> +struct NumericConverter { + + using result_type = uint8_t; + using source_type = float; + static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; + + CUTLASS_DEVICE + static result_type convert(source_type const & s) { + + int32_t intermediate; + asm volatile("cvt.rni.sat.u8.f32 %0, %1;" : "=r"(intermediate) : "f"(s)); + + return static_cast(intermediate); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +template <> +struct NumericConverter { + + using result_type = uint8_t; + using source_type = float; + static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; + + CUTLASS_DEVICE + static result_type convert(source_type const & s) { + + int32_t intermediate; + asm volatile("cvt.rzi.sat.u8.f32 %0, %1;" : "=r"(intermediate) : "f"(s)); + + return static_cast(intermediate); + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + #elif !defined(__CUDACC_RTC__) template <> @@ -273,8 +316,118 @@ struct NumericConverter { } }; +template <> +struct NumericConverter { + + using result_type = uint8_t; + using source_type = float; + static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; + + static result_type convert(source_type const & s) { + std::fesetround(FE_TONEAREST); + int32_t intermediate = (int32_t)std::nearbyint(s); + + // Low-end saturation + intermediate = std::max(intermediate, (int32_t)std::numeric_limits::lowest()); + + // High-end saturation + intermediate = std::min(intermediate, (int32_t)std::numeric_limits::max()); + + return static_cast(intermediate); + } + + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +template <> +struct NumericConverter { + + using result_type = uint8_t; + using source_type = float; + static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; + + static result_type convert(source_type const & s) { + std::fesetround(FE_TOWARDZERO); + int32_t intermediate = (int32_t)std::nearbyint(s); + + // Low-end saturation + intermediate = std::max(intermediate, (int32_t)std::numeric_limits::lowest()); + + // High-end saturation + intermediate = std::min(intermediate, (int32_t)std::numeric_limits::max()); + + return static_cast(intermediate); + } + + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + #endif +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partial specializations for float => integer_subbyte +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct NumericConverter, float, Round> { +private: + static constexpr bool result_is_signed = true; + +public: + using result_type = integer_subbyte; + using source_type = float; + static constexpr FloatRoundStyle round_style = Round; + + CUTLASS_HOST_DEVICE static result_type + convert(source_type const& src) { + using middle_type = int; + static_assert(8 * sizeof(middle_type) > Bits, "This conversion " + "requires that integer_subbyte have fewer representation bits " + "than the number of bits in int."); + + auto middle = NumericConverter::convert(src); + return NumericConverter::convert(middle); + } + + CUTLASS_HOST_DEVICE result_type + operator()(source_type const& s) const { + return convert(s); + } +}; + +template +struct NumericConverter, float, Round> { +private: + static constexpr bool result_is_signed = false; + +public: + using result_type = integer_subbyte; + using source_type = float; + static constexpr FloatRoundStyle round_style = Round; + + CUTLASS_HOST_DEVICE static result_type + convert(source_type const& src) { + using middle_type = unsigned; + static_assert(8 * sizeof(middle_type) > Bits, "This conversion " + "requires that integer_subbyte have fewer representation bits " + "than the number of bits in unsigned int."); + + auto middle = NumericConverter::convert(src); + return NumericConverter::convert(middle); + } + + CUTLASS_HOST_DEVICE result_type + operator()(source_type const& s) const { + return convert(s); + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// /// Partial specialization for float <= cutlass::half_t @@ -706,8 +859,8 @@ struct NumericConverterClamp { CUTLASS_HOST_DEVICE static result_type convert(source_type const & s) { NumericConverter convert_op; - result_type const kClamp_max = platform::numeric_limits::max(); - result_type const kClamp_min = platform::numeric_limits::lowest(); + result_type const kClamp_max = cutlass::platform::numeric_limits::max(); + result_type const kClamp_min = cutlass::platform::numeric_limits::lowest(); if (s < (source_type)kClamp_min) return kClamp_min; if (s > (source_type)kClamp_max) @@ -814,7 +967,7 @@ struct NumericArrayConverter { } else { result_type result; for (int i = 0; i < N; ++i) { - result[i] = conj(source[i]); + result[i] = conj(static_cast(source[i])); } return result; } @@ -2317,7 +2470,6 @@ struct NumericArrayConverter { CUTLASS_HOST_DEVICE static result_type convert(source_type const & source) { - // Convert to int to int8_t NumericConverter destination_converter; result_type result; result[0] = destination_converter(source[0]); @@ -2330,6 +2482,29 @@ struct NumericArrayConverter { } }; +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + NumericConverter destination_converter; + result_type result; + result[0] = destination_converter(source[0]); + return result; + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + // To convert a FP32 to Int that has less than 32 bits, we need to convert it to int32 first. template < typename T, @@ -2342,7 +2517,7 @@ struct NumericArrayFP32ToIntConverter { using source_type = Array; static FloatRoundStyle const round_style = Round; - static_assert(platform::numeric_limits::is_integer, "the dest type has to be int."); + static_assert(cutlass::platform::numeric_limits::is_integer, "the dest type has to be int."); CUTLASS_HOST_DEVICE static result_type convert(source_type const & source) { diff --git a/include/cutlass/numeric_size.h b/include/cutlass/numeric_size.h index 42bc418a40..4ff83bab88 100644 --- a/include/cutlass/numeric_size.h +++ b/include/cutlass/numeric_size.h @@ -60,11 +60,12 @@ struct sizeof_bits { ///////////////////////////////////////////////////////////////////////////////////////////////// /// Returns the number of bytes required to hold a specified number of bits +template CUTLASS_HOST_DEVICE -CUTLASS_CONSTEXPR_IF_CXX17 -int -bits_to_bytes(int bits) { - return (bits + 7) / 8; +constexpr +R +bits_to_bytes(T bits) { + return (R(bits) + R(7)) / R(8); } ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/pipeline/sm90_pipeline.hpp b/include/cutlass/pipeline/sm90_pipeline.hpp index 2ab7ae0455..68bb04b0d8 100644 --- a/include/cutlass/pipeline/sm90_pipeline.hpp +++ b/include/cutlass/pipeline/sm90_pipeline.hpp @@ -49,7 +49,6 @@ using namespace cute; enum class BarrierStatus : uint32_t { WaitAgain = 0u, WaitDone = 1u, - WaitOnly = 2u }; class ArrivalToken { @@ -62,7 +61,7 @@ class ArrivalToken { CUTLASS_HOST_DEVICE BarrierStatus get() const { - return barrier_status_;; + return barrier_status_; } CUTLASS_HOST_DEVICE @@ -241,7 +240,7 @@ public : , full_barrier_ptr_(&storage.full_barrier_[0]) , empty_barrier_ptr_(&storage.empty_barrier_[0]) { - int warp_idx = canonical_warp_idx(); + int warp_idx = canonical_warp_idx_sync(); int lane_predicate = cute::elect_one_sync(); if (warp_idx == 0 && lane_predicate == 1) { @@ -344,7 +343,7 @@ public : CUTLASS_DEVICE void producer_tail(PipelineState state) { for (int count = 0; count < Stages; ++count) { - producer_acquire(state, {BarrierStatus::WaitOnly}); + empty_barrier_ptr_[state.index()].wait(state.phase()); ++state; } } @@ -394,7 +393,7 @@ private : if (skip_wait) { return {BarrierStatus::WaitDone}; } - uint32_t barrier_status = empty_barrier_ptr_[stage].try_wait(phase); + bool barrier_status = empty_barrier_ptr_[stage].try_wait(phase); return {static_cast(barrier_status)}; } @@ -403,9 +402,6 @@ private : if (barrier_token != BarrierStatus::WaitDone) { empty_barrier_ptr_[stage].wait(phase); } - if (barrier_token == BarrierStatus::WaitOnly) { - return; - } if (params_.is_leader) { full_barrier_ptr_[stage].arrive_and_expect_tx(params_.transaction_bytes); @@ -456,7 +452,7 @@ private : if (skip_wait) { return {BarrierStatus::WaitDone}; } - uint32_t barrier_status = full_barrier_ptr_[stage].try_wait(phase); + bool barrier_status = full_barrier_ptr_[stage].try_wait(phase); return {static_cast(barrier_status)}; } @@ -465,7 +461,7 @@ private : if (skip_wait) { return {BarrierStatus::WaitDone}; } - uint32_t barrier_status = full_barrier_ptr_[stage].test_wait(phase); + bool barrier_status = full_barrier_ptr_[stage].test_wait(phase); return {static_cast(barrier_status)}; } @@ -659,7 +655,7 @@ public : , full_barrier_ptr_(storage.full_barrier_.data()) , empty_barrier_ptr_(storage.empty_barrier_.data()) { - int warp_idx = canonical_warp_idx(); + int warp_idx = canonical_warp_idx_sync(); int lane_predicate = cute::elect_one_sync(); // Barrier FULL, EMPTY init @@ -761,7 +757,7 @@ public : if (skip_wait) { return {BarrierStatus::WaitDone}; } - uint32_t barrier_status = empty_barrier_ptr_[stage].try_wait(phase); + bool barrier_status = empty_barrier_ptr_[stage].try_wait(phase); return {static_cast(barrier_status)}; } @@ -793,7 +789,7 @@ public : if (skip_wait) { return {BarrierStatus::WaitDone}; } - uint32_t barrier_status = full_barrier_ptr_[stage].try_wait(phase); + bool barrier_status = full_barrier_ptr_[stage].try_wait(phase); return {static_cast(barrier_status)}; } @@ -802,7 +798,7 @@ public : if (skip_wait) { return {BarrierStatus::WaitDone}; } - uint32_t barrier_status = full_barrier_ptr_[stage].test_wait(phase); + bool barrier_status = full_barrier_ptr_[stage].test_wait(phase); return {static_cast(barrier_status)}; } @@ -824,20 +820,31 @@ public : // Simple producer-consumer async Pipeline class // /////////////////////////////////////////////////////////////////////////////////////////////////// -template -class PipelineAsync { -public : - using FullBarrier = cutlass::arch::ClusterBarrier; - using EmptyBarrier = cutlass::arch::ClusterBarrier; - using ProducerBarrierType = FullBarrier::ValueType; - using ConsumerBarrierType = EmptyBarrier::ValueType; - static constexpr uint32_t Stages = Stages_; - using PipelineState = cutlass::PipelineState; - struct SharedStorage { +namespace PipelineDetail { + template + using PipelineAsyncPipelineState = cutlass::PipelineState; + + template + struct PipelineAsyncSharedStorage { + using FullBarrier = cutlass::arch::ClusterBarrier; + using EmptyBarrier = cutlass::arch::ClusterBarrier; + FullBarrier full_barrier_[Stages]; EmptyBarrier empty_barrier_[Stages]; }; +}; + +template +class PipelineAsync { +public : + static constexpr uint32_t Stages = Stages_; + using SharedStorage = PipelineDetail::PipelineAsyncSharedStorage; + using FullBarrier = typename SharedStorage::FullBarrier; + using EmptyBarrier = typename SharedStorage::EmptyBarrier; + using ProducerBarrierType = typename FullBarrier::ValueType; + using ConsumerBarrierType = typename EmptyBarrier::ValueType; + using PipelineState = PipelineDetail::PipelineAsyncPipelineState; enum class ThreadCategory { NonParticipant, @@ -867,7 +874,7 @@ public : full_barrier_ptr_(&storage.full_barrier_[0]), empty_barrier_ptr_(&storage.empty_barrier_[0]) { - int warp_idx = canonical_warp_idx(); + int warp_idx = canonical_warp_idx_sync(); int lane_predicate = cute::elect_one_sync(); // Barrier FULL, EMPTY init @@ -960,6 +967,11 @@ public : consumer_release(state.index()); } + CUTLASS_DEVICE + ProducerBarrierType* producer_get_barrier(uint32_t stage) { + return reinterpret_cast(&full_barrier_ptr_[stage]); + } + private: Params params_; FullBarrier *full_barrier_ptr_; @@ -970,7 +982,7 @@ public : if (skip_wait) { return {BarrierStatus::WaitDone}; } - uint32_t barrier_status = empty_barrier_ptr_[stage].try_wait(phase); + bool barrier_status = empty_barrier_ptr_[stage].try_wait(phase); return {static_cast(barrier_status)}; } @@ -986,17 +998,12 @@ public : full_barrier_ptr_[stage].arrive(); } - CUTLASS_DEVICE - ProducerBarrierType* producer_get_barrier(uint32_t stage) { - return reinterpret_cast(&full_barrier_ptr_[stage]); - } - CUTLASS_DEVICE ConsumerToken consumer_try_wait(uint32_t stage, uint32_t phase, uint32_t skip_wait) { if (skip_wait) { return {BarrierStatus::WaitDone}; } - uint32_t barrier_status = full_barrier_ptr_[stage].try_wait(phase); + bool barrier_status = full_barrier_ptr_[stage].try_wait(phase); return {static_cast(barrier_status)}; } @@ -1005,13 +1012,13 @@ public : if (skip_wait) { return {BarrierStatus::WaitDone}; } - uint32_t barrier_status = full_barrier_ptr_[stage].test_wait(phase); + bool barrier_status = full_barrier_ptr_[stage].test_wait(phase); return {static_cast(barrier_status)}; } CUTLASS_DEVICE void consumer_wait(uint32_t stage, uint32_t phase) { - uint32_t done = full_barrier_ptr_[stage].test_wait(phase); + bool done = full_barrier_ptr_[stage].test_wait(phase); if (!done) { full_barrier_ptr_[stage].wait(phase); } @@ -1040,14 +1047,24 @@ public : // /////////////////////////////////////////////////////////////////////////////////////////////////// +namespace PipelineDetail { + template -class OrderedSequenceBarrier { -public : +struct OrderedSequenceBarrierSharedStorage { using Barrier = cutlass::arch::ClusterBarrier; + Barrier barrier_[SequenceDepth][SequenceLength]; +}; - struct SharedStorage { - Barrier barrier_[SequenceDepth][SequenceLength]; - }; +} // namespace PipelineDetail + +template +class OrderedSequenceBarrier { +public: + static constexpr int SequenceDepth = SequenceDepth_; + static constexpr int SequenceLength = SequenceLength_; + using SharedStorage = + PipelineDetail::OrderedSequenceBarrierSharedStorage; + using Barrier = typename SharedStorage::Barrier; struct Params { uint32_t group_id; @@ -1077,7 +1094,7 @@ private : barrier_ptr_(&storage.barrier_[0][0]), // Group 0 - starts with an opposite phase stage_({0, params.group_id == 0, 0}) { - int warp_idx = canonical_warp_idx(); + int warp_idx = canonical_warp_idx_sync(); int lane_predicate = cute::elect_one_sync(); // Barrier FULL, EMPTY init @@ -1122,4 +1139,31 @@ private : //////////////////////////////////////////////////////////////////////////////////////////////////// +// Synchronization call. Blocks until barriers are initialized in shared memory. +CUTLASS_DEVICE +void +pipeline_init_wait(int cluster_size) { + if (cluster_size > 1) { + cute::cluster_wait(); + } + else { + __syncthreads(); + } +} + +// Used to guarantee that the Pipeline init is visible +// to all producers and consumer threadblocks in the cluster +CUTLASS_DEVICE +void +pipeline_init_arrive_relaxed(int cluster_size) { + if (cluster_size > 1) { + cute::cluster_arrive_relaxed(); + } + else { + __syncthreads(); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // end namespace cutlass diff --git a/include/cutlass/platform/platform.h b/include/cutlass/platform/platform.h index ba74ae723b..e6a445ea2d 100644 --- a/include/cutlass/platform/platform.h +++ b/include/cutlass/platform/platform.h @@ -323,23 +323,9 @@ using std::pair; #endif -/// The type used as a compile-time boolean with true value. -typedef integral_constant true_type; - -/// The type used as a compile-time boolean with false value. -typedef integral_constant false_type; - -#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus <= 201402L)) || (defined(_MSC_VER) && (_MSC_VER < 1900)) - -/// std::bool_constant -template -struct bool_constant : platform::integral_constant {}; - -#else - -using std::bool_constant; - -#endif +using CUTLASS_STL_NAMESPACE::bool_constant; +using CUTLASS_STL_NAMESPACE::true_type; +using CUTLASS_STL_NAMESPACE::false_type; #if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1700)) @@ -356,125 +342,52 @@ using std::nullptr_t; // Conditional metaprogramming //----------------------------------------------------------------------------- -#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201700L)) || (defined(_MSC_VER) && (_MSC_VER < 1600)) - -/// std::enable_if (true specialization) -template -struct enable_if { - typedef T type; -}; - -/// std::enable_if (false specialization) -template -struct enable_if {}; - -/// std::conditional (true specialization) -template -struct conditional { - typedef T type; -}; - -/// std::conditional (false specialization) -template -struct conditional { - typedef F type; -}; - -template -using void_t = void; - -#else - -using std::enable_if; -using std::conditional; -using std::void_t; - -#endif - -#if (201703L <=__cplusplus) -/// std::conditional_t +using CUTLASS_STL_NAMESPACE::conditional; using CUTLASS_STL_NAMESPACE::conditional_t; -#endif +using CUTLASS_STL_NAMESPACE::enable_if; +using CUTLASS_STL_NAMESPACE::enable_if_t; +using CUTLASS_STL_NAMESPACE::void_t; //----------------------------------------------------------------------------- // Const/volatility specifiers //----------------------------------------------------------------------------- -#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201703L)) || (defined(_MSC_VER) && (_MSC_VER < 1500)) - -/// std::remove_const (non-const specialization) -template -struct remove_const { - typedef T type; -}; - -/// std::remove_const (const specialization) -template -struct remove_const { - typedef T type; -}; - -/// std::remove_volatile (non-volatile specialization) -template -struct remove_volatile { - typedef T type; -}; +using CUTLASS_STL_NAMESPACE::remove_const; +using CUTLASS_STL_NAMESPACE::remove_const_t; +using CUTLASS_STL_NAMESPACE::remove_cv; +using CUTLASS_STL_NAMESPACE::remove_cv_t; +using CUTLASS_STL_NAMESPACE::remove_reference; +using CUTLASS_STL_NAMESPACE::remove_reference_t; +using CUTLASS_STL_NAMESPACE::remove_volatile; +using CUTLASS_STL_NAMESPACE::remove_volatile_t; -/// std::remove_volatile (volatile specialization) -template -struct remove_volatile { - typedef T type; -}; +// remove_cvref and remove_cvref_t are C++20 features, +// but CUTLASS finds them useful enough to back-port. +#if defined(__cpp_lib_remove_cvref) -/// std::remove_cv -template -struct remove_cv { - typedef typename remove_volatile::type>::type type; -}; +using CUTLASS_STL_NAMESPACE::remove_cvref; +using CUTLASS_STL_NAMESPACE::remove_cvref_t; #else -using std::remove_const; -using std::remove_volatile; -using std::remove_cv; - -#endif - -#if (201703L <=__cplusplus) - -/// std::remove_cv_t -using CUTLASS_STL_NAMESPACE::remove_cv_t; -/// std::remove_reference_t -using CUTLASS_STL_NAMESPACE::remove_reference_t; - -// C++20 -// using std::remove_cvref; template struct remove_cvref { using type = remove_cv_t>; }; -// C++20 -// using std::remove_cvref_t; template using remove_cvref_t = typename remove_cvref::type; #endif - //----------------------------------------------------------------------------- // Type relationships //----------------------------------------------------------------------------- -#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500)) +using CUTLASS_STL_NAMESPACE::is_same; +using CUTLASS_STL_NAMESPACE::is_same_v; -/// std::is_same (false specialization) -template -struct is_same : false_type {}; - -/// std::is_same (true specialization) -template -struct is_same : true_type {}; +#if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500)) /// Helper for std::is_base_of template @@ -507,7 +420,6 @@ struct is_base_of #else -using std::is_same; using std::is_base_of; #endif @@ -516,6 +428,11 @@ using std::is_base_of; // Type properties //----------------------------------------------------------------------------- +using CUTLASS_STL_NAMESPACE::is_arithmetic; +using CUTLASS_STL_NAMESPACE::is_arithmetic_v; +using CUTLASS_STL_NAMESPACE::is_void; +using CUTLASS_STL_NAMESPACE::is_void_v; + #if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500)) /// std::is_volatile @@ -536,10 +453,6 @@ struct is_pointer_helper : true_type {}; template struct is_pointer : is_pointer_helper::type> {}; -/// std::is_void -template -struct is_void : is_same::type> {}; - /// std::is_integral template struct is_integral : false_type {}; @@ -579,11 +492,6 @@ struct is_floating_point (is_same::type>::value || is_same::type>::value)> {}; -/// std::is_arithmetic -template -struct is_arithmetic - : integral_constant::value || is_floating_point::value)> {}; - /// std::is_fundamental template struct is_fundamental @@ -595,10 +503,8 @@ struct is_fundamental using std::is_volatile; using std::is_pointer; -using std::is_void; using std::is_integral; using std::is_floating_point; -using std::is_arithmetic; using std::is_fundamental; #endif @@ -635,6 +541,12 @@ using CUTLASS_STL_NAMESPACE::is_unsigned_v; #endif +//----------------------------------------------------------------------------- +// +//----------------------------------------------------------------------------- + +using CUTLASS_STL_NAMESPACE::declval; + //----------------------------------------------------------------------------- // bit_cast //----------------------------------------------------------------------------- @@ -649,6 +561,12 @@ constexpr To CUTLASS_HOST_DEVICE bit_cast(const From& src) noexcept return reinterpret_cast(src); } +//----------------------------------------------------------------------------- +// Convertable +//----------------------------------------------------------------------------- +using CUTLASS_STL_NAMESPACE::is_convertible; +using CUTLASS_STL_NAMESPACE::is_convertible_v; + //----------------------------------------------------------------------------- // Alignment and layout utilities //----------------------------------------------------------------------------- @@ -892,6 +810,7 @@ struct numeric_limits { CUTLASS_HOST_DEVICE static constexpr int32_t max() noexcept { return 2147483647;} static constexpr bool is_integer = true; + static constexpr bool has_infinity = false; }; template <> @@ -901,6 +820,7 @@ struct numeric_limits { CUTLASS_HOST_DEVICE static constexpr int16_t max() noexcept { return 32767;} static constexpr bool is_integer = true; + static constexpr bool has_infinity = false; }; template <> @@ -910,6 +830,7 @@ struct numeric_limits { CUTLASS_HOST_DEVICE static constexpr int8_t max() noexcept { return 127;} static constexpr bool is_integer = true; + static constexpr bool has_infinity = false; }; @@ -920,6 +841,7 @@ struct numeric_limits { CUTLASS_HOST_DEVICE static constexpr uint32_t max() noexcept { return 4294967295U;} static constexpr bool is_integer = true; + static constexpr bool has_infinity = false; }; template <> @@ -929,6 +851,7 @@ struct numeric_limits { CUTLASS_HOST_DEVICE static constexpr uint16_t max() noexcept { return 65535U;} static constexpr bool is_integer = true; + static constexpr bool has_infinity = false; }; template <> @@ -938,16 +861,41 @@ struct numeric_limits { CUTLASS_HOST_DEVICE static constexpr uint8_t max() noexcept { return 255U;} static constexpr bool is_integer = true; + static constexpr bool has_infinity = false; }; template <> struct numeric_limits { CUTLASS_HOST_DEVICE static constexpr float infinity() noexcept { return bit_cast(0x7f800000);} + CUTLASS_HOST_DEVICE + static constexpr float max() noexcept { return bit_cast(0x7f7fffff);} static constexpr bool is_integer = false; static constexpr bool has_infinity = true; }; +/// Returns a value that curries the `std::maximum()` function into the identity +/// function. No value will compare < than this value. +template +constexpr T identity_for_maximum() { + if constexpr (numeric_limits::has_infinity) { + return -numeric_limits::infinity(); + } else { + return numeric_limits::lowest(); + } +} + +/// Returns a value that curries the `std::minimum()` function into the identity +/// function. No value will compare > than this value. +template +constexpr T identity_for_minimum() { + if constexpr (numeric_limits::has_infinity) { + return numeric_limits::infinity(); + } else { + return numeric_limits::max(); + } +} + /// std::float_round_style using CUTLASS_STL_NAMESPACE::float_round_style; using CUTLASS_STL_NAMESPACE::round_indeterminate; diff --git a/include/cutlass/relatively_equal.h b/include/cutlass/relatively_equal.h index 81e80281b9..26b7c66b19 100644 --- a/include/cutlass/relatively_equal.h +++ b/include/cutlass/relatively_equal.h @@ -35,14 +35,15 @@ #pragma once #include "numeric_types.h" +#include "complex.h" namespace cutlass { ///////////////////////////////////////////////////////////////////////////////////////////////// -template +template CUTLASS_HOST_DEVICE -bool relatively_equal(T a, T b, T epsilon, T nonzero_floor); +bool relatively_equal(T a, T b, U epsilon, U nonzero_floor); ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -220,6 +221,55 @@ bool relatively_equal(double a, double b, double epsilon, double nonzero return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); } +template +CUTLASS_HOST_DEVICE +bool relatively_equal(complex a, complex b, T epsilon, T nonzero_floor) { +#if defined(__CUDACC_RTC__) + using cuda::std::abs; +#else + using std::abs; +#endif + + T abs_A = abs(a); + T abs_B = abs(b); + T diff = abs(a - b); + complex zero = complex{T{}, T{}}; + + if (a == b) { + return true; + } + else if (a == zero || b == zero || diff < nonzero_floor) { + return diff < epsilon * nonzero_floor; + } + + return diff < epsilon * (abs_A + abs_B); +} + +template +CUTLASS_HOST_DEVICE +bool relatively_equal(complex a, complex b, complex epsilon, complex nonzero_floor) { +#if defined(__CUDACC_RTC__) + using cuda::std::abs; +#else + using std::abs; +#endif + + T abs_A = abs(a); + T abs_B = abs(b); + complex diff = a - b; + T abs_diff = abs(diff); + complex zero = complex{T{}, T{}}; + + if (a == b) { + return true; + } + else if (a == zero || b == zero || abs_diff < abs(nonzero_floor)) { + return abs_diff < abs(epsilon * nonzero_floor); + } + + return abs_diff < abs(epsilon) * (abs_A + abs_B); +} + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/include/cutlass/subbyte_reference.h b/include/cutlass/subbyte_reference.h index 694962b11b..af697f62f8 100644 --- a/include/cutlass/subbyte_reference.h +++ b/include/cutlass/subbyte_reference.h @@ -39,6 +39,43 @@ namespace cutlass { +namespace detail { +// This is an implementation detail of cutlass::SubbyteReference and. +// cutlass::HostTensor. For a given logical element type Element, +// and its corresponding storage (physical) element type StorageUnit, +// it computes quantities that help with managing allocations. +// +// CUTLASS uses a hidden "ContainerUnitType" or StorageUnit type to support +// packed arrays of subbyte types such as int4. Element is the "logical" type +// for computations, while CUTLASS uses StorageUnit as the element type +// of a packed array of Element. If Element is not a subbyte type, +// then the corresponding StorageUnit type is just Element itself. +// +// The ContainerType is always calculated as an array StorageUnit type (the StorageUnit +// is always a byte for subbyte types), +// and its number of bits is the lcm of the subbyte type's number of bits and 8. +// Below are some examples for different subbyte types. +// +// * Subbyte Type=int2, ContainerType=StorageUnit[1] (StorageUnit=uint8_t) +// * Subbyte Type=int4, ContainerType=StorageUnit[1] (StorageUnit=uint8_t) +template +struct StorageContainerCalculator { + // kContainerTypeNumBits: The number of bits needed for ContainerType + static constexpr int kContainerTypeNumBits = (sizeof_bits::value < 8) ? cutlass::lcm_cxx11(sizeof_bits::value, sizeof_bits::value) : sizeof_bits::value; + static_assert(kContainerTypeNumBits % sizeof_bits::value == 0, "The bits of ContainerType should be divisible by the element's number of bits"); + // kContainerTypeNumLogicalElements: The number of logical Element instance(s) that can be stored per ContainerType instance + static constexpr int kContainerTypeNumLogicalElements = kContainerTypeNumBits / sizeof_bits::value; + /// 3. kContainerTypeNumBytes: The number of bytes per ContainerType instance + static constexpr int kContainerTypeNumBytes = kContainerTypeNumBits / 8; + /// 4. kContainerTypeNumBytes: The number of base StorageUnit in the ContainerType + static constexpr int kContainerTypeNumStorageUnit = kContainerTypeNumBits / sizeof_bits::value; + + static_assert(kContainerTypeNumBits != 0, "kContainerTypeNumBits can not be zero"); + static_assert(kContainerTypeNumLogicalElements != 0, "kContainerTypeNumLogicalElements can not be zero"); + static_assert(kContainerTypeNumBytes != 0, "kContainerTypeNumBytes can not be zero"); +}; +} + ///////////////////////////////////////////////////////////////////////////////////////////////// /// This class provides a mechanism for packing and unpacking elements smaller than one byte. It @@ -623,12 +660,16 @@ class SubbyteReference::value, sizeof_bits::value); - static int const kNumStorageUnitPerStoredVec = kBitsStoredVec / sizeof_bits::value; +private: + using StorageContainerCalculator = cutlass::detail::StorageContainerCalculator; +public: + static int const kBitsStoredVec = StorageContainerCalculator::kContainerTypeNumBits; + static int const kNumStorageUnitPerStoredVec = StorageContainerCalculator::kContainerTypeNumStorageUnit; using StorageVec = StorageUnit[kNumStorageUnitPerStoredVec]; using StorageVecPointer = StorageVec *; diff --git a/include/cutlass/tfloat32.h b/include/cutlass/tfloat32.h index 2666d921c1..8e7ab884cf 100644 --- a/include/cutlass/tfloat32.h +++ b/include/cutlass/tfloat32.h @@ -61,7 +61,19 @@ struct alignas(4) tfloat32_t { // // Methods // + private: + CUTLASS_HOST_DEVICE + static uint32_t float_to_storage(float s) { + #if defined(__CUDA_ARCH__) + uint32_t result = reinterpret_cast(s); + #else + uint32_t result; + std::memcpy(&result, &s, sizeof(float)); + #endif + return result; + } + public: /// Constructs from an unsigned int CUTLASS_HOST_DEVICE static tfloat32_t bitcast(uint32_t x) { @@ -73,7 +85,7 @@ struct alignas(4) tfloat32_t { /// Emulated rounding is fast in device code CUTLASS_HOST_DEVICE static tfloat32_t round_half_ulp_truncate(float const &s) { - uint32_t x = reinterpret_cast(s); + uint32_t x = float_to_storage(s); #if defined(__CUDA_ARCH__) if (::isfinite(s)) { @@ -88,24 +100,19 @@ struct alignas(4) tfloat32_t { return tfloat32_t::bitcast(x); } - /// Default constructor tfloat32_t() = default; /// Floating-point conversion - round toward nearest even CUTLASS_HOST_DEVICE -// explicit tfloat32_t(float x): storage(round_half_ulp_truncate(x).storage) { } - tfloat32_t(float x): storage(round_half_ulp_truncate(x).storage) { } + explicit tfloat32_t(float x): storage(round_half_ulp_truncate(x).raw()) { } - /// Floating-point conversion - round toward nearest even + // Conversion from double (this rounds twice) CUTLASS_HOST_DEVICE -// explicit tfloat32_t(double x): tfloat32_t(float(x)) { - tfloat32_t(double x): tfloat32_t(float(x)) { - } + explicit tfloat32_t(double x): tfloat32_t(float(x)) { } /// Integer conversion - round toward zero CUTLASS_HOST_DEVICE -// explicit tfloat32_t(int x) { - tfloat32_t(int x) { + explicit tfloat32_t(int x) { float flt = static_cast(x); #if defined(__CUDA_ARCH__) storage = reinterpret_cast(flt); @@ -114,7 +121,7 @@ struct alignas(4) tfloat32_t { #endif } - /// Converts to float + // Conversion to float CUTLASS_HOST_DEVICE operator float() const { @@ -122,7 +129,7 @@ struct alignas(4) tfloat32_t { // of the mantissa. unsigned bits = (storage & ~0x1fffu); - #if defined(__CUDA_ARCH__) + #if defined(__CUDA_ARCH__) return reinterpret_cast(bits); #else float flt; @@ -131,7 +138,7 @@ struct alignas(4) tfloat32_t { #endif } - /// Converts to float + /// Converts to double CUTLASS_HOST_DEVICE explicit operator double() const { return double(float(*this)); @@ -253,11 +260,11 @@ cutlass::tfloat32_t sqrt(cutlass::tfloat32_t const& h) { CUTLASS_HOST_DEVICE tfloat32_t copysign(tfloat32_t const& a, tfloat32_t const& b) { - uint32_t a_mag = (reinterpret_cast(a) & 0x7fffffff); - uint32_t b_sign = (reinterpret_cast(b) & 0x80000000); + uint32_t a_mag = (a.raw() & 0x7fffffff); + uint32_t b_sign = (b.raw() & 0x80000000); uint32_t result = (a_mag | b_sign); - return reinterpret_cast(result); + return tfloat32_t::bitcast(result); } /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -372,13 +379,7 @@ tfloat32_t operator+(tfloat32_t const& lhs, tfloat32_t const& rhs) { CUTLASS_HOST_DEVICE tfloat32_t operator-(tfloat32_t const& lhs) { - union u_tff32 { - float val_f32; - tfloat32_t val_tf; - CUTLASS_HOST_DEVICE u_tff32() : val_f32(0) { } - }; - union u_tff32 x; x.val_f32 = -reinterpret_cast(lhs); - return x.val_tf; + return tfloat32_t::bitcast(0x80000000 ^ lhs.raw()); } CUTLASS_HOST_DEVICE diff --git a/include/cutlass/transform/device/transform_universal_adapter.hpp b/include/cutlass/transform/device/transform_universal_adapter.hpp new file mode 100644 index 0000000000..5fc5ab2d94 --- /dev/null +++ b/include/cutlass/transform/device/transform_universal_adapter.hpp @@ -0,0 +1,242 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Transform Kernel Universal adapter +*/ + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/cluster_launch.hpp" +#include "cutlass/device_kernel.h" +#include "cutlass/cuda_host_adapter.hpp" + +namespace cutlass::transform::device { + +template +class TransformUniversalAdapter +{ +public: + using TransformKernel = TransformKernel_; + using Arguments = typename TransformKernel::Arguments; + using Params = typename TransformKernel::Params; + +private: + Params params_; + static constexpr bool const EnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; + +public: + Params const& params() const { + return this->params_; + } + + static Status + can_implement(Arguments const& args) { + return TransformKernel::can_implement(args); + } + + static size_t + get_workspace_size(Arguments const& args) { + return TransformKernel::get_workspace_size(args); + } + + static dim3 + get_grid_shape(Arguments const& args, void* workspace = nullptr) { + auto tmp_params = TransformKernel::to_underlying_arguments(args, workspace); + return TransformKernel::get_grid_shape(tmp_params); + } + + static dim3 + get_grid_shape(Params const& params) { + return TransformKernel::get_grid_shape(params); + } + + Status + initialize( + Arguments & args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + + CUTLASS_TRACE_HOST("TransformUniversalAdapter::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = TransformKernel::initialize_workspace(args, workspace, stream, cuda_adapter); + if (status != Status::kSuccess) { + return status; + } + + // Initialize the Params structure + this->params_ = TransformKernel::to_underlying_arguments(args, workspace); + + // Don't set the function attributes - require the CudaHostAdapter to set it. + if constexpr (EnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + return Status::kSuccess; + } + else { + // + // Account for dynamic smem capacity if needed + // + int smem_size = TransformKernel::SharedStorageSize; + + CUTLASS_ASSERT(cuda_adapter == nullptr); + + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + } + + return Status::kSuccess; + } + + static Status + run( + Params& params, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + int32_t kernel_index = 0) { + + CUTLASS_TRACE_HOST("TransformUniversalAdapter::run()"); + dim3 const block = TransformKernel::get_block_shape(); + dim3 const grid = get_grid_shape(params); + // Currently only support 1x1x1 for transform kernel. + dim3 const cluster = {1,1,1}; + + // configure smem size and carveout + int smem_size = TransformKernel::SharedStorageSize; + + Status launch_result; + + // Use extended launch API only for mainloops that use it + if constexpr(TransformKernel::ArchTag::kMinComputeCapability >= 90) { + void* kernel_params[] = {¶ms}; + + if constexpr (EnableCudaHostAdapter) { + // + // Use the cuda host adapter + // + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + launch_result = cuda_adapter->launch( + grid, cluster, block, smem_size, stream, kernel_params, kernel_index); + } + else { + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + void const* kernel = (void const*) device_kernel; + launch_result = ClusterLauncher::launch( + grid, cluster, block, smem_size, stream, kernel, kernel_params); + } + } + else { + launch_result = Status::kSuccess; + if constexpr (EnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + void* kernel_params[] = {¶ms}; + launch_result = cuda_adapter->launch( + grid, block, smem_size, stream, kernel_params, 0); + } + else { + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + device_kernel<<>>(params); + } + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result && Status::kSuccess == launch_result) { + return Status::kSuccess; + } else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + + Status + run( + Arguments & args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + int32_t kernel_index = 0) { + + Status status = initialize(args, workspace, stream, cuda_adapter); + if (Status::kSuccess == status) { + status = run(this->params_, stream, cuda_adapter, kernel_index); + } + return status; + } + + Status + operator()( + Arguments & args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + return run(args, workspace, stream, cuda_adapter); + } + + Status + run( + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + return run(this->params_, stream, cuda_adapter); + } + + Status + operator()( + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + return run(this->params_, stream, cuda_adapter); + } +}; + +} // namespace cutlass::transform::device diff --git a/include/cutlass/transform/kernel/filter_format_transformer.hpp b/include/cutlass/transform/kernel/filter_format_transformer.hpp new file mode 100644 index 0000000000..7538f2f432 --- /dev/null +++ b/include/cutlass/transform/kernel/filter_format_transformer.hpp @@ -0,0 +1,205 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/* \file + \brief Convolution filter format transformation kernel. +*/ + +#pragma once + +#include +#include + +#include "cutlass/coord.h" +#include "cutlass/arch/arch.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/cuda_host_adapter.hpp" + +#include "cute/int_tuple.hpp" +#include "cute/tensor.hpp" +#include "cute/config.hpp" + +namespace cutlass::transform::kernel { + +using namespace cute; + +enum class FilterFormat { + CKTRS, + CTRSK, + KTRSC +}; + +template < + FilterFormat SrcFormat, + FilterFormat DstFormat, + int NumDimensions, + class Element, + int AlignmentBytes = 16 +> +struct ConvFilterFormatTransformer { + static_assert(SrcFormat == FilterFormat::CKTRS, "Currently only source format of CKTRS is supported"); + static_assert(DstFormat == FilterFormat::CTRSK || DstFormat == FilterFormat::KTRSC, "Currently only destination format of CTRSK/KTRSC is supported"); + static_assert(AlignmentBytes % static_cast(sizeof(Element)) == 0, "Invalid alignment setting"); + + // In ktrsc order. + using FilterExtent = array; + + // Default cta tile shape: 32x32 + static constexpr auto CTATileShape = make_shape(Int<4 * AlignmentBytes / static_cast(sizeof(Element))>{}, Int<32>{}); + // Default thread layout: (4, 32) + static constexpr auto ThreadLayout = make_layout(make_shape(Int<4>{}, Int<32>{})); + + static constexpr uint32_t MaxThreadsPerBlock = 128; + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + using ArchTag = arch::Sm90; + + // Default ctor + CUTLASS_HOST_DEVICE + ConvFilterFormatTransformer() {} + + struct Arguments { + const void *src_ptr; + void *dst_ptr; + FilterExtent filter_extent; + }; + + struct Params { + using TensorSrc = decltype(make_tensor(make_gmem_ptr(recast_ptr(nullptr)), make_layout(take<0,NumDimensions>(FilterExtent{})))); + using TensorDst = decltype(make_tensor(make_gmem_ptr(recast_ptr(nullptr)), make_layout(make_shape(int32_t(0), int32_t(0))))); + + TensorSrc src; + TensorDst dst; + }; + + struct SharedStorage { + /* empty, no smem needed */ + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + static Status + can_implement(Arguments const& args) { + return Status::kSuccess; + } + + static size_t + get_workspace_size(Arguments const& args) { + return 0; + } + + static dim3 + get_block_shape() { + return dim3(size(shape(ThreadLayout)), 1, 1); + } + + static dim3 + get_grid_shape(Params const& params) { + auto dim_m = ceil_div(size<0>(shape(params.dst)), get<0>(CTATileShape)); + auto dim_n = ceil_div(size<1>(shape(params.dst)), get<1>(CTATileShape)); + + return dim3(dim_m, dim_n, 1); + } + + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + return Status::kSuccess; + } + + static Params + to_underlying_arguments(Arguments & args, void* workspace) { + auto k = args.filter_extent[0]; + auto c = args.filter_extent[NumDimensions - 1]; + auto srt = reverse(take<1,NumDimensions - 1>(args.filter_extent)); + + // source shape (s,r,t,k,c) + auto shape_src = flatten(make_shape(srt, k, c)); + auto shape_dst = DstFormat == FilterFormat::CTRSK ? make_shape(k, c * product(srt)) : make_shape(c, k * product(srt)); + + auto src = make_tensor(make_gmem_ptr(recast_ptr(args.src_ptr)), make_layout(shape_src)); + auto dst = make_tensor(make_gmem_ptr(recast_ptr(args.dst_ptr)), make_layout(shape_dst)); + + return Params{src, dst}; + } + + CUTLASS_DEVICE + void operator()(Params const& params, char *smem_buf) { + // Tile the input tensor into blocks + auto block_coord = make_coord(blockIdx.x, blockIdx.y); + auto block_shape = make_shape(Int<4 * AlignmentBytes / static_cast(sizeof(Element))>{}, Int<32>{}); + // Default thread layout: (4, 32) + auto thread_layout = make_layout(make_shape(Int<4>{}, Int<32>{})); + auto vec_layout = make_layout(make_shape(Int(sizeof(Element))>{}, Int<1>{})); + + Tensor tile_D = local_tile(params.dst, block_shape, block_coord); + + // Construct tiled copy + using AccessType = cutlass::AlignedArray; + using Atom = Copy_Atom, Element>; + + auto tiled_copy = make_tiled_copy(Atom{}, thread_layout, vec_layout); + auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x); + Tensor thr_tile_D = thr_copy.partition_D(tile_D); + + // shape (s, r, t) + auto shape_trs = take<0, NumDimensions - 2>(shape(params.src)); + // strided_c = c for format CTRSK, strided_c = k for format KTRSC + auto strided_c = DstFormat == FilterFormat::CTRSK ? get(shape(params.src)) : get(shape(params.src)); + // shape (s, r, t, c) for format CTRSK and shape (s, r, t, k) for format KTRSC + auto shape_ctrs = append(shape_trs, strided_c); + auto srtc_coord = idx2crd(int(blockIdx.y * get<1>(block_shape) + threadIdx.x / size<0>(thread_layout)), shape_ctrs); + // index of k for format CTRSK and index of c for format KTRSC + auto n_layout = make_layout(make_shape(gridDim.x, size<0>(thread_layout)), make_stride(size<0>(block_shape), size<0>(vec_layout))); + int n_idx = n_layout(make_coord(blockIdx.x, threadIdx.x % size<0>(thread_layout))); + + // Fragment to load from S and store to D + auto frag = make_fragment_like(thr_tile_D); + // Predicate tensor. + Tensor thr_tile_P = make_tensor(shape(thr_tile_D)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(frag); ++i) { + auto srt_coord = take<0, NumDimensions - 2>(srtc_coord); + auto kc_coord = DstFormat == FilterFormat::CTRSK ? + make_coord(n_idx+i, get(srtc_coord)) : + make_coord(get(srtc_coord), n_idx+i); + auto coord = flatten(make_coord(srt_coord, kc_coord)); + frag(i) = params.src(coord); + thr_tile_P(i) = elem_less(coord, shape(params.src)); + } + + // Copy from RMEM to GMEM + copy_if(tiled_copy, thr_tile_P, frag, thr_tile_D); + } +}; + +} // namespace cutlass::transform::kernel diff --git a/include/cutlass/version.h b/include/cutlass/version.h index 8133e69759..bcfce6c3a7 100644 --- a/include/cutlass/version.h +++ b/include/cutlass/version.h @@ -36,7 +36,7 @@ #define CUTLASS_MAJOR 3 #define CUTLASS_MINOR 5 -#define CUTLASS_PATCH 0 +#define CUTLASS_PATCH 1 #ifdef CUTLASS_VERSIONS_GENERATED #include "cutlass/version_extended.h" diff --git a/include/cutlass/workspace.h b/include/cutlass/workspace.h index 6dc0141cfb..6f1c3254c6 100644 --- a/include/cutlass/workspace.h +++ b/include/cutlass/workspace.h @@ -43,6 +43,7 @@ #include "cutlass.h" #include "cutlass/cuda_host_adapter.hpp" + namespace cutlass { ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -51,7 +52,7 @@ static constexpr int MinWorkspaceAlignment = 16; #if !defined(__CUDACC_RTC__) static Status -zero_workspace(void* workspace, size_t workspace_size, cudaStream_t stream = nullptr) { +zero_workspace(void* workspace, size_t workspace_size, cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { if (workspace_size > 0) { if (workspace == nullptr) { CUTLASS_TRACE_HOST(" error: device workspace must not be null"); @@ -59,12 +60,28 @@ zero_workspace(void* workspace, size_t workspace_size, cudaStream_t stream = nul } CUTLASS_TRACE_HOST(" clearing workspace"); + +#if defined(CUTLASS_ENABLE_CUDA_HOST_ADAPTER) && CUTLASS_ENABLE_CUDA_HOST_ADAPTER + // + // Use the cuda host adapter + // + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + if (Status::kSuccess != cuda_adapter->memsetDevice(workspace, static_cast(0), workspace_size, stream)) { + return Status::kErrorInternal; + } + } + else { + return Status::kErrorInternal; + } +#else cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_size, stream); if (cudaSuccess != result) { result = cudaGetLastError(); // to clear the error bit CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); return Status::kErrorInternal; } +#endif } return Status::kSuccess; @@ -83,20 +100,14 @@ fill_workspace(void* workspace, T fill_value, size_t fill_count, cudaStream_t st } CUTLASS_TRACE_HOST(" filling workspace"); - CUdeviceptr d_workspace = reinterpret_cast(workspace); #if defined(CUTLASS_ENABLE_CUDA_HOST_ADAPTER) && CUTLASS_ENABLE_CUDA_HOST_ADAPTER - // // Use the cuda host adapter // CUTLASS_ASSERT(cuda_adapter); if (cuda_adapter) { - Status status = Status::kErrorInternal; - - status = cuda_adapter->memsetDevice(workspace, fill_value, fill_count, stream); - - if (status!=Status::kSuccess) { + if (Status::kSuccess != cuda_adapter->memsetDevice(workspace, fill_value, fill_count, stream)) { return Status::kErrorInternal; } } @@ -104,6 +115,7 @@ fill_workspace(void* workspace, T fill_value, size_t fill_count, cudaStream_t st return Status::kErrorInternal; } #else + CUdeviceptr d_workspace = reinterpret_cast(workspace); CUresult result = CUDA_SUCCESS; if (sizeof(T) == 4) { result = cuMemsetD32Async(d_workspace, reinterpret_cast(fill_value), fill_count, stream); diff --git a/media/docs/build/building_in_windows_with_visual_studio.md b/media/docs/build/building_in_windows_with_visual_studio.md index bcd258ae1a..2c69e1ac5c 100644 --- a/media/docs/build/building_in_windows_with_visual_studio.md +++ b/media/docs/build/building_in_windows_with_visual_studio.md @@ -50,18 +50,6 @@ before attempting to clone or build CUTLASS. [This Microsoft help article](https://learn.microsoft.com/en-us/windows/win32/fileio/maximum-file-path-limitation?tabs=registry) explains different ways to change the registry setting. -# Limitations - -Currently, it's possible to build examples and tests. -Building the CUTLASS library (e.g., for profiling) with default settings does not currently work, -because Visual Studio's linker cannot handle more than 65535 symbols in a library. -(The symptom of this issue is a LNK1189 linker error.) -The known way to work around this Visual Studio limitation is to disable building CUTLASS's library, -by setting the CMake option `CUTLASS_ENABLE_LIBRARY` to `OFF`. -Another approach may be to limit the number of kernels in the library -by setting the CMake option `CUTLASS_LIBRARY_KERNELS` -so that CUTLASS tries to put fewer kernels in the library. - # Set up build environment 1. Run "git bash" to get a familiar command-line interface @@ -72,7 +60,7 @@ so that CUTLASS tries to put fewer kernels in the library. 4. Create the `build` subdirectory in the CUTLASS clone directory, and run CMake in it, specifying whatever CMake options are desired, e.g., - `cmake .. -DCUTLASS_NVCC_ARCHS=90a -DCUTLASS_ENABLE_LIBRARY=OFF` + `cmake .. -DCUTLASS_NVCC_ARCHS=90a` Alternate approaches may rely on the CMake GUI and/or Windows' native command line. @@ -91,3 +79,12 @@ Unlike with CMake's Makefile or Ninja generators, `CMAKE_BUILD_TYPE` has no effect on the Visual Studio generator, because the Visual Studio generator creates all build configurations. +# Tips + +With Windows builds, one may find that CMake reruns unnecessarily. +For example, cancelling a build and starting it again may rerun CMake. +This will in turn touch build files that result in unnecessary rebuilds. +One work-around is to set the CMake option `CMAKE_SUPPRESS_REGENERATION=ON`. +However, this turns off CMake's ability to detect on its own when it needs to rerun. +As a result, one will need to know when to rerun CMake by hand. + diff --git a/media/docs/build/building_with_clang_as_host_compiler.md b/media/docs/build/building_with_clang_as_host_compiler.md index 54b2c78e1f..c53500609b 100644 --- a/media/docs/build/building_with_clang_as_host_compiler.md +++ b/media/docs/build/building_with_clang_as_host_compiler.md @@ -9,7 +9,7 @@ Clang as both host and device compiler ("CUDA Clang"). # Software prerequisites -1. Clang (regularly tested with Clang 14; +1. Clang (regularly tested with Clang 17; occasionally tested with Clang 10 and greater) 2. CUDA Toolkit (tested with 12.2; other versions likely work) diff --git a/media/docs/cute/03_tensor.md b/media/docs/cute/03_tensor.md index 35c2e6f28b..f2412d1189 100644 --- a/media/docs/cute/03_tensor.md +++ b/media/docs/cute/03_tensor.md @@ -166,10 +166,10 @@ The `make_tensor_like` function makes an owning Tensor of register memory with t Calling `print` on each of the above tensors produces similar output ``` -rmem_4x8_col : ptr[32b](0x7ff1c8fff820) o (_4,_8):(_1,_4) -rmem_4x8_row : ptr[32b](0x7ff1c8fff8a0) o (_4,_8):(_8,_1) -rmem_4x8_pad : ptr[32b](0x7ff1c8fff920) o (_4,_8):(_32,_2) -rmem_4x8_like : ptr[32b](0x7f4158fffc60) o (_4,_8):(_8,_1) +rmem_4x8_col : ptr[32b](0x7fff48929460) o (_4,_8):(_1,_4) +rmem_4x8_row : ptr[32b](0x7fff489294e0) o (_4,_8):(_8,_1) +rmem_4x8_pad : ptr[32b](0x7fff489295e0) o (_4,_8):(_32,_2) +rmem_4x8_like : ptr[32b](0x7fff48929560) o (_4,_8):(_8,_1) ``` and we can see that each pointer address is unique indicating that each `Tensor` is a unique array-like allocation. @@ -195,7 +195,7 @@ For example, we can read and write to `Tensor`s using natural coordinates, using ```c++ Tensor A = make_tensor(Shape ,Int<13>>{}, - Stride,_64>{}); + Stride, _64>{}); float* b_ptr = ...; Tensor B = make_tensor(b_ptr, make_shape(13, 20)); @@ -317,7 +317,7 @@ Another common partitioning strategy is called a thread-value partitioning. In t // to 1D coordinates within a 4x8 tensor // (T8,V4) -> (M4,N8) auto tv_layout = Layout,Shape <_2, _2>>, - Stride,Stride<_4,_16>>>{}; // (8,4) + Stride,Stride<_4,_16>>>{}; // (8,4) // Construct a 4x8 tensor with any layout Tensor A = make_tensor(Shape<_4,_8>{}, LayoutRight{}); // (4,8) diff --git a/media/docs/cute/0x_gemm_tutorial.md b/media/docs/cute/0x_gemm_tutorial.md index 7fe5f81c92..533d4b4be0 100644 --- a/media/docs/cute/0x_gemm_tutorial.md +++ b/media/docs/cute/0x_gemm_tutorial.md @@ -195,7 +195,7 @@ As is evident, these smem layouts can be almost anything. Inside the kernel, the CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K ``` -Use of static layouts has a few advantages. +Use of static layouts has a few advantages. * Static layouts let us statically allocate shared memory as shown below. * Static layouts are often more efficient and allow CuTe to dispatch to optimized implementations. * Static layouts makes it easier to prove correctness of the algorithm and provide checks like the above -- the smem layout sizes are the same as the CTA tile sizes. @@ -227,7 +227,7 @@ if (thread0()) { ``` This would work, but we have lots of threads to use inside this CTA, so let's use them! -If we partition the two tiles of data across the threads in the CTA, then each thread can copy its own subtensor of data. There are lots of ways this partitioning could occur, however. +If we partition the two tiles of data across the threads in the CTA, then each thread can copy its own subtensor of data. There are lots of ways this partitioning could occur, however. The `gemm_nt` function defines two layouts of *threads* as ```c++ @@ -295,7 +295,7 @@ if (thread0()) { ``` This would work, but we have lots of threads to use inside this CTA, so let's use them! -If we partition the output tile `gC` across the threads in the CTA, then each thread can compute its own subtensor. There are lots of ways this partitioning could occur, however. +If we partition the output tile `gC` across the threads in the CTA, then each thread can compute its own subtensor. There are lots of ways this partitioning could occur, however. The `gemm_nt` and `gemm_tn` functions define one more layout of *threads*: ```cpp @@ -332,7 +332,7 @@ These thread layouts are then used to partition the tiles of data in global memo CUTE_STATIC_ASSERT_V(size<1>(tCrC) == size<0>(tCsB)); // THR_N CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCsB)); // BLK_K ``` -where we've used the same projection-style interface to avoid applying the `N`-mode of `tC` to the `(BLK_M,BLK_K)` shape of `sA` and avoid applying the `M`-mode of `tC` to the `(BLK_N,BLK_K)` shape of `sB`. +where we've used the same projection-style interface to avoid applying the `N`-mode of `tC` to the `(BLK_M,BLK_K)` shape of `sA` and avoid applying the `M`-mode of `tC` to the `(BLK_N,BLK_K)` shape of `sB`.

tC_partitioning.png diff --git a/media/docs/ide_setup.md b/media/docs/ide_setup.md new file mode 100644 index 0000000000..90e5dc2957 --- /dev/null +++ b/media/docs/ide_setup.md @@ -0,0 +1,122 @@ +[README](../../README.md#documentation) > **IDE Setup for CUTLASS Development** + +# IDE Setup for CUTLASS Development + +This document outlines instructions and tips for setting up a local editor for CUTLASS development, including support +for intellisense, go-to-definition, code formatting, and so on. + +## Overview +In order for any intellisense tool to work with CUTLASS, the following things need to be configured with it: +* Include paths, i.e. where the compiler (or in this case, the intellisense tool) should look for header files +* Compiler flags; especially the C++ standard (`--std`) +* Preprocessor variables; especially CUDA-related ones + +One usually needs to configure the above variables in a settings file. Below, two config approaches are described: +for VSCode, and for any editor that uses the clangd language server, which includes +Vim, Emacs, NeoVim, Sublime Text, and so on. Note that VSCode can also be configured to use clangd. +It might be worth setting up clangd for VSCode rather than the default intellisense, +and you might see faster responses and more stable performance with clangd. + +## VSCode Setup + +1. Install the [Official C/C++ extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode.cpptools) +1. Open settings... + 1. `Ctrl+Shift+P` to open the command palette + 1. Enter "C/C++" to filter results + 1. Select "C/C++ Edit Configurations (UI)" (or "... (JSON)" if you feel like editing the raw JSON) + 1. View the documentation for these settings + [here](https://code.visualstudio.com/docs/cpp/c-cpp-properties-schema-reference) +1. Edit "Include Path" to set up **include paths**. For CUTLASS, this includes the following: + * `${workspaceFolder}/include` + * `${workspaceFolder}/tools/util/include` + * `${workspaceFolder}/examples/common` + * ...others, depending on which files you edit +1. Edit C++ standard to be `c++17`, `gnu++17`, or equivalent. +1. Edit `defines` to define preprocessor variables. See +[Global Config below](#Global-Config) for examples. The important + ones include `__CUDACC_VER_MAJOR__`, `__CUDA_ARCH__`, `__CUDA_ARCH_FEAT_SM90_ALL__`. But configure + them according to your target architecture. +1. ...and possible edit any other fields for your specific setup. + +## clangd Setup + +`clangd` is a C++ language server that is part of the LLVM project. You must first set it up your specific IDE: +* `clangd` official [documentation](https://clangd.llvm.org/installation#editor-plugins) for editor setup. +* NeoVim setup is possible through [lsp](https://neovim.io/doc/user/lsp.html) and either manually installing clangd or +using an installation manager like Mason. + +Then, one needs to edit the config ([documentation](https://clangd.llvm.org/config)). One typically has a +**global** and a **per-project** config. + +### Global Config + +Here is one example for a global config. +On linux this is usually located at `~/.config/clangd/config.yaml` . Here is one example config for CUDA projects on SM90. +The key settings here are the preprocessor vars (`-D__CUDACC_VER_MAJOR__` , `-D__CUDA_ARCH__`) + +``` +CompileFlags: + Compiler: /usr/local/cuda/bin/nvcc + Add: + - --cuda-path=/usr/local/cuda + - --cuda-gpu-arch=sm_90a + - -I/usr/local/cuda/include + - "-xcuda" + # report all errors + - "-ferror-limit=0" + - --cuda-gpu-arch=sm_90a + - --std=c++17 + - "-D__INTELLISENSE__" + - "-D__CLANGD__" + - "-DCUDA_12_0_SM90_FEATURES_SUPPORTED" + - "-DCUTLASS_ARCH_MMA_SM90_SUPPORTED=1" + - "-D_LIBCUDACXX_STD_VER=12" + - "-D__CUDACC_VER_MAJOR__=12" + - "-D__CUDACC_VER_MINOR__=3" + - "-D__CUDA_ARCH__=900" + - "-D__CUDA_ARCH_FEAT_SM90_ALL" + - "-Wno-invalid-constexpr" + Remove: + # strip CUDA fatbin args + - "-Xfatbin*" + # strip CUDA arch flags + - "-gencode*" + - "--generate-code*" + # strip CUDA flags unknown to clang + - "-ccbin*" + - "--compiler-options*" + - "--expt-extended-lambda" + - "--expt-relaxed-constexpr" + - "-forward-unknown-to-host-compiler" + - "-Werror=cross-execution-space-call" +Hover: + ShowAKA: No +InlayHints: + Enabled: No +Diagnostics: + Suppress: + - "variadic_device_fn" + - "attributes_not_allowed" +``` + +### Local Config +Local config is needed to specify per-project settings, especially include paths. An example is: +``` +CompileFlags: + Add: + - -I/include/ + - -I/tools/util/include/ + - -I/cutlass/examples/common/ +``` + +Note that absolute paths are needed since clangd doesn't support relative paths. + +### Note on compile_commands.json +For typical C++ projects, clangd can *automatically* configure itself by parsing the `compile_commands.json` +generated by your CMake build. The path to such a file is by default `build/compile_commands.json` and is +configured by the `CompilationDatabase` config. + +This is usually a convenient way to configure projects, but it's not as simple for CUDA/nvcc projects, since +clang doesn't understand many of the compiler flags used by nvcc. Hence, for now, we don't recommend using +`compile_commands.json` to configure your CUDA project. + diff --git a/media/docs/profiler.md b/media/docs/profiler.md index 34282925db..35106f26ba 100644 --- a/media/docs/profiler.md +++ b/media/docs/profiler.md @@ -210,6 +210,8 @@ GEMM [int] --inst_k,--instruction-shape::k Math instruction shape in the K dimension [int] --min_cc,--minimum-compute-capability Minimum device compute capability [int] --max_cc,--maximum-compute-capability Maximum device compute capability + [enum] --raster_order={H|M|N} If supported by kernel, sets the tile raster direction + [int] --swizzle_size If supported by kernel, sets the 2D tile swizzle extent Examples: Profile a particular problem size: @@ -229,6 +231,9 @@ Using various input value distribution: $ cutlass_profiler --operation=Gemm --dist=gaussian,mean:0,stddev:3 $ cutlass_profiler --operation=Gemm --dist=sequential,start:0,delta:1 +Using CUTLASS 3.x GEMM kernel with a tile scheduler that supports runtime tile remapping and raster mode order: + $ cutlass_profiler --operation=Gemm --m=2048 --n=2048 --k=2048 --raster_order=M --swizzle_size=2 + Run a kernel with cta tile size of 256x128x32 and save workspace if results are incorrect (note that --cta-tile::k=32 is default cta-tile size): $ cutlass_profiler --operation=Gemm --cta_m=256 --cta_n=128 --cta_k=32 --save-workspace=incorrect diff --git a/media/docs/programming_guidelines.md b/media/docs/programming_guidelines.md index 392c390d06..d395c87e13 100644 --- a/media/docs/programming_guidelines.md +++ b/media/docs/programming_guidelines.md @@ -92,9 +92,13 @@ for (int idx = 0; idx < kN; ++idx) { // Loop has constant number of iterati // direct register access. } ``` - ## Style +### If you see an issue in code formatting, fix it + +You are empowered to reformat code. +Please, however, consider making reformatting changes separately from content-related changes. + ### No automatic code formatting Do not use any kind of automatic code formatting, @@ -128,48 +132,111 @@ and we should always strive to eliminate them. * [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html) +#### C is not a subset of C++ + +C is not a subset of C++. +Some valid C is not valid C++, and some valid "C-looking" C++ is not valid C. +See e.g., the informative C++ Standard Committee (WG21) document +[P2735R0](https://isocpp.org/files/papers/P2735R0.pdf), +which explains ways in which the same code has different behavior in C vs. C++. +In some cases, code that compiles in both C and C++, +and is correct in C, has undefined behavior (can crash or worse) in C++. +The "type.punning" section of P2735R0 specifically relates to unions. + #### Spacing and line length * Use spaces, not tabs. * Use 2 spaces to indent. -* Max 100 characters per line. +* Use at most 100 characters per line. +(Right-align tensor shape layout comments at column 120. +Please see below.) Lines longer than 100 characters typically wrap unfavorably when viewed in Github's pretty printer. -#### Function indentation +#### Formatting function declarations and definitions -When calling a function or function object with a long name, -break the line right after the invoking open parenthesis. -Here is an example. +Short function headers can go on one line. + +Do not insert a newline between the parenthesis +that closes the function's parameters and +the curly bracket that opens the function's body. ```c++ -detail::very_long_function_object_name{}( - params.long_parameter_name, some_operator.another_long_function_name()); +int short_name(int x, int y) { + return x + y; +} ``` -When declaring functions, indent function parameters like this. +If the function name and its parameters are too long to fit on one line, +break the line immediately after the opening parenthesis +that starts the parameter list. Then, double-indent the parameters +to distinguish them from the body of the function. ```c++ -void possibly_an_unusually_long_function_name( - std::uint32_t foo +void indeed_my_fellowbeings_this_function_name_is_unusually_long( + std::uint32_t foo, // parameters are double-indented std::uint32_t const* bar, TypeA a, TypeB b, - TypeC c) { - // ... the function's body ... + TypeC c) { // the ) and { go on the same line still + auto d = body_of_the_function(a, b, c); // body is single-indented + // ... more code ... } ``` -A newline should not be inserted between the parenthesis -that closes the function's parameters and the curly bracket -that opens the function's body. Note the double indent for function parameters. +For a constructor with a long parameter list, +break the line after the parentheses, just as with other functions. +Align the colon that starts the constructor's initializer list +flush with the comma on the next line. + +As with functions, double-indent the parameters +to distinguish them from the constructor body. +Here is an example. + +```c++ +class YesTheCommunityAgreesThatTheNameOfThisClassIsIndeedExtremelyLong { +public: + CUTLASS_HOST_DEVICE + YesTheCommunityAgreesThatTheNameOfThisClassIsIndeedExtremelyLong( + int this_is_the_first_parameter_and_its_name_is_long, + int this_is_the_second_parameter_and_its_name_is_also_long, + int this_is_the_third_parameter_and_its_name_is_long_too) + : x_(this_is_the_first_parameter_and_its_name_is_long) + , y_(this_is_the_second_parameter_and_its_name_is_also_long) + , z_(this_is_the_third_parameter_and_its_name_is_long_too) { + // constructor body + // more of the constructor body + } + +private: + int x_ = 0; + int y_ = 0; + int z_ = 0; +}; +``` + +#### Formatting function calls + +When calling a function or function object with a long name, +break the line right after the invoking open parenthesis. +Here are some examples. + +```c++ +detail::very_long_function_object_name{}( + params.long_parameter_name, some_operator.another_long_function_name()); + +detail::an_even_longer_function_object_name{}( + params.long_parameter_name, some_operator.long_member_function_name(), + another_operator.another_long_member_function_name(x, y, z)); +``` #### If-else brackets and spacing -* Always use braces with conditionals such as `if`. +* Always use braces with conditionals such as `if`, + even if the body is a single line. * Use a space after control flow keywords such as `if`, `for`, and `while`. @@ -181,13 +248,14 @@ that opens the function's body. Note the double indent for function parameters. of an `if` branch, and the `else` keyword. ```c++ -if (condition) { +if (condition) { // space after if, and between ) and { // ... code ... -} +} // newline after } else { // ... other code ... } +// space after keyword for for (int k = 0; k < num_iters; ++k) { // ... still more code ... } @@ -244,7 +312,6 @@ and not this. int const &var; int const *var; ``` - #### Avoid calling functions "fast" or "optimized" Putting words like "fast" or "optimized" @@ -395,6 +462,9 @@ Sometimes a function needs to return multiple values. In that case, consider th for all the types that work in `std::tuple`. CuTe's documentation explains.) +3. Resort to "returning" multiple values by output references + only if performance requires it. + Here is an example of the struct approach for named values. For a comparable example in the C++ Standard, please see [`std::allocate_at_least`](https://en.cppreference.com/w/cpp/memory/allocate_at_least), @@ -655,6 +725,158 @@ private: }; ``` +#### For code reuse, prefer composition over inheritance + +* [C++ Core Guidelines C.129](https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines#c129-when-designing-a-class-hierarchy-distinguish-between-implementation-inheritance-and-interface-inheritance): "When designing a class hierarchy, distinguish between implementation inheritance and interface inheritance" +* [C++ Core Guidelines ES.63](https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines#Res-slice): "Don't slice" + +Suppose that a class hierarchy exists entirely for implementation convenience, so that implementers can reuse code and "program by difference" (changing or adding only what's different from the base class). In the example below, both `PipelineA` and `PipelineB` are used by themselves. `PipelineB` inherits from `PipelineA` just to avoid duplicating code. There are no virtual member functions, and users don't expect to rely on run-time polymorphism. + +```c++ +class PipelineA { +public: + PipelineA(Arg0 arg0, Arg1 arg1) + : arg0_(arg0), arg1_(arg1) + {} + + void producer_acquire(uint32_t stage, uint32_t phase, uint32_t skip_wait) { + // ... implementation ... + } + + void consumer_release(uint32_t stage, uint32_t skip) { + // ... implementation ... + } + +private: + Arg0 arg0_; + Arg1 arg1_; +}; + +class PipelineB : public PipelineA { +public: + PipelineB(Arg0 arg0, Arg1 arg1, Arg2 arg2) : + PipelineA(arg0, arg1), arg2_(arg2) + {} + + // Reuse PipelineA::producer_acquire via inheritance + + // Override PipelineA::consumer_release + void consumer_release(uint32_t stage, uint32_t skip) { + // ... some other implementation, not invoking parent ... + } + +private: + Arg2 arg2_; +}; +``` + +The problem with public inheritance here is that `PipelineB` is NOT a (versus "is-a," i.e., substitutable-as) `PipelineA`. In particular, the following code would be incorrect. + +```c++ +void consume_and_release_pipeline(PipelineA* parent) { + // ... code ... + parent->consumer_release(stage, skip); + // ... code ... +} + +void use_pipeline( /* other args */ ) { + // ... code ... + PipelineB child{arg0, arg1, arg2}; + // ... code ... + + // WRONG!!! SLICES CHILD TO PARENT!!! + consume_and_release_pipeline(&child); // BAD + + // ... code ... +} +``` + +`PipelineA::consumer_release` is not a virtual member function, so `consume_and_release_pipeline` would not actually be polymorphic, as callers might have expected from an interface that takes a base class pointer. What's worse is that the resulting slicing could violate `PipelineB`'s invariants, thus putting it in an incorrect state. + +The most straightforward way to reuse code would be by changing from inheritance (is-a) to composition (has-a). + +```c++ +namespace detail { + +// Implementation class; not for users +class PipelineImpl { +public: + PipelineImpl(Arg0 arg0, Arg1 arg1) + : arg0_(arg0), arg1_(arg1) + {} + + void producer_acquire(uint32_t stage, uint32_t phase, uint32_t skip_wait) { + // ... implementation ... + } + + void consumer_release(uint32_t stage, uint32_t skip) { + // ... implementation ... + } + +private: + Arg0 arg0_; + Arg1 arg1_; +}; + +} // namespace detail + +class PipelineA { +public: + PipelineA(Arg0 arg0, Arg1 arg1) : + impl_(arg0, arg1) + {} + + void producer_acquire(uint32_t stage, uint32_t phase, uint32_t skip_wait) { + impl_.producer_acquire(stage, phase, skip_wait); + } + + void consumer_release(uint32_t stage, uint32_t skip) { + impl_.consumer_release(stage, skip); + } + +private: + detail::PipelineImpl impl_; +}; + +// A second kind of pipeline. +// Note that this does NOT inherit from PipelineB! +// The two pipeline classes have the same compile-time interface +// (for compile-time polymorphism), but do not belong in an +// inheritance hierarchy (as would imply run-time polymorphism). +class PipelineB { +public: + PipelineB(Arg0 arg0, Arg1 arg1, Arg2 arg2) : + impl_(arg0, arg1), otherTwo_(arg2) + {} + + void producer_acquire(uint32_t stage, uint32_t phase, uint32_t skip_wait) { + impl_.producer_acquire(stage, phase, skip_wait); + } + + void consumer_release(uint32_t stage, uint32_t skip) { + // this class doesn't actually use impl_ here + otherTwo_.other_action(stage, skip); + // ... some other code not using impl_ ... + } + +private: + detail::PipelineImpl impl_; + OtherTwo otherTwo_; + // ... other member data ... +}; +``` + +This design prevents users at compile time from incorrectly assuming that `PipelineB` is a `PipelineA`. Implementers continue to get compile-time polymorphism, as long as `PipelineA` and `PipelineB` implement the same compile-time interface. + +##### Behavioral subtyping + +Another reason to avoid public inheritance would be if the public member functions of `PipelineA` and `PipelineB` have different behavior, such that the invariants satisfied by the member functions of the base class `PipelineA` are not satisfied by the correspondingly named member functions of the subclass `PipelineB`. For example, suppose that both classes have a public `producer_arrive` member function. However, for `PipelineA`, this issues a producer arrival only for its own block, whereas for `PipelineB`, this issues a producer arrival for all blocks in the cluster. Again, PipelineB "is-not-a" PipelineA. The child class doesn't just add behavior onto the parent class; it has completely different behavior. Thus, it fails to satisfy behavioral subtyping: invariants of the parent class's member functions are not satisfied by the child class. Behavioral subtyping is especially important when reasoning about already difficult things like parallel synchronization. The inheritance design would give developers the false impression that `PipelineB` just adds behavior atop `PipelineA`, whereas in fact, developers would need to understand both pipeline classes completely to build a correct mental model about their behavior. + +The fix is the same: Use composition, not inheritance. As [C++ Core Guidelines C.120](https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines#c120-use-class-hierarchies-to-represent-concepts-with-inherent-hierarchical-structure-only) explains: "Use class hierarchies to represent concepts with inherent hierarchical structure (only)." + +1. "Make sure the idea represented in the base class exactly matches all derived types and there is not a better way to express it than using the tight coupling of inheritance." +2. "Do not use inheritance when simply having a data member will do." + #### Use scoped enums Use scoped enums (a C++11 feature) for enumerated types. @@ -765,18 +987,119 @@ Use `#pragma once` to guard all headers. ### CuTe Layout Comments -* Right align CuTe layout comments at column 120. +* Right-align tensor shape layout comments at column 120. * If layout comment is too long do your best to align it. -* If layout comment is too long and there are many related tensors that reader should read together, try to align the layout comments of related tensors. +* If layout comment is too long and there are many related tensors + that the reader should read together, + try to align the layout comments of related tensors. + +Here are a couple examples. ```c++ - Tensor my_tensor = make_tensor(Layout{}, Stride<_1,_2>>{}); // (2,2):(1,2) +Tensor mC = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N), params.dC); // (M,N) +Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N), params.dD); // (M,N) +Tensor mAux = make_tensor(make_gmem_ptr(params.ptr_Aux), make_shape(M,N), params.dAux); // (M,N) + +auto thr_mma = tiled_mma.get_thread_slice(thread_idx); +Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) +Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) +Tensor tCgAux = thr_mma.partition_C(gAux); // (VEC,THR_M,THR_N) +``` + +```c++ +Tensor my_tensor = make_tensor(Layout{}, Stride<_1,_2>>{}); // (2,2):(1,2) - // Related tensors - Tensor my_tensor1 = make_tensor(ThisIsAVeryComplicatedLayoutWithAVeryLongName); // ((Mode0_0,Mode0_1,Mode0_2),Mode1,Mode2,Mode3) - Tensor my_tensor2_related = make_tensor(ThisIsAVeryComplicatedLayoutWithAVeryLongName); // ((Mode0_0,Mode0_1,Mode0_2),Mode1,Mode2,Mode3) +// Related tensors +Tensor my_tensor1 = make_tensor(ThisIsAVeryComplicatedLayoutWithAVeryLongName); // ((Mode0_0,Mode0_1,Mode0_2),Mode1,Mode2,Mode3) +Tensor my_tensor2_related = make_tensor(ThisIsAVeryComplicatedLayoutWithAVeryLongName); // ((Mode0_0,Mode0_1,Mode0_2),Mode1,Mode2,Mode3) +``` + +### Warnings + +CUTLASS code aims to build free of warnings. + +#### Spurious warnings + +Some compilers, or some versions of a compiler, emit spurious warnings, that is, "false positives" for perfectly fine code. While such code is correct, the warnings can obscure errors. Users also may report warnings as bugs, and processing those bugs takes developer time away from other tasks. Thus, it's good to try to "fix" the warnings, if doing so wouldn't make the code worse. + +#### Missing return statement + +GCC 10 (but not 7.5, 9.4.0, or 11) has trouble deducing that a function with `auto` return type and all of its returns in an `if constexpr` ... `else` statement must actually return. As a result, GCC emits spurious "missing return statement" build warnings. Such functions have one of two forms: `if constexpr` ... `else` where `else` returns, and `if constexpr` ... `else` where `else` is meant to fail at compile time. Here is an example of the first form. + +```c++ +template +constexpr auto first_form(T t) { + if constexpr (some_condition_v) { + return some_function(t); + } + else if constexpr (another_condition_v) { + return another_function(t); + } + else { + return yet_another_function(t); + } +} +``` + +In this form, the `if constexpr` ... `else` sequence of branches covers all possibilities. Here is an example of the second form. + +```c++ +template +constexpr auto second_form(T t) { + if constexpr (some_condition_v) { + return some_function(t); + } + else if constexpr (another_condition_v) { + return another_function(t); + } + else { + static_assert(sizeof(T) < 0, "This branch always fails"); + } +} +``` + +In this form, the `else` branch had a `static_assert` that was meant always to fail if the `else` branch were taken, such as `static_assert(sizeof(T) < 0)`. (Note that we cannot use `static_assert(false)` here, because it will ALWAYS fail at compile time, even if the `else` branch is not taken. C++23 fixes this behavior, but CUTLASS currently requires that its code be compatible with C++17. As a result, CUTLASS includes a `dependent_false` library function that you can use in place of the always-`false` test `sizeof(T) < 0`.) + +One can suppress "missing return statement" warnings for both forms by invoking CUTLASS' function-like macro `CUTE_GCC_UNREACHABLE()`. When building with GCC, this invokes the GCC-specific built-in function `__builtin_unreachable()`. Actually calling this function is undefined behavior, so using this lets the programmer declare that the code path calling that function will never be taken. (C++23 introduces the `std::unreachable()` function, which achieves the same goal. Again, though, CUTLASS cannot currently use C++23 library functions.) Here is an example of how to use `CUTE_GCC_UNREACHABLE()`. + +```c++ +template +constexpr auto second_form(T t) { + if constexpr (some_condition_v) { + return some_function(t); + } + else if constexpr (another_condition_v) { + return another_function(t); + } + else { + static_assert(sizeof(T) < 0, "This branch always fails"); + } + CUTE_GCC_UNREACHABLE(); +} +``` + +This macro should only be used if it is needed to suppress spurious warnings. Also, this function should not be used if the developer is not sure whether the code exhaustively tests all possibilities. For example, some functions may look like this. + +```c++ +template +constexpr auto possibly_nonexhaustive(T t) { + if constexpr (some_condition_v) { + return some_function(t); + } + else if constexpr (another_condition_v) { + return another_function(t); + } + + // NOTE lack of unadorned "else" here +} ``` +This is a good opportunity to review the function. If the branches are obviously meant to be exhaustive, you can add an `else` branch with a `static_assert` (see above for how to express this). If you're not sure, leave it alone and let the compiler issue warnings. + +#### Unused variable + +Some compilers may emit spurious unused warnings for some variable declarations, where the variable was only being used inside a `decltype` in an `if constexpr` test. Marking the variables as `[[maybe_unused]]` (a standard C++17 attribute) suppresses these warnings. Again, please only do this if you're sure that the code is right. + ### CUDA C++ style #### CUDA Built-in Variables diff --git a/media/docs/quickstart.md b/media/docs/quickstart.md index 804d0ec39e..7faad445d9 100644 --- a/media/docs/quickstart.md +++ b/media/docs/quickstart.md @@ -232,7 +232,7 @@ int main() { ## Launching a GEMM kernel in CUDA -**Example:** launch a mixed-precision GEMM targeting Turing Tensor Cores. +**Example:** launch a mixed-precision GEMM targeting Turing Tensor Cores. _Note, this example uses CUTLASS Utilities. Be sure `tools/util/include` is listed as an include path._ ```c++ @@ -289,7 +289,7 @@ int main() { // // Launch GEMM on the device // - + status = gemm_op({ {M, N, K}, {ptrA, lda}, // TensorRef to A device tensor @@ -315,7 +315,7 @@ Note, the above could be simplified as follows using helper methods defined in ` // // Use the TensorRef returned by HostTensor::device_ref(). - // + // status = gemm_op({ {M, N, K}, @@ -329,7 +329,7 @@ Note, the above could be simplified as follows using helper methods defined in ` ## Launching a GEMM kernel using CUTLASS 3.0 or newer -**Example:** launch a mixed-precision GEMM targeting Hopper Tensor Cores. +**Example:** launch a mixed-precision GEMM targeting Hopper Tensor Cores. ```c++ #include "cutlass/cutlass.h" @@ -367,7 +367,7 @@ int main(int argc, char const **args) { using TilesShape = Shape<_128,_128,_64>; // Threadblock-level tile size using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size - using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default setting in the Collective Builder + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default setting in the Collective Builder using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, @@ -425,10 +425,10 @@ int main(int argc, char const **args) { StrideC stride_C; StrideD stride_D; - stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, Int<1>{})); - stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, Int<1>{})); - stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, Int<1>{})); - stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, Int<1>{})); + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}); block_A.reset(M * K); block_B.reset(K * N); @@ -438,7 +438,7 @@ int main(int argc, char const **args) { // // Launch GEMM on the device // - + status = gemm_op({ cutlass::gemm::GemmUniversalMode::kGemm, {M, N, K}, @@ -462,9 +462,9 @@ int main(int argc, char const **args) { The [CUTLASS Library](/tools/library) defines an API for managing and executing collections of compiled kernel instances and launching them from host code without template instantiations in client code. -The host-side launch API is designed to be analogous to BLAS implementations for convenience, though its -kernel selection procedure is intended only to be functionally sufficient. It may not launch the -optimal tile size for a given problem. It chooses the first available kernel whose data types, +The host-side launch API is designed to be analogous to BLAS implementations for convenience, though its +kernel selection procedure is intended only to be functionally sufficient. It may not launch the +optimal tile size for a given problem. It chooses the first available kernel whose data types, layouts, and alignment constraints satisfy the given problem. Kernel instances and a data structure describing them are completely available to client applications which may choose to implement their own selection logic. @@ -479,12 +479,12 @@ by several SDK examples. * [11_planar_complex_array](/examples/11_planar_complex_array/planar_complex_array.cu) The CUTLASS Library defines enumerated types describing numeric data types, matrix and tensor -layouts, math operation classes, complex transformations, and more. +layouts, math operation classes, complex transformations, and more. Client applications should specify [`tools/library/include`](/tools/library/include) in their include paths and link against libcutlas_lib.so. -The CUTLASS SDK example [10_planar_complex](/examples/10_planar_complex/CMakeLists.txt) specifies +The CUTLASS SDK example [10_planar_complex](/examples/10_planar_complex/CMakeLists.txt) specifies its dependency on the CUTLASS Library with the following CMake command. ``` target_link_libraries( @@ -534,7 +534,7 @@ int main() { // // CUTLASS Library call to execute device GEMM // - + cutlass::library::Handle handle; // @@ -571,7 +571,7 @@ int main() { ptrD, // pointer to D matrix in device memory ldd // leading dimension of D matrix ); - + if (status != cutlass::Status::kSuccess) { return -1; } @@ -580,27 +580,27 @@ int main() { } ``` -# Example CMake Commands +# Example CMake Commands -To instantiate all operations supporting all tile sizes, data types, and alignment constraints, specify +To instantiate all operations supporting all tile sizes, data types, and alignment constraints, specify `-DCUTLASS_LIBRARY_KERNELS=all` when running `cmake`. ```bash $ cmake .. -DCUTLASS_NVCC_ARCHS='70;75;80' -DCUTLASS_LIBRARY_KERNELS=all ``` -The above command line generates about twenty thousand kernels targeting NVIDIA Ampere, Turing, and Volta architectures. -Compiling thousands of kernels for three different architectures is time-consuming. Additionally, this would also result +The above command line generates about twenty thousand kernels targeting NVIDIA Ampere, Turing, and Volta architectures. +Compiling thousands of kernels for three different architectures is time-consuming. Additionally, this would also result in a large binary size and on some platforms linker to fail on building the library. -Enabling the "unity build" instantiates multiple kernel instances in each compilation unit, thereby reducing binary size +Enabling the "unity build" instantiates multiple kernel instances in each compilation unit, thereby reducing binary size and avoiding linker limitations on some platforms. ```bash $ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" -DCUTLASS_LIBRARY_KERNELS=all -DCUTLASS_UNITY_BUILD_ENABLED=ON ``` -It is advised to only compile CUTLASS kernels for NVIDIA architectures one plans on running. Furthermore, kernels -can be selectively included in the CUTLASS Library by specifying filter strings and wildcard characters when executing CMake. +It is advised to only compile CUTLASS kernels for NVIDIA architectures one plans on running. Furthermore, kernels +can be selectively included in the CUTLASS Library by specifying filter strings and wildcard characters when executing CMake. -Several examples are defined below for convenience. They may be combined as a comma-delimited list. +Several examples are defined below for convenience. They may be combined as a comma-delimited list. Compling only the kernels desired reduces compilation time. @@ -646,7 +646,7 @@ $ cmake .. -DCUTLASS_NVCC_ARCHS='50;60;61;70;75;80' -DCUTLASS_LIBRARY_KERNELS=sf $ cmake .. -DCUTLASS_NVCC_ARCHS='80' -DCUTLASS_LIBRARY_KERNELS=s16816fprop_*_f16 ``` -**Example.** All backward weight gradient (wgrad) convolution kernels with FP32 accumulation, FP16 input, and optimized global memory iterator +**Example.** All backward weight gradient (wgrad) convolution kernels with FP32 accumulation, FP16 input, and optimized global memory iterator targeting NVIDIA Ampere, Turing, and Volta Tensor Core operations ```bash $ cmake .. -DCUTLASS_NVCC_ARCHS='70;75;80' -DCUTLASS_LIBRARY_KERNELS=tensorop*s*wgrad_optimized_f16 diff --git a/pyproject.toml b/pyproject.toml index b537d90f23..61c371a23d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "nvidia-cutlass" -version = "3.5.0.0" +version = "3.5.1.0" description = "CUTLASS" readme = "README.md" requires-python = ">=3.8" diff --git a/python/cutlass/__init__.py b/python/cutlass/__init__.py index d04b038e72..dfc9b40509 100644 --- a/python/cutlass/__init__.py +++ b/python/cutlass/__init__.py @@ -121,7 +121,7 @@ def get_option_registry(): this._option_registry = OptionRegistry(device_cc()) return this._option_registry -this.__version__ = '3.5.0' +this.__version__ = '3.5.1' from cutlass.backend import create_memory_pool from cutlass.emit.pytorch import pytorch diff --git a/python/cutlass/backend/evt/backend/sm90_nodes.py b/python/cutlass/backend/evt/backend/sm90_nodes.py index f68b6859a8..477aab9dab 100644 --- a/python/cutlass/backend/evt/backend/sm90_nodes.py +++ b/python/cutlass/backend/evt/backend/sm90_nodes.py @@ -154,20 +154,6 @@ def type_decl(self): class Sm90RowBroadcastImpl(RowBroadcastImpl): - - @property - def descriptor(self) -> str: - """ - Descriptor for Aux Load - """ - return f"{self.name_camel}Descriptor" - - def decl_descriptor(self) -> str: - """ - Declare the descriptor type - """ - return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::RowBroadcastDescriptor;\n" - @property def type_decl(self): """ @@ -176,22 +162,14 @@ def type_decl(self): if self._type_decl is not None: return self._type_decl - self._type_decl = self.decl_descriptor() - self._type_decl += f""" + self._type_decl = f""" using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowBroadcast< - {self.descriptor}::Stages, typename EpilogueDescriptor::TileShape, - typename {self.descriptor}::Element, {self.stride_mnl} + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, + {self.stride_mnl} >; """ return self._type_decl - def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles): - """ - Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d - """ - stages = (stages_c + epi_tiles - 1) // epi_tiles + 1 - return (DataTypeSize[self.element] * cta_tile_mnk[1] * stages // 8, 16) - class Sm90ColumnBroadcastImpl(ColumnBroadcastImpl): diff --git a/python/cutlass_library/conv2d_operation.py b/python/cutlass_library/conv2d_operation.py index 10e3922bea..1cfe7f6eb6 100644 --- a/python/cutlass_library/conv2d_operation.py +++ b/python/cutlass_library/conv2d_operation.py @@ -150,7 +150,7 @@ def configuration_name(self): else: group_conv_name = "" - if self.stride_support == StrideSupport.Unity: + if self.stride_support == StrideSupport.Unity and self.conv_kind == ConvKind.Dgrad: configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_${group_conv_name}align${alignment}" else: configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${group_conv_name}align${alignment}" diff --git a/python/cutlass_library/conv3x_emitter.py b/python/cutlass_library/conv3x_emitter.py index 3088bbc879..84d42a3ad5 100644 --- a/python/cutlass_library/conv3x_emitter.py +++ b/python/cutlass_library/conv3x_emitter.py @@ -69,8 +69,8 @@ def __init__(self): typename cutlass::epilogue::collective::CollectiveBuilder< ${arch}, ${opcode_class_epi}, - ${tile_shape}, // tile shape - ${cluster_shape}, // cluster shape + ${output_cta_tile_shape}, // output cta tile shape + ${cluster_shape}, // cluster shape ${epi_tile_mn}, ${element_accumulator}, ${element_compute}, @@ -88,8 +88,8 @@ def __init__(self): ${element_a}, ${layout_a}, 128 / cute::sizeof_bits_v<${element_a}>, ${element_b}, ${layout_b}, 128 / cute::sizeof_bits_v<${element_b}>, ${element_accumulator}, - ${tile_shape}, // tile shape - ${cluster_shape}, // cluster shape + ${mma_tile_shape}, // mma tile shape + ${cluster_shape}, // cluster shape ${stages}, ${kernel_schedule} >::CollectiveOp; @@ -106,30 +106,54 @@ def __init__(self): def arch_number_to_type(self, arch: int) -> str: return f"cutlass::arch::Sm{arch}" - def tile_shape(self, operation) -> str: + def output_cta_tile_shape(self, operation, cta_m, cta_n, cta_k) -> str: # For all three kinds of convolutions, the tile shape's K mode # differs from GEMM in that needs to be wrapped in a Shape. # For Wgrad convolutions specifically, # the N tile shape also needs to be wrapped in a Shape. - m_template = 'cute::_${tile_shape_m}' + m_template = 'cute::_${cta_m}' if operation.conv_kind == ConvKind.Wgrad: - n_template = 'cute::Shape' + n_template = 'cute::Shape' else: - n_template = 'cute::_${tile_shape_n}' - k_template = 'cute::Shape' + n_template = 'cute::_${cta_n}' + k_template = 'cute::Shape' - tile_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>' + output_cta_tile_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>' values = { - 'tile_shape_m': operation.tile_description.tile_shape[0], - 'tile_shape_n': operation.tile_description.tile_shape[1], - 'tile_shape_k': operation.tile_description.tile_shape[2] + 'cta_m': cta_m, + 'cta_n': cta_n, + 'cta_k': cta_k } - return Template(tile_shape_template).substitute(values) + return Template(output_cta_tile_shape_template).substitute(values) + + def mma_tile_shape(self, operation, cta_m, cta_n, cta_k) -> str: + mma_m = cta_m + mma_n = cta_n + mma_k = cta_k + + # For all three kinds of convolutions, the tile shape's K mode + # differs from GEMM in that needs to be wrapped in a Shape. + # For Wgrad convolutions specifically, + # the N tile shape also needs to be wrapped in a Shape. + m_template = 'cute::_${mma_m}' + if operation.conv_kind == ConvKind.Wgrad: + n_template = 'cute::Shape' + else: + n_template = 'cute::_${mma_n}' + k_template = 'cute::Shape' + + mma_tile_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>' + values = { + 'mma_m': mma_m, + 'mma_n': mma_n, + 'mma_k': mma_k + } + return Template(mma_tile_shape_template).substitute(values) def cluster_shape(self, operation) -> str: - m_template = 'cute::_${cluster_shape_m}' - n_template = 'cute::_${cluster_shape_n}' - k_template = 'cute::_${cluster_shape_k}' + m_template = 'cute::_${cluster_shape_m}' if operation.tile_description.cluster_shape[0] > 0 else 'int(0)' + n_template = 'cute::_${cluster_shape_n}' if operation.tile_description.cluster_shape[1] > 0 else 'int(0)' + k_template = 'cute::_${cluster_shape_k}' if operation.tile_description.cluster_shape[2] > 0 else 'int(0)' cluster_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>' values = { 'cluster_shape_m': operation.tile_description.cluster_shape[0], @@ -159,6 +183,10 @@ def emit(self, operation) -> str: opcode_class_epi = opcode_class_main tile_shape = operation.tile_description.tile_shape + cluster_m = operation.tile_description.cluster_shape[0] + cluster_n = operation.tile_description.cluster_shape[1] + + cta_m, cta_n, cta_k = tile_shape warp_count = operation.tile_description.warp_count epilogue_schedule = EpilogueScheduleTag[operation.epilogue_schedule] @@ -189,19 +217,20 @@ def emit(self, operation) -> str: 'element_d': DataTypeTag[operation.D.element], 'layout_d': LayoutTag[operation.D.layout], 'align_d': int(operation.D.alignment), - 'element_accumulator': DataTypeTag[operation.accumulator_type()], - 'opcode_class': opcode_class, - 'arch': self.arch_number_to_type(operation.arch), - 'tile_shape': self.tile_shape(operation), - 'cluster_shape': self.cluster_shape(operation), - 'opcode_class_epi': opcode_class_epi, - 'opcode_class_main': opcode_class_main, - 'epi_tile_mn': epi_tile_mn, - 'stages': self.stage_count(operation), - 'kernel_schedule': kernel_schedule, - 'epilogue_schedule': epilogue_schedule, - 'tile_scheduler': tile_scheduler, - 'element_compute': DataTypeTag[operation.element_compute] + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': opcode_class, + 'arch': self.arch_number_to_type(operation.arch), + 'output_cta_tile_shape': self.output_cta_tile_shape(operation, cta_m, cta_n, cta_k), + 'mma_tile_shape': self.mma_tile_shape(operation, cta_m, cta_n, cta_k), + 'cluster_shape': self.cluster_shape(operation), + 'opcode_class_epi': opcode_class_epi, + 'opcode_class_main': opcode_class_main, + 'epi_tile_mn': epi_tile_mn, + 'stages': self.stage_count(operation), + 'kernel_schedule': kernel_schedule, + 'epilogue_schedule': epilogue_schedule, + 'tile_scheduler': tile_scheduler, + 'element_compute': DataTypeTag[operation.element_compute] } return Template(self.template).substitute(values) diff --git a/python/cutlass_library/gemm_operation.py b/python/cutlass_library/gemm_operation.py index f739c15ab8..5e015492f4 100644 --- a/python/cutlass_library/gemm_operation.py +++ b/python/cutlass_library/gemm_operation.py @@ -178,16 +178,28 @@ def extended_name(self): if self.is_complex(): extended_name = "${core_name}" else: + # e.g. f16_f16_f32_void_f32 kernel if self.C.element != self.tile_description.math_instruction.element_accumulator and \ - self.A.element != self.tile_description.math_instruction.element_accumulator: + self.A.element != self.tile_description.math_instruction.element_accumulator: extended_name = "${element_c}_${core_name}_${element_a}" if self.is_mixed_input(): extended_name += "_${element_b}" + + # e.g. f32_f32_f32_void_f32 kernel + elif self.C.element != self.tile_description.math_instruction.element_accumulator and \ + self.A.element == self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_${core_name}" + if self.is_mixed_input(): + extended_name += "_${element_b}" + + # e.g. f16_f16_f32_f32_f32 kernel elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ - self.A.element != self.tile_description.math_instruction.element_accumulator: + self.A.element != self.tile_description.math_instruction.element_accumulator: extended_name = "${core_name}_${element_a}" if self.is_mixed_input(): extended_name += "_${element_b}" + + # e.g. f32_f32_f32_f32_f32 kernel else: extended_name = "${core_name}" diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index 0ac604e74c..cbc9c326b3 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -36,13 +36,13 @@ import argparse import enum -from itertools import product +from itertools import chain, product import logging import os.path import shutil import sys import copy -from typing import Any, Optional, Sequence, Tuple +from typing import Any, Dict, Optional, Sequence, Tuple _LOGGER = logging.getLogger(__name__) @@ -513,7 +513,7 @@ def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignme new_operations = [ # None grouped kernel Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ - A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_), + A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_), ] # Instance group conv kernel @@ -521,12 +521,12 @@ def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignme tile.minimum_compute_capability >= 80: # SingleGroup kernel new_operations.append(Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ - A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_, group_mode=GroupMode.SingleGroup)) + A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_, group_mode=GroupMode.SingleGroup)) # Analytic iterator supports MultipleGroup mode if iterator_algorithm == IteratorAlgorithm.Analytic: new_operations.append(Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ - A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_, group_mode=GroupMode.MultipleGroup)) + A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_, group_mode=GroupMode.MultipleGroup)) for new_operation in new_operations: manifest.append(new_operation) @@ -884,7 +884,7 @@ def short_math_name(self): prefix = '' if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: prefix = 'g' - return prefix + ShortDataTypeNames[self.accumulator_type()] + return prefix + DataTypeNames[self.accumulator_type()] def is_tensor_op(self): tensor_ops = [ @@ -1054,8 +1054,11 @@ def CreateConvOperator3x(manifest: Manifest, log_debug_line(f'conv_kind: {conv_kind}', log_indent_level) for triple in dims_and_alignments: - spatial_dimensionality = None # to be determined by loop below + assert(isinstance(triple, tuple) or isinstance(triple, list)) assert(len(triple) == 3) + + spatial_dimensionality = None # to be determined by loop below + for entry in triple: # [A, B, C] assert(len(entry) == 2) [dim, alignment] = entry @@ -6631,85 +6634,352 @@ def GenerateSM90_Conv3x(manifest, cuda_version, minimum_compute_capability = 90 maximum_compute_capability = 90 - spatial_dims = [2, 3] + spatial_dims = (2, 3) + + # This function only generates kernels that use TMA. + byte_alignment_required_by_tma = 16 + tma_byte_alignments = { + 'A': byte_alignment_required_by_tma, + 'B': byte_alignment_required_by_tma, + 'C': byte_alignment_required_by_tma, + } + + # For tuples of one element, the element needs to end with comma. + all_byte_alignments = ( + tma_byte_alignments, + ) + + # MMA shapes (MMA_M, MMA_N, MMA_K): + # + # Different hardware MMA instructions may have different MMA shapes. + # This function may generate kernels with different MMA shapes for + # different data types, either because the hardware only supports + # certain shapes for certain types, or for performance reasons + # (CUTLASS doesn't need to generate all valid kernels for the + # profiler library, just the best-performing ones). + # + # The kernel names refer to tile shapes (TILE_M, TILE_N, TILE_K) + # instead of MMA shapes. For SM >= 90 kernels, TILE_K = 4 * MMA_K, + # where 4, the "number of MMA instructions per tile," is determined + # through some combination of modeling and experiment. + # + # For performance on sm90, generally CUTLASS generates 64x128 + # instead of 128x64. + mma_64x64x16 = ( 64, 64, 16) + mma_64x64x8 = ( 64, 64, 8) + + num_mma_per_tile = 4 + + # Cluster shapes (1, 1, 1) and (2, 2, 1) are valid, + # but not included, because they tend not to perform as well. + cluster_shapes = ( + (2, 1, 1), + (1, 2, 1), + ) + + fp16 = DataType.f16 + bf16 = DataType.bf16 + fp32 = DataType.f32 + s8 = DataType.s8 + s32 = DataType.s32 + + # When generating kernels, the usual way is to specify 4 types, + # (A, B, Acc, C/D). Tests instead have 5 types, + # (ElementAct, ElementFlt, ElementOut, ElementAcc, ElementCompute), + # where ElementCompute is also called 'epi_type', + # and corresponds to the type of epilogue activations. + # This script maps tests' 5 types to 4 types + # by making ElementCompute the same as ElementOut. + + fp16_fp32_fp16_fp32 = { + 'a_type': fp16, # ElementAct(ivation) + 'b_type': fp16, # ElementF(i)lt(er) + 'c_type': fp32, # ElementAcc + 'd_type': fp32, # ElementOut (used only by CollectiveEpilogue) + 'acc_type': fp16, # ElementAcc + 'epi_type': fp32, # ElementCompute (used only by CollectiveEpilogue) + } + fp16_fp32_fp32_fp32 = { + 'a_type': fp16, + 'b_type': fp16, + 'c_type': fp32, + 'd_type': fp32, + 'acc_type': fp32, + 'epi_type': fp32, + } + fp32_fp32_fp32_fp32 = { + 'a_type': fp32, + 'b_type': fp32, + 'c_type': fp32, + 'd_type': fp32, + 'acc_type': fp32, + 'epi_type': fp32, + } + s8_s32_s32_s32 = { + 'a_type': s8, + 'b_type': s8, + 'c_type': s32, + 'd_type': s32, + 'acc_type': s32, + 'epi_type': s32, + } - def make_dims_and_alignments_triple(dim: int): - byte_alignment_required_by_tma = 16 - return ((dim, byte_alignment_required_by_tma), # A - (dim, byte_alignment_required_by_tma), # B - (dim, byte_alignment_required_by_tma)) # C - dims_and_alignments = [make_dims_and_alignments_triple(dim) for dim in spatial_dims] + # Other NVIDIA libraries may have the habit of specifying data types like this. + bf16bf16_bf16f32_f32 = { + 'a_type': bf16, + 'b_type': bf16, + 'c_type': fp32, + 'd_type': fp32, + 'acc_type': fp32, + 'epi_type': fp32, + } + f16f16_f16f16_f16 = { + 'a_type': fp16, + 'b_type': fp16, + 'c_type': fp16, + 'd_type': fp16, + 'acc_type': fp16, + 'epi_type': fp16, + } + f16f16_f16f32_f32 = { + 'a_type': fp16, + 'b_type': fp16, + 'c_type': fp16, + 'd_type': fp16, + 'acc_type': fp32, + 'epi_type': fp32, + } + f32f32_tf32f32_f32 = fp32_fp32_fp32_fp32 + + i8i8_i8i32_f32 = { + 'a_type': s8, + 'b_type': s8, + 'c_type': s32, + 'd_type': s32, + 'acc_type': s32, + 'epi_type': s32, + } - def make_math_instruction(data_types: Tuple[DataType, DataType, DataType], - instruction_shape: Tuple[int, int, int]) -> MathInstruction: + # Each element in the outermost iterable is one combination of + # + # (ConvKind, spatial_dimension, data_types, byte_alignments, mma_sizes, cluster_sizes) + # + # for which to generate a kernel. spatial_dimension is the spatial + # dimension of the convolution: either 1, 2, or 3. byte_alignments + # is a triple of required minimum byte alignments for A, B, and C. + # + # Note that itertools functions produce a single-pass generator. + # The code doesn't need a multipass iterable, but if one did, one + # could call `tuple` or `list` on the generator. + # + # While this happens to use the same cluster sizes for each element, + # the code doesn't require that. Different convolution kinds, data + # types, or mma sizes might have different optimal cluster sizes. + combinations_of_parameters = chain( + # The following are all the kernels exercised in the unit tests. + # Please try to keep in sync with the unit tests. + product( + ( + ConvKind.Fprop, + ), + spatial_dims, + ( + fp16_fp32_fp16_fp32, + fp16_fp32_fp32_fp32, + s8_s32_s32_s32, + ), + all_byte_alignments, + ( + mma_64x64x16, + ), + cluster_shapes + ), + product( + ( + ConvKind.Fprop, + ), + spatial_dims, + ( + fp32_fp32_fp32_fp32, + ), + all_byte_alignments, + ( + mma_64x64x8, + ), + cluster_shapes + ), + product( + ( + ConvKind.Dgrad, + ), + spatial_dims, + ( + fp16_fp32_fp16_fp32, + fp16_fp32_fp32_fp32, + ), + all_byte_alignments, + ( + mma_64x64x16, + ), + cluster_shapes + ), + # Kernels not necessarily in the unit tests, but used elsewhere + # and thus useful to have generated for profiling. They may + # duplicate kernels above. All of them are 2-D. In general, + # CUTLASS prefers 64 x 128 to 128 x 64 on sm90, even if the + # hardware permits 128 x 64. + ( + # Fprop + # + # bf16bf16_bf16f32_f32 + # + # cluster shape (2, 1, 1) + # + (ConvKind.Fprop, 2, bf16bf16_bf16f32_f32, tma_byte_alignments, (128, 256, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, bf16bf16_bf16f32_f32, tma_byte_alignments, (128, 256, 16), (2, 1, 1)), + (ConvKind.Fprop, 2, bf16bf16_bf16f32_f32, tma_byte_alignments, (256, 128, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, bf16bf16_bf16f32_f32, tma_byte_alignments, (256, 128, 16), (2, 1, 1)), + # + # f16f16_f16f16_f16 + # + # cluster shape (1, 1, 1) + # + (ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 64, 8), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 64, 16), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 128, 8), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 128, 16), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 256, 8), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 256, 16), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, (128, 128, 8), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, (128, 128, 16), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, (128, 256, 8), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, (128, 256, 16), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, (256, 64, 8), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, (256, 64, 16), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, (256, 128, 8), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, tma_byte_alignments, (256, 128, 16), (1, 1, 1)), + # + # f16f16_f16f32_f32 + # + # cluster shape (2, 1, 1) + # + (ConvKind.Fprop, 2, f16f16_f16f32_f32, tma_byte_alignments, (128, 192, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f32_f32, tma_byte_alignments, (128, 192, 16), (2, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f32_f32, tma_byte_alignments, (128, 256, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f32_f32, tma_byte_alignments, (128, 256, 16), (2, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f32_f32, tma_byte_alignments, (256, 96, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f32_f32, tma_byte_alignments, (256, 96, 16), (2, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f32_f32, tma_byte_alignments, (256, 128, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f32_f32, tma_byte_alignments, (256, 128, 16), (2, 1, 1)), + # + # f32f32_tf32f32_f32 + # + # cluster shape (2, 1, 1) + # + (ConvKind.Fprop, 2, f32f32_tf32f32_f32, tma_byte_alignments, (128, 192, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, f32f32_tf32f32_f32, tma_byte_alignments, (128, 256, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, f32f32_tf32f32_f32, tma_byte_alignments, (256, 128, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, f32f32_tf32f32_f32, tma_byte_alignments, (256, 96, 8), (2, 1, 1)), + # + # i8i8_i8i32_f32 + # + # cluster shape (2, 1, 1) + # + (ConvKind.Fprop, 2, i8i8_i8i32_f32, tma_byte_alignments, (128, 256, 16), (2, 1, 1)), + (ConvKind.Fprop, 2, i8i8_i8i32_f32, tma_byte_alignments, (128, 256, 32), (2, 1, 1)), + (ConvKind.Fprop, 2, i8i8_i8i32_f32, tma_byte_alignments, (256, 128, 16), (2, 1, 1)), + (ConvKind.Fprop, 2, i8i8_i8i32_f32, tma_byte_alignments, (256, 128, 32), (2, 1, 1)), + # + # Dgrad + # + # bf16bf16_bf16f32_f32 + # + # cluster shape (2, 1, 1) + # + (ConvKind.Dgrad, 2, bf16bf16_bf16f32_f32, tma_byte_alignments, (128, 256, 8), (2, 1, 1)), + (ConvKind.Dgrad, 2, bf16bf16_bf16f32_f32, tma_byte_alignments, (128, 256, 16), (2, 1, 1)), + (ConvKind.Dgrad, 2, bf16bf16_bf16f32_f32, tma_byte_alignments, (256, 128, 8), (2, 1, 1)), + (ConvKind.Dgrad, 2, bf16bf16_bf16f32_f32, tma_byte_alignments, (256, 128, 16), (2, 1, 1)), + # + # f16f16_f16f16_f16 + # + # cluster shape (1, 1, 1) + # + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 64, 8), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 64, 16), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 128, 8), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 128, 16), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 256, 8), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, ( 64, 256, 16), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, (128, 128, 8), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, (128, 128, 16), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, (128, 256, 8), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, (128, 256, 16), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, (256, 64, 8), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, (256, 64, 16), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, (256, 128, 8), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, tma_byte_alignments, (256, 128, 16), (1, 1, 1)), + # + # f16f16_f16f32_f32 + # + # cluster shape (2, 1, 1) + # + (ConvKind.Dgrad, 2, f16f16_f16f32_f32, tma_byte_alignments, (128, 256, 8), (2, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f32_f32, tma_byte_alignments, (128, 256, 16), (2, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f32_f32, tma_byte_alignments, (256, 128, 8), (2, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f32_f32, tma_byte_alignments, (256, 128, 16), (2, 1, 1)), + ), + ) + + # SM >= 90 kernels don't actually use warp_count, but the + # TileDescription class needs it. The 4 in the default + # warp_count has nothing to do with num_mma_per_tile. + warp_count = [4, 1, 1] + + stages = 0 # zero means "deduce the number of stages automatically" + + mainloop_schedule = KernelScheduleType.ImplicitTmaWarpSpecializedSm90 + epilogue_schedule = EpilogueScheduleType.TmaWarpSpecialized + schedule_pairs = ( + (mainloop_schedule, epilogue_schedule), + ) + tile_schedulers = ( + TileSchedulerType.Default, # -> void + ) + + def make_math_instruction(data_types: Dict[str, DataType], + mma_shape: Tuple[int, int, int]) -> MathInstruction: default_opcode = OpcodeClass.TensorOp default_math_op = MathOperation.multiply_add - [A_data_type, B_data_type, C_data_type] = data_types return MathInstruction( - instruction_shape, - A_data_type, B_data_type, C_data_type, + mma_shape, + data_types['a_type'], data_types['b_type'], data_types['c_type'], default_opcode, default_math_op ) - data_types_and_instruction_shapes = [ - ((DataType.f16, DataType.f16, DataType.f16), (64, 64, 16)), - ((DataType.f16, DataType.f16, DataType.f32), (64, 64, 16)), - ((DataType.bf16, DataType.bf16, DataType.f32), (64, 64, 16)), - ] - math_instructions = map(lambda x: make_math_instruction(*x), - data_types_and_instruction_shapes) - cluster_shapes = [ - [2, 1, 1], - [1, 1, 1], - ] - conv_kinds = [ - ConvKind.Fprop, - ConvKind.Dgrad - ] - mainloop_schedule = KernelScheduleType.ImplicitTmaWarpSpecializedSm90 - stages = 0 # zero means "deduce the number of stages automatically" - - # tile_descriptions is a 2-level list. - # Each inner list is for each cluster shape. - for math_inst in math_instructions: - tile_descriptions = [] - for cluster_shape in cluster_shapes: - tile_shape = [ - math_inst.instruction_shape[0], - math_inst.instruction_shape[1], - math_inst.instruction_shape[2] * 4 - ] - warp_count = [4, 1, 1] - tile_description = TileDescription( - tile_shape, stages, warp_count, math_inst, - minimum_compute_capability, maximum_compute_capability, - cluster_shape) - tile_descriptions.append(tile_description) - - # It's typical to get the data types from the math instruction. - data_type = { - "a_type" : math_inst.element_a, - "b_type" : math_inst.element_b, - "c_type" : math_inst.element_accumulator, - "d_type" : math_inst.element_accumulator, - "acc_type" : math_inst.element_accumulator, - "epi_type" : math_inst.element_accumulator - } - - for conv_kind in conv_kinds: - epilogue_schedule = EpilogueScheduleType.TmaWarpSpecialized - schedule_pairs = [ - (mainloop_schedule, epilogue_schedule) - ] - CreateConvOperator3x(manifest, - dims_and_alignments = dims_and_alignments, - tile_descriptions = tile_descriptions, - data_types = data_type, - schedule_pairs = schedule_pairs, - tile_schedulers = [TileSchedulerType.Default], # -> void - conv_kind = conv_kind, - log_indent_level = log_indent_level) + for (conv_kind, spatial_dim, data_types, byte_alignments, mma_shape, cluster_shape) in combinations_of_parameters: + math_inst = make_math_instruction(data_types, mma_shape) + tile_shape = (mma_shape[0], mma_shape[1], num_mma_per_tile * mma_shape[2]) + tile_description = TileDescription(tile_shape, stages, warp_count, math_inst, + minimum_compute_capability, maximum_compute_capability, cluster_shape) + assert(isinstance(spatial_dim, int)) + assert(isinstance(byte_alignments, dict)) + dims_and_alignments = ( + ( + (spatial_dim, byte_alignments['A']), + (spatial_dim, byte_alignments['B']), + (spatial_dim, byte_alignments['C']), + ), + ) + CreateConvOperator3x(manifest, + dims_and_alignments = dims_and_alignments, + tile_descriptions = [tile_description], + data_types = data_types, + schedule_pairs = schedule_pairs, + tile_schedulers = tile_schedulers, + conv_kind = conv_kind, + log_indent_level = log_indent_level) def GenerateSM90(manifest, cuda_version): GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version) @@ -6738,8 +7008,8 @@ def GenerateSM90(manifest, cuda_version): def numeric_log_level(log_level: str) -> int: """ - Converts the string identifier of the log level into the numeric identifier used - in setting the log level + Converts the string identifier of the log level + into the numeric identifier used in setting the log level. :param x: string representation of log level (e.g., 'INFO', 'DEBUG') :type x: str @@ -6762,8 +7032,18 @@ def define_parser(): parser.add_argument("--curr-build-dir", default=".", help="CUTLASS current build directory. cmake files will be emitted in this directory") parser.add_argument("--generator-target", default='library', help="Target of CUTLASS Library Generator.") parser.add_argument("--architectures", default='53;60;61;70;75;80;90', help="Target compute architectures") - parser.add_argument("--kernels", default='', help='Comma delimited list to filter kernels by name.') - parser.add_argument("--ignore-kernels", default='', help='Comma delimited list of kernels to exclude from build.') + parser.add_argument("--kernels", default='', help='Comma-delimited list to filter kernels by name. ' + + 'Specifying this as \"all\" includes ALL the kernels, ' + + 'while not specifying this includes only the default set of kernels.') + parser.add_argument("--ignore-kernels", default='', help='Comma-delimited list of kernels ' + + 'to exclude from build. For backwards compatibility reasons, ' + + 'this option only takes effect if --kernels is set to a nonempty value.') + parser.add_argument("--exclude-kernels", default='', help='Comma-delimited list of kernels ' + + 'to exclude from build. In contrast to --ignore-kernels, ' + + 'this option always takes effect, ' + + 'whether or not --kernels is set to a nonempty value. ' + + 'It also can exclude kernels from the filter file ' + + '(see --kernel-filter-file option below).') parser.add_argument("--filter-by-cc", default='True', type=str, help='If enabled, kernels whose compute capability range is not satisfied by the build target are excluded.') parser.add_argument("--cuda-version", default="11.0.0", help="Semantic version string of CUDA Toolkit") parser.add_argument('--kernel-filter-file', type=str, default=None, required=False, help='Full path of filter file') diff --git a/python/cutlass_library/manifest.py b/python/cutlass_library/manifest.py index aed0df3283..b31d8dd23e 100644 --- a/python/cutlass_library/manifest.py +++ b/python/cutlass_library/manifest.py @@ -506,6 +506,7 @@ def __init__(self, args = None): self.operations_enabled = [] self.selected_kernels = [] self.ignore_kernel_names = [] + self.exclude_kernel_names = [] self.compute_capabilities = [50,] self.curr_build_dir = '.' self.filter_by_cc = True @@ -546,6 +547,7 @@ def __init__(self, args = None): self.kernel_names = [x for x in args.kernels.split(',') if x != ''] self.ignore_kernel_names = [x for x in args.ignore_kernels.split(',') if x != ''] + self.exclude_kernel_names = [x for x in args.exclude_kernels.split(',') if x != ''] if args.kernel_filter_file is None: self.kernel_filter_list = [] @@ -612,41 +614,54 @@ def filter(self, operation): if len(self.operations_enabled) and not operation.operation_kind in self.operations_enabled: return False + name = operation.procedural_name() + # eliminate duplicates - if operation.procedural_name() in self.operations_by_name.keys(): + if name in self.operations_by_name.keys(): return False # Filter based on list of valid substrings if len(self.kernel_names): - name = operation.procedural_name() enabled = False # compare against the include list for name_substr in self.kernel_names: if self._filter_string_matches(name_substr, name): - _LOGGER.debug("Kernel {kernel} included due to filter string '{filt}'.".format( - kernel = operation.procedural_name(), - filt = name_substr)) + _LOGGER.debug(f"Kernel {name} included due to filter string '{name_substr}'.") enabled = True break + else: + _LOGGER.debug(f"Kernel {name} NOT included due to not matching '{name_substr}'.") # compare against the exclude list for name_substr in self.ignore_kernel_names: if self._filter_string_matches(name_substr, name): - _LOGGER.debug("Kernel {kernel} ignored due to filter string '{filt}'.".format( - kernel = operation.procedural_name(), - filt = name_substr)) + _LOGGER.debug(f"Kernel {name} ignored due to filter string '{name_substr}'.") enabled = False break - - if len(self.kernel_filter_list) > 0: - if self.filter_out_kernels(operation.procedural_name(), self.kernel_filter_list): - _LOGGER.debug("Kernel {kernel} matched via kernel filter file.".format(kernel = operation.procedural_name())) - enabled = True else: - _LOGGER.debug("Kernel {kernel} culled due to no match in kernel filter file.".format(kernel = operation.procedural_name())) - enabled = False + _LOGGER.debug(f"Kernel {name} NOT ignored due to not matching '{name_substr}'.") + if len(self.kernel_filter_list) > 0: + if self.filter_out_kernels(name, self.kernel_filter_list): + _LOGGER.debug(f"Kernel {name} matched via kernel filter file.") + enabled = True + else: + _LOGGER.debug(f"Kernel {name} culled due to no match in kernel filter file.") + enabled = False + + # CUTLASS_LIBRARY_IGNORE_KERNELS ("ignore" list) only takes effect + # if CUTLASS_LIBRARY_KERNELS was specified. + # Changing that would break backwards compatibility. + # Thus, CUTLASS has introduced the new CMake option CUTLASS_LIBRARY_EXCLUDE_KERNELS, + # that always takes effect, whether or not CUTLASS_LIBRARY_KERNELS was specified. + for name_substr in self.exclude_kernel_names: + if self._filter_string_matches(name_substr, name): + _LOGGER.debug(f"Kernel {name} excluded due to filter string '{name_substr}'.") + enabled = False + break + else: + _LOGGER.debug(f"Kernel {name} NOT excluded due to not matching '{name_substr}'.") # TODO: filter based on compute data type return enabled diff --git a/python/setup_library.py b/python/setup_library.py index b6f4dabf7b..870840324c 100644 --- a/python/setup_library.py +++ b/python/setup_library.py @@ -36,7 +36,7 @@ def perform_setup(): setup( name='cutlass_library', - version='3.5.0', + version='3.5.1', description='CUTLASS library generation scripts', packages=['cutlass_library'] ) diff --git a/python/setup_pycute.py b/python/setup_pycute.py index 680957da1a..24e30e9bc6 100644 --- a/python/setup_pycute.py +++ b/python/setup_pycute.py @@ -36,7 +36,7 @@ def perform_setup(): setup( name='pycute', - version='3.5.0', + version='3.5.1', description='Python implementation of CuTe', packages=['pycute'], ) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index a101902007..eb802d80a6 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -33,3 +33,7 @@ else() add_custom_target(test_unit) endif() +if (CUTLASS_ENABLE_SELF_CONTAINED_INCLUDES_CHECK) + add_subdirectory(self_contained_includes) +endif() + diff --git a/test/self_contained_includes/CMakeLists.txt b/test/self_contained_includes/CMakeLists.txt new file mode 100644 index 0000000000..6425b5cdbc --- /dev/null +++ b/test/self_contained_includes/CMakeLists.txt @@ -0,0 +1,57 @@ +# Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# The purpose of this target is to check if the following header files are self-contained, +# i.e. they can be included in a source file without needing to include other headers before it. + +set(header_files_to_check + cutlass/gemm/kernel/default_gemm.h + cutlass/gemm/kernel/default_gemm_complex.h + cutlass/gemm/kernel/gemm_universal_decl.h + # cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp + + cute/tensor_impl.hpp +) + +# for each header in _header_files: +# create a .cu file with the same name as the header's path, except with / replaced with % +# have the .cu file include that header +set(_gen_source_files "") +foreach(header_file ${header_files_to_check}) + string(REPLACE "/" "%" header_file_esc ${header_file}) + + file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/${header_file_esc}.cu" + "#include <${header_file}>") + + list(APPEND _gen_source_files + "${CMAKE_CURRENT_BINARY_DIR}/${header_file_esc}.cu") +endforeach() + +# build all generated .cu files into a single library +cutlass_add_library(test_self_contained_includes MODULE ${_gen_source_files}) + diff --git a/test/unit/conv/device/conv2d_fprop_implicit_gemm_f8nhwc_f8nhwc_f8nhwc_tensor_op_f32_sm89.cu b/test/unit/conv/device/conv2d_fprop_implicit_gemm_f8nhwc_f8nhwc_f8nhwc_tensor_op_f32_sm89.cu index 9f2178ab38..3443551400 100644 --- a/test/unit/conv/device/conv2d_fprop_implicit_gemm_f8nhwc_f8nhwc_f8nhwc_tensor_op_f32_sm89.cu +++ b/test/unit/conv/device/conv2d_fprop_implicit_gemm_f8nhwc_f8nhwc_f8nhwc_tensor_op_f32_sm89.cu @@ -60,7 +60,7 @@ TEST(SM89_Device_Conv2d_Fprop_Analytic_ImplicitGemm_fe4m3nhwc_fe4mnhwc_fe4mnhwc_ using ElementB = cutlass::float_e4m3_t; using ElementOutput = cutlass::float_e4m3_t; using ElementAuxOutput = ElementOutput; - using ElementAccumulator = float;; + using ElementAccumulator = float; static int const kStages = 3; using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationGenericWithScalingAndAbsMax< @@ -104,7 +104,7 @@ TEST(SM89_Device_Conv2d_Fprop_Analytic_ImplicitGemm_fe5m2nhwc_fe4m3nhwc_fe4m3nhw using ElementB = cutlass::float_e4m3_t; using ElementOutput = cutlass::float_e4m3_t; using ElementAuxOutput = ElementOutput; - using ElementAccumulator = float;; + using ElementAccumulator = float; static int const kStages = 3; using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationGenericWithScalingAndAbsMax< @@ -148,7 +148,7 @@ TEST(SM89_Device_Conv2d_Fprop_Analytic_ImplicitGemm_fe5m2nhwc_fe4m3nhwc_fe5m2nhw using ElementB = cutlass::float_e4m3_t; using ElementOutput = cutlass::float_e5m2_t; using ElementAuxOutput = ElementOutput; - using ElementAccumulator = float;; + using ElementAccumulator = float; static int const kStages = 3; using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationGenericWithScalingAndAbsMax< @@ -192,7 +192,7 @@ TEST(SM89_Device_Conv2d_Fprop_Optimized_ImplicitGemm_fe4m3nhwc_fe4mnhwc_fe4mnhwc using ElementB = cutlass::float_e4m3_t; using ElementOutput = cutlass::float_e4m3_t; using ElementAuxOutput = ElementOutput; - using ElementAccumulator = float;; + using ElementAccumulator = float; static int const kStages = 3; using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationGenericWithScalingAndAbsMax< @@ -236,7 +236,7 @@ TEST(SM89_Device_Conv2d_Fprop_Optimized_ImplicitGemm_fe4m3nhwc_fe4mnhwc_fe4mnhwc using ElementB = cutlass::float_e4m3_t; using ElementOutput = cutlass::float_e4m3_t; using ElementAuxOutput = ElementOutput; - using ElementAccumulator = float;; + using ElementAccumulator = float; static int const kStages = 3; using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationGenericWithScalingAndAbsMax< @@ -280,7 +280,7 @@ TEST(SM89_Device_Conv2d_Fprop_Optimized_ImplicitGemm_fe4m3nhwc_fe4mnhwc_fe4mnhwc using ElementB = cutlass::float_e4m3_t; using ElementOutput = cutlass::float_e4m3_t; using ElementAuxOutput = ElementOutput; - using ElementAccumulator = float;; + using ElementAccumulator = float; static int const kStages = 3; using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationGenericWithScalingAndAbsMax< @@ -324,7 +324,7 @@ TEST(SM89_Device_Conv2d_Fprop_Optimized_ImplicitGemm_fe4m3nhwc_fe4mnhwc_fe4mnhwc using ElementB = cutlass::float_e4m3_t; using ElementOutput = cutlass::float_e4m3_t; using ElementAuxOutput = ElementOutput; - using ElementAccumulator = float;; + using ElementAccumulator = float; static int const kStages = 3; using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationGenericWithScalingAndAbsMax< diff --git a/test/unit/conv/device/conv2d_fprop_with_broadcast_simt_sm80.cu b/test/unit/conv/device/conv2d_fprop_with_broadcast_simt_sm80.cu index 944af8b6ee..c3015bfda1 100644 --- a/test/unit/conv/device/conv2d_fprop_with_broadcast_simt_sm80.cu +++ b/test/unit/conv/device/conv2d_fprop_with_broadcast_simt_sm80.cu @@ -104,7 +104,7 @@ template < template class UnaryOp, bool TestSplitK = true > -static void Conv2dFpropSM80TestResidaulBlock() { +static void Conv2dFpropSM80TestResidualBlock() { using ElementA = float; using ElementB = float; using ElementC = float; @@ -162,7 +162,7 @@ static void Conv2dFpropSM80TestResidaulBlock() { TEST(SM80_Device_Conv2d_Fprop_With_Residual_Block_Plus_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, 128x128_8x4_32x64x8) { // Resnet - Conv2dFpropSM80TestResidaulBlock(); + Conv2dFpropSM80TestResidualBlock(); } //////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/conv/device/conv2d_fprop_with_broadcast_sm70.cu b/test/unit/conv/device/conv2d_fprop_with_broadcast_sm70.cu index 5a4fa085c9..6e235c7905 100644 --- a/test/unit/conv/device/conv2d_fprop_with_broadcast_sm70.cu +++ b/test/unit/conv/device/conv2d_fprop_with_broadcast_sm70.cu @@ -60,7 +60,7 @@ template < template class UnaryOp, bool TestSplitK = false > -void TestResidaulBlock() { +void Conv2dFpropSM70TestResidualBlock() { using ElementA = cutlass::half_t; using ElementB = cutlass::half_t; using ElementC = cutlass::half_t; @@ -117,7 +117,7 @@ void TestResidaulBlock() { TEST(SM70_Device_Conv2d_Fprop_With_Residual_Block_Plus_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, 128x128_32x2_64x64x32) { // Resnet - TestResidaulBlock(); + Conv2dFpropSM70TestResidualBlock(); } //////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/conv/device/conv2d_fprop_with_broadcast_sm75.cu b/test/unit/conv/device/conv2d_fprop_with_broadcast_sm75.cu index d3805e7c52..f2d3e584d8 100644 --- a/test/unit/conv/device/conv2d_fprop_with_broadcast_sm75.cu +++ b/test/unit/conv/device/conv2d_fprop_with_broadcast_sm75.cu @@ -103,7 +103,7 @@ template < template class UnaryOp, bool TestSplitK = true > -void TestResidaulBlock() { +void Conv2dFpropSM75TestResidualBlock() { using ElementA = cutlass::half_t; using ElementB = cutlass::half_t; using ElementC = cutlass::half_t; @@ -160,14 +160,14 @@ void TestResidaulBlock() { TEST(SM75_Device_Conv2d_Fprop_With_Residual_Block_Plus_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, 128x128_32x2_64x64x32) { // Resnet - TestResidaulBlock(); + Conv2dFpropSM75TestResidualBlock(); } TEST(SM75_Device_Conv2d_Fprop_With_Residual_Block_Multiply_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, 128x128_32x2_64x64x32) { // EfficientNet V2 // Do not run split-K tests since the activation op is not Identity. - TestResidaulBlock(); + Conv2dFpropSM75TestResidualBlock(); } //////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/conv/device/conv3d_fprop_with_broadcast_simt_sm80.cu b/test/unit/conv/device/conv3d_fprop_with_broadcast_simt_sm80.cu index bc0dee0e03..a3461f8e58 100644 --- a/test/unit/conv/device/conv3d_fprop_with_broadcast_simt_sm80.cu +++ b/test/unit/conv/device/conv3d_fprop_with_broadcast_simt_sm80.cu @@ -104,7 +104,7 @@ template < template class UnaryOp, bool TestSplitK = true > -static void Conv3dFpropSM80TestResidaulBlock() { +static void Conv3dFpropSM80TestResidualBlock() { using ElementA = float; using ElementB = float; using ElementC = float; @@ -162,7 +162,7 @@ static void Conv3dFpropSM80TestResidaulBlock() { TEST(SM80_Device_Conv3d_Fprop_With_Residual_Block_Plus_Analytic_ImplicitGemm_f32ndhwc_f32ndhwc_f32ndhwc_simt_f32, 128x128_8x4_32x64x8) { // Resnet - Conv3dFpropSM80TestResidaulBlock(); + Conv3dFpropSM80TestResidualBlock(); } //////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/conv/device/conv3d_with_broadcast_testbed.h b/test/unit/conv/device/conv3d_with_broadcast_testbed.h index cc7c06f7da..437dbd30bd 100644 --- a/test/unit/conv/device/conv3d_with_broadcast_testbed.h +++ b/test/unit/conv/device/conv3d_with_broadcast_testbed.h @@ -204,17 +204,29 @@ class TestbedConv3dWithBroadcast { } void initialize( - cutlass::conv::Conv3dProblemSize const &problem_size, uint64_t seed = 2019) { + cutlass::conv::Conv3dProblemSize const &problem_size, bool non_packed_test = false, uint64_t seed = 2019) { - tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); - tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); - tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - tensor_C_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - tensor_Z_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - tensor_Z_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + // to make the layout of tensors a little bit bigger than the problem size + cutlass::Tensor5DCoord stride_increment = cutlass::Tensor5DCoord(8, 16, 32, 32, 64); + + cutlass::Tensor5DCoord tensor_A_extent = implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size); + cutlass::Tensor5DCoord tensor_B_extent = implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size); + cutlass::Tensor5DCoord tensor_C_extent = implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size); + + if (non_packed_test) { + tensor_A_extent += stride_increment; + tensor_C_extent += stride_increment; + } + + tensor_A.resize(tensor_A_extent); + tensor_B.resize(tensor_B_extent); + tensor_C.resize(tensor_C_extent); + tensor_C_reference.resize(tensor_C_extent); + tensor_Z_computed.resize(tensor_C_extent); + tensor_Z_reference.resize(tensor_C_extent); tensor_T_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); tensor_T_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); - tensor_Y_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_Y_reference.resize(tensor_C_extent); tensor_Broadcast.resize({ 1, 1, @@ -282,6 +294,7 @@ class TestbedConv3dWithBroadcast { bool run( cutlass::conv::Conv3dProblemSize const &problem_size, cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + bool non_packed_test = false, ElementCompute alpha = ElementCompute(1), ElementCompute beta = ElementCompute(1)) { @@ -300,7 +313,7 @@ class TestbedConv3dWithBroadcast { << std::endl; #endif - initialize(problem_size); + initialize(problem_size, non_packed_test); // configure the operator Conv3d conv3d_op; @@ -479,6 +492,7 @@ class TestbedConv3dWithBroadcast { << problem_size.dilation_h << "x" << problem_size.dilation_w << "_" << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") + << (non_packed_test ? "non_packed_tensor_test_" : "packed_tensor_test_") << Conv3d::ThreadblockShape::kM << "x" << Conv3d::ThreadblockShape::kN << "x" << Conv3d::ThreadblockShape::kK << "_" @@ -521,7 +535,8 @@ template bool TestAllConv3dWithBroadcast( const Conv3dProblemVector &conv_test_sizes = Conv3dProblemVector(), - const Conv3dProblemVector &conv_blacklist_sizes = Conv3dProblemVector()) { + const Conv3dProblemVector &conv_blacklist_sizes = Conv3dProblemVector(), + bool non_packed_test = false) { bool passed = true; @@ -595,17 +610,17 @@ bool TestAllConv3dWithBroadcast( // test mode = xcross passed = testbed.run( conv_problem, - cutlass::conv::SplitKMode::kSerial); - + cutlass::conv::SplitKMode::kSerial, non_packed_test); + if (!passed) { return false; } - + // test mode = convolution passed = testbed.run( conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), - cutlass::conv::SplitKMode::kSerial); - + cutlass::conv::SplitKMode::kSerial, non_packed_test); + if (!passed) { return false; } @@ -651,6 +666,7 @@ bool TestAllConv3dWithBroadcast( passed = testbed.run( conv3d_split_k_test_size.reset_split_k_slices(split_k_slice), split_k_mode, + false,/*non_packed_test*/ cutlass::from_real(alpha), cutlass::from_real(beta)); @@ -669,7 +685,8 @@ template , bool AddBroadcastFirst = false> bool TestSpecificConv3dWithBroadcast( - const Conv3dProblemVector & problem_sizes) { + const Conv3dProblemVector & problem_sizes, + bool non_packed_test = false) { bool passed = true; @@ -686,19 +703,19 @@ bool TestSpecificConv3dWithBroadcast( // Test // - // test mode = xcross + // test mode = xcross, non_packed_test = false passed = testbed.run( conv_problem, - cutlass::conv::SplitKMode::kSerial); + cutlass::conv::SplitKMode::kSerial, non_packed_test); if (!passed) { return false; } - // test mode = convolution + // test mode = convolution, non_packed_test = false passed = testbed.run( conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), - cutlass::conv::SplitKMode::kSerial); + cutlass::conv::SplitKMode::kSerial, non_packed_test); if (!passed) { return false; diff --git a/test/unit/conv/device/deconv2d_with_broadcast_simt_sm80.cu b/test/unit/conv/device/deconv2d_with_broadcast_simt_sm80.cu index 7872f8a466..bfb85d5126 100644 --- a/test/unit/conv/device/deconv2d_with_broadcast_simt_sm80.cu +++ b/test/unit/conv/device/deconv2d_with_broadcast_simt_sm80.cu @@ -104,7 +104,7 @@ template < template class UnaryOp, bool TestSplitK = true > -static void Deconv2dSM80TestResidaulBlock() { +static void Deconv2dSM80TestResidualBlock() { using ElementA = float; using ElementB = float; using ElementC = float; @@ -163,7 +163,7 @@ static void Deconv2dSM80TestResidaulBlock() { TEST(SM80_Device_Deconv2d_With_Residual_Block_Plus_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, 128x128_8x4_32x64x8) { // Resnet - Deconv2dSM80TestResidaulBlock(); + Deconv2dSM80TestResidualBlock(); } //////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/conv/device/deconv3d_with_broadcast_simt_sm80.cu b/test/unit/conv/device/deconv3d_with_broadcast_simt_sm80.cu index e0d0171f7f..09817d71a0 100644 --- a/test/unit/conv/device/deconv3d_with_broadcast_simt_sm80.cu +++ b/test/unit/conv/device/deconv3d_with_broadcast_simt_sm80.cu @@ -103,7 +103,7 @@ template < template class UnaryOp, bool TestSplitK = true > -static void Deconv3dSM80TestResidaulBlock() { +static void Deconv3dSM80TestResidualBlock() { using ElementA = float; using ElementB = float; using ElementC = float; @@ -162,7 +162,7 @@ static void Deconv3dSM80TestResidaulBlock() { TEST(SM80_Device_Deconv3d_With_Residual_Block_Plus_Analytic_ImplicitGemm_f32ndhwc_f32ndhwc_f32ndhwc_simt_f32, 128x128_8x4_32x64x8) { // Resnet - Deconv3dSM80TestResidaulBlock(); + Deconv3dSM80TestResidualBlock(); } //////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/conv/device_3x/conv_problem_sizes.hpp b/test/unit/conv/device_3x/conv_problem_sizes.hpp index cecff9f64c..d7dd062321 100644 --- a/test/unit/conv/device_3x/conv_problem_sizes.hpp +++ b/test/unit/conv/device_3x/conv_problem_sizes.hpp @@ -42,7 +42,7 @@ namespace test::conv::device { ///////////////////////////////////////////////////////////////////////////////////////////////// -template +template std::vector> inline get_conv_problem_vector(); @@ -297,7 +297,7 @@ get_conv_problem_vector<2, cutlass::conv::Operator::kFprop>() { // 2x5 filter, asymmetric padding 1,0/1,0, w/ stride problem_shapes.push_back({ cutlass::conv::Mode::kCrossCorrelation, - {2, 8, 8, 64}, + {2, 7, 7, 64}, {256, 2, 5, 64}, {1, 1}, {0, 0}, @@ -319,7 +319,7 @@ get_conv_problem_vector<2, cutlass::conv::Operator::kFprop>() { // 2x5 filter, asymmetric padding 1,0/1,0, w/ stride, w/ dilation problem_shapes.push_back({ cutlass::conv::Mode::kCrossCorrelation, - {2, 16, 16, 64}, + {2, 16, 15, 64}, {256, 2, 5, 64}, {1, 1}, {0, 0}, @@ -658,7 +658,7 @@ get_conv_problem_vector<2, cutlass::conv::Operator::kWgrad>() { // 2x5 filter, asymmetric padding 1,0/1,0, w/ stride problem_shapes.push_back({ cutlass::conv::Mode::kCrossCorrelation, - {2, 16, 16, 32}, + {2, 15, 16, 32}, {256, 2, 5, 32}, {1, 1}, {0, 0}, @@ -680,7 +680,7 @@ get_conv_problem_vector<2, cutlass::conv::Operator::kWgrad>() { // 2x5 filter, asymmetric padding 1,0/1,0, w/ stride, w/ dilation problem_shapes.push_back({ cutlass::conv::Mode::kCrossCorrelation, - {2, 16, 16, 32}, + {2, 16, 15, 32}, {256, 2, 5, 32}, {1, 1}, {0, 0}, @@ -688,6 +688,28 @@ get_conv_problem_vector<2, cutlass::conv::Operator::kWgrad>() { {2, 3}, 1 }); + // To test streamk, equals to gemm-MxNxK size 128x640x2048 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 64, 16, 128}, // nhwc + {640, 1, 1, 128}, // krsc + {0, 0}, // padding lower (pad_h, pad_w) + {0, 0}, // padding upper (pad_h, pad_w) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + 1 // group + }); + // To test streamk, equals to gemm-MxNxK size 128x640x2080 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 65, 16, 128}, // nhwc + {640, 1, 1, 128}, // krsc + {0, 0}, // padding lower (pad_h, pad_w) + {0, 0}, // padding upper (pad_h, pad_w) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + 1 // group + }); return problem_shapes; } @@ -751,17 +773,39 @@ get_conv_problem_vector<3, cutlass::conv::Operator::kWgrad>() { {2, 2, 3}, 1 }); + // To test streamk, equals to gemm-MxNxK size 128x640x2048 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 1, 64, 16, 128}, // ndhwc + {640, 1, 1, 1, 128}, // ktrsc + {0, 0, 0}, // padding lower (pad_d, pad_h, pad_w) + {0, 0, 0}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 1 // group + }); + // To test streamk, equals to gemm-MxNxK size 128x640x2080 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 1, 65, 16, 128}, // ndhwc + {640, 1, 1, 1, 128}, // ktrsc + {0, 0, 0}, // padding lower (pad_d, pad_h, pad_w) + {0, 0, 0}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 1 // group + }); return problem_shapes; } ///////////////////////////////////////////////////////////////////////////////////////////////// -// Dgrad +// Unit Stride Dgrad ///////////////////////////////////////////////////////////////////////////////////////////////// // Specialization for 1D dgrad problems template<> std::vector> inline -get_conv_problem_vector<1, cutlass::conv::Operator::kDgrad>() { +get_conv_problem_vector<1, cutlass::conv::Operator::kDgrad, false>() { using ProblemShape = cutlass::conv::ConvProblemShape; std::vector problem_shapes; problem_shapes.push_back({ @@ -884,7 +928,7 @@ get_conv_problem_vector<1, cutlass::conv::Operator::kDgrad>() { // Specialization for 2D dgrad problems template<> std::vector> inline -get_conv_problem_vector<2, cutlass::conv::Operator::kDgrad>() { +get_conv_problem_vector<2, cutlass::conv::Operator::kDgrad, false>() { using ProblemShape = cutlass::conv::ConvProblemShape; std::vector problem_shapes; problem_shapes.push_back({ @@ -1007,7 +1051,7 @@ get_conv_problem_vector<2, cutlass::conv::Operator::kDgrad>() { // Specialization for 3D dgrad problems template<> std::vector> inline -get_conv_problem_vector<3, cutlass::conv::Operator::kDgrad>() { +get_conv_problem_vector<3, cutlass::conv::Operator::kDgrad, false>() { using ProblemShape = cutlass::conv::ConvProblemShape; std::vector problem_shapes; // Filter-K = 16 for predication @@ -1082,6 +1126,134 @@ get_conv_problem_vector<3, cutlass::conv::Operator::kDgrad>() { return problem_shapes; } +///////////////////////////////////////////////////////////////////////////////////////////////// +// Strided Dgrad +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Specialization for 1D dgrad problems +template<> +std::vector> inline +get_conv_problem_vector<1, cutlass::conv::Operator::kDgrad, true>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + // non-packed input/output strides. + // stride divides dilation + // asymmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {3, 8, 64}, // nqk + {800, 80, 1}, // stride (nqk) + {64, 3, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {800, 80, 1}, // stride (nwc) + {0}, // padding lower (pad_w) + {1}, // padding upper (pad_w) + {2}, // stride (stride_w) + {4}, // dilation (dilation_w) + 1 // group + }); + // non-packed input/output strides. + // dilation divides stride + // asymmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {3, 8, 64}, // nqk + {800, 80, 1}, // stride (nqk) + {64, 3, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {800, 80, 1}, // stride (nwc) + {1}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {4}, // stride (stride_w) + {2}, // dilation (dilation_w) + 1 // group + }); + // non-packed input/output strides. + // stride dilation dont divide + // asymmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {3, 8, 64}, // nqk + {800, 80, 1}, // stride (nqk) + {64, 3, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {800, 80, 1}, // stride (nwc) + {1}, // padding lower (pad_w) + {2}, // padding upper (pad_w) + {2}, // stride (stride_w) + {3}, // dilation (dilation_w) + 1 // group + }); + return problem_shapes; +} + +// Specialization for 2D dgrad problems +template<> +std::vector> inline +get_conv_problem_vector<2, cutlass::conv::Operator::kDgrad, true>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation + // mode 0 stride divides dilation + // mode 1 dilation divides stride + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {3, 16, 16, 64}, + {256, 2, 5, 64}, + {1, 0}, + {0, 1}, + {2, 4}, + {4, 2}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation + // mode 0 dilation divides stride + // mode 1 stride divides dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {3, 16, 16, 64}, + {256, 2, 5, 64}, + {1, 0}, + {0, 1}, + {4, 2}, + {2, 4}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation + // stride dilation dont divide + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {3, 16, 16, 64}, + {256, 2, 5, 64}, + {1, 0}, + {0, 1}, + {3, 2}, + {2, 3}, + 1 + }); + return problem_shapes; +} + +// Specialization for 3D dgrad problems +template<> +std::vector> inline +get_conv_problem_vector<3, cutlass::conv::Operator::kDgrad, true>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + // Filter 3x4x5 + asymmetric padding 102/010, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 10, 16, 64}, + {64, 3, 4, 5, 96}, + {1, 0, 1}, + {0, 2, 0}, + {2, 4, 2}, + {4, 2, 3}, + 1 + }); + return problem_shapes; +} + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::test diff --git a/test/unit/conv/device_3x/fprop/sm90_conv2d_fprop_implicit_gemm_f16_f16_f32_tensorop_f16.cu b/test/unit/conv/device_3x/fprop/sm90_conv2d_fprop_implicit_gemm_f16_f16_f32_tensorop_f16.cu index 7768c7f660..47d510b864 100644 --- a/test/unit/conv/device_3x/fprop/sm90_conv2d_fprop_implicit_gemm_f16_f16_f32_tensorop_f16.cu +++ b/test/unit/conv/device_3x/fprop/sm90_conv2d_fprop_implicit_gemm_f16_f16_f32_tensorop_f16.cu @@ -93,6 +93,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using Conv = cutlass::conv::device::ConvUniversalAdapter; EXPECT_TRUE(test::conv::device::TestAllConv()); + EXPECT_TRUE(test::conv::device::TestAllConv(/*alpha=*/1.0, /*beta=*/1.0)); } // @@ -137,6 +138,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using Conv = cutlass::conv::device::ConvUniversalAdapter; EXPECT_TRUE(test::conv::device::TestAllConv()); + EXPECT_TRUE(test::conv::device::TestAllConv(/*alpha=*/1.0, /*beta=*/1.0)); } // @@ -181,6 +183,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using Conv = cutlass::conv::device::ConvUniversalAdapter; EXPECT_TRUE(test::conv::device::TestAllConv()); + EXPECT_TRUE(test::conv::device::TestAllConv(/*alpha=*/1.0, /*beta=*/1.0)); } // @@ -225,6 +228,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using Conv = cutlass::conv::device::ConvUniversalAdapter; EXPECT_TRUE(test::conv::device::TestAllConv()); + EXPECT_TRUE(test::conv::device::TestAllConv(/*alpha=*/1.0, /*beta=*/1.0)); } ////////////////////////////////////////////////////////////////////////////////////////////////// @@ -273,6 +277,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using Conv = cutlass::conv::device::ConvUniversalAdapter; EXPECT_TRUE(test::conv::device::TestAllConv()); + EXPECT_TRUE(test::conv::device::TestAllConv(/*alpha=*/1.0, /*beta=*/1.0)); } // @@ -317,6 +322,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using Conv = cutlass::conv::device::ConvUniversalAdapter; EXPECT_TRUE(test::conv::device::TestAllConv()); + EXPECT_TRUE(test::conv::device::TestAllConv(/*alpha=*/1.0, /*beta=*/1.0)); } // @@ -361,6 +367,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using Conv = cutlass::conv::device::ConvUniversalAdapter; EXPECT_TRUE(test::conv::device::TestAllConv()); + EXPECT_TRUE(test::conv::device::TestAllConv(/*alpha=*/1.0, /*beta=*/1.0)); } // @@ -405,6 +412,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f16 using Conv = cutlass::conv::device::ConvUniversalAdapter; EXPECT_TRUE(test::conv::device::TestAllConv()); + EXPECT_TRUE(test::conv::device::TestAllConv(/*alpha=*/1.0, /*beta=*/1.0)); } #endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/conv/device_3x/fprop/sm90_conv2d_fprop_implicit_gemm_f16_f16_f32_tensorop_f32.cu b/test/unit/conv/device_3x/fprop/sm90_conv2d_fprop_implicit_gemm_f16_f16_f32_tensorop_f32.cu index b8c86f4ddc..7faffa266b 100644 --- a/test/unit/conv/device_3x/fprop/sm90_conv2d_fprop_implicit_gemm_f16_f16_f32_tensorop_f32.cu +++ b/test/unit/conv/device_3x/fprop/sm90_conv2d_fprop_implicit_gemm_f16_f16_f32_tensorop_f32.cu @@ -93,6 +93,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32 using Conv = cutlass::conv::device::ConvUniversalAdapter; EXPECT_TRUE(test::conv::device::TestAllConv()); + EXPECT_TRUE(test::conv::device::TestAllConv(/*alpha=*/1.0, /*beta=*/1.0)); } // @@ -137,6 +138,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32 using Conv = cutlass::conv::device::ConvUniversalAdapter; EXPECT_TRUE(test::conv::device::TestAllConv()); + EXPECT_TRUE(test::conv::device::TestAllConv(/*alpha=*/1.0, /*beta=*/1.0)); } // @@ -181,6 +183,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32 using Conv = cutlass::conv::device::ConvUniversalAdapter; EXPECT_TRUE(test::conv::device::TestAllConv()); + EXPECT_TRUE(test::conv::device::TestAllConv(/*alpha=*/1.0, /*beta=*/1.0)); } // @@ -225,6 +228,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32 using Conv = cutlass::conv::device::ConvUniversalAdapter; EXPECT_TRUE(test::conv::device::TestAllConv()); + EXPECT_TRUE(test::conv::device::TestAllConv(/*alpha=*/1.0, /*beta=*/1.0)); } ////////////////////////////////////////////////////////////////////////////////////////////////// @@ -273,6 +277,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32 using Conv = cutlass::conv::device::ConvUniversalAdapter; EXPECT_TRUE(test::conv::device::TestAllConv()); + EXPECT_TRUE(test::conv::device::TestAllConv(/*alpha=*/1.0, /*beta=*/1.0)); } // @@ -317,6 +322,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32 using Conv = cutlass::conv::device::ConvUniversalAdapter; EXPECT_TRUE(test::conv::device::TestAllConv()); + EXPECT_TRUE(test::conv::device::TestAllConv(/*alpha=*/1.0, /*beta=*/1.0)); } // @@ -361,6 +367,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32 using Conv = cutlass::conv::device::ConvUniversalAdapter; EXPECT_TRUE(test::conv::device::TestAllConv()); + EXPECT_TRUE(test::conv::device::TestAllConv(/*alpha=*/1.0, /*beta=*/1.0)); } // @@ -405,6 +412,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32 using Conv = cutlass::conv::device::ConvUniversalAdapter; EXPECT_TRUE(test::conv::device::TestAllConv()); + EXPECT_TRUE(test::conv::device::TestAllConv(/*alpha=*/1.0, /*beta=*/1.0)); } #endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/conv/device_3x/fprop/sm90_conv2d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32.cu b/test/unit/conv/device_3x/fprop/sm90_conv2d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32.cu index 9c1394064e..6e397cfbd9 100644 --- a/test/unit/conv/device_3x/fprop/sm90_conv2d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32.cu +++ b/test/unit/conv/device_3x/fprop/sm90_conv2d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32.cu @@ -93,6 +93,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f using Conv = cutlass::conv::device::ConvUniversalAdapter; EXPECT_TRUE(test::conv::device::TestAllConv()); + EXPECT_TRUE(test::conv::device::TestAllConv(/*alpha=*/1.0, /*beta=*/1.0)); } // @@ -137,6 +138,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f using Conv = cutlass::conv::device::ConvUniversalAdapter; EXPECT_TRUE(test::conv::device::TestAllConv()); + EXPECT_TRUE(test::conv::device::TestAllConv(/*alpha=*/1.0, /*beta=*/1.0)); } // @@ -181,6 +183,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f using Conv = cutlass::conv::device::ConvUniversalAdapter; EXPECT_TRUE(test::conv::device::TestAllConv()); + EXPECT_TRUE(test::conv::device::TestAllConv(/*alpha=*/1.0, /*beta=*/1.0)); } // @@ -225,6 +228,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f using Conv = cutlass::conv::device::ConvUniversalAdapter; EXPECT_TRUE(test::conv::device::TestAllConv()); + EXPECT_TRUE(test::conv::device::TestAllConv(/*alpha=*/1.0, /*beta=*/1.0)); } ////////////////////////////////////////////////////////////////////////////////////////////////// @@ -273,6 +277,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f using Conv = cutlass::conv::device::ConvUniversalAdapter; EXPECT_TRUE(test::conv::device::TestAllConv()); + EXPECT_TRUE(test::conv::device::TestAllConv(/*alpha=*/1.0, /*beta=*/1.0)); } // @@ -317,6 +322,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f using Conv = cutlass::conv::device::ConvUniversalAdapter; EXPECT_TRUE(test::conv::device::TestAllConv()); + EXPECT_TRUE(test::conv::device::TestAllConv(/*alpha=*/1.0, /*beta=*/1.0)); } // @@ -361,6 +367,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f using Conv = cutlass::conv::device::ConvUniversalAdapter; EXPECT_TRUE(test::conv::device::TestAllConv()); + EXPECT_TRUE(test::conv::device::TestAllConv(/*alpha=*/1.0, /*beta=*/1.0)); } // @@ -405,6 +412,7 @@ TEST(SM90_device_conv2d_fprop_implicitgemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f using Conv = cutlass::conv::device::ConvUniversalAdapter; EXPECT_TRUE(test::conv::device::TestAllConv()); + EXPECT_TRUE(test::conv::device::TestAllConv(/*alpha=*/1.0, /*beta=*/1.0)); } #endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/conv/device_3x/testbed_conv.hpp b/test/unit/conv/device_3x/testbed_conv.hpp index e22c2cfffc..0545e78346 100644 --- a/test/unit/conv/device_3x/testbed_conv.hpp +++ b/test/unit/conv/device_3x/testbed_conv.hpp @@ -40,6 +40,7 @@ #include "cutlass/kernel_hardware_info.hpp" #include "cutlass/conv/convolution.h" #include "cutlass/conv/convnd_problem_shape.hpp" +#include "../test/unit/gemm/device/gemm_testbed_3x.hpp" #include "thrust/universal_vector.h" #include "cutlass/util/distribution.h" @@ -64,6 +65,7 @@ namespace test::conv::device { ///////////////////////////////////////////////////////////////////////////////////////////////// + // Initializes a flat device buffer template static void @@ -104,7 +106,39 @@ initialize_values( } ///////////////////////////////////////////////////////////////////////////////////////////////// +// utils for sparse or dense conv parameters + template +struct DenseConvParams { + // Default Kernel data types + using ElementA = typename Conv::ConvKernel::ElementA; + using ElementB = typename Conv::ConvKernel::ElementB; + + static constexpr cutlass::conv::Operator ConvOp = Conv::DispatchPolicy::ConvOp; + static constexpr int NumSpatialDimensions = Conv::NumSpatialDimensions; + using ProblemShape = cutlass::conv::ConvProblemShape; + + // get the default arguments without sparse data + auto get_mainloop_arguments( + ProblemShape const& problem_shape, + thrust::universal_vector& tensor_A, + thrust::universal_vector& tensor_B + ) { + auto args = typename Conv::ConvKernel::MainloopArguments { + problem_shape, + tensor_A.data().get(), + tensor_B.data().get(), + }; + return args; + } +}; + +template +struct SparseConvParams { +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template struct ConvTestbed { // Kernel data types using ElementA = typename Conv::ConvKernel::ElementA; @@ -114,6 +148,11 @@ struct ConvTestbed { using ElementD = typename Conv::ConvKernel::ElementD; using ElementAccumulator = typename Conv::ConvKernel::ElementAccumulator; + // ConvTest for sparse kernel + static constexpr bool isSparseEnabled = isSparseEnabled_; + using ConvParams = cute::conditional_t, DenseConvParams>; + ConvParams params; + // // FusionOperation derived types/queries // @@ -134,6 +173,8 @@ struct ConvTestbed { static constexpr bool IsBiasEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithBias::value && !cute::is_same_v; + static constexpr bool DisableSource = cute::is_void_v; + using StrideC = typename Conv::ConvKernel::StrideC; using StrideD = typename Conv::ConvKernel::StrideD; using ThreadEpilogueOp = typename Conv::ConvKernel::CollectiveEpilogue::ThreadEpilogueOp; @@ -141,6 +182,10 @@ struct ConvTestbed { static constexpr cutlass::conv::Operator ConvOp = Conv::DispatchPolicy::ConvOp; static constexpr int NumSpatialDimensions = Conv::NumSpatialDimensions; using ProblemShape = cutlass::conv::ConvProblemShape; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using MaxSwizzleSize = typename gemm::device::detail::MaxSwizzleSize; + using Splits = typename gemm::device::detail::Splits; using Schedule = typename Conv::DispatchPolicy::Schedule; /// Initialization @@ -148,6 +193,7 @@ struct ConvTestbed { cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform; cutlass::Distribution::Kind init_C = cutlass::Distribution::Uniform; cutlass::Distribution::Kind init_bias = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind init_disable = cutlass::Distribution::Identity; // all zeros uint64_t seed = 6090; float epsilon = 0.0f; int split_p_slices = 1; @@ -160,7 +206,8 @@ struct ConvTestbed { thrust::universal_vector tensor_alpha; thrust::universal_vector tensor_beta; - void initialize(ProblemShape const& problem_shape, uint64_t seed = 6090) { + // Return true on success, else false + bool initialize(ProblemShape const& problem_shape, uint64_t seed = 6090) { tensor_A.resize(sizeof(ElementA) * problem_shape.size_A()); tensor_B.resize(sizeof(ElementB) * problem_shape.size_B()); tensor_C.resize(sizeof(ElementC) * problem_shape.size_C()); @@ -171,6 +218,12 @@ struct ConvTestbed { initialize_values(tensor_B, init_B, seed * 11); initialize_values(tensor_C, init_C, seed * 17); initialize_values(tensor_bias, init_bias, seed * 19); + bool flag = true; + if constexpr (isSparseEnabled) { + flag &= params.initialize(problem_shape, tensor_B, static_cast(seed + 2023)); + } + + return flag; } // Determine SMEM requirements and waive if not satisfied @@ -190,11 +243,16 @@ struct ConvTestbed { return max_smem_size >= Conv::ConvKernel::SharedStorageSize; } - /// Executes one test + // Executes one test bool run( ProblemShape const& problem_shape, ElementScalar alpha = ElementScalar(1), ElementScalar beta = ElementScalar(0) + , + RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, + MaxSwizzleSize max_swizzle = MaxSwizzleSize{}, + Splits splits = Splits{}, + DecompositionMode decomposition_mode = DecompositionMode::Heuristic ) { // Waive test if insufficient CUDA device @@ -205,7 +263,12 @@ struct ConvTestbed { return true; } - initialize(problem_shape); + bool ret = initialize(problem_shape); + + if (!ret) { + std::cerr << "initialize failed for the given problem_shape: \n"; + return false; + } cutlass::KernelHardwareInfo hw_info; cudaGetDevice(&hw_info.device_id); @@ -230,20 +293,27 @@ struct ConvTestbed { cute::get<0, i>(stride_D) = problem_shape.stride_C[ProblemShape::RankT-2-i]; }); } + + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + typename Conv::ConvKernel::TileScheduler::Arguments scheduler_args{}; + if constexpr (cute::is_same_v) { + scheduler_args = { static_cast(splits), static_cast(max_swizzle), raster_order, decomposition_mode }; + } + + auto mainloop_args = params.get_mainloop_arguments(problem_shape, tensor_A, tensor_B); + auto epilogue_args = typename Conv::ConvKernel::EpilogueArguments { + {}, + tensor_C.data().get(), + stride_C, + tensor_D_computed.data().get(), + stride_D, + }; + auto args = typename Conv::Arguments { - { - problem_shape, - tensor_A.data().get(), - tensor_B.data().get(), - }, // MainloopArguments - { - {}, - tensor_C.data().get(), - stride_C, - tensor_D_computed.data().get(), - stride_D, - }, // EpilogueArguments + mainloop_args, // MainloopArguments + epilogue_args, // EpilogueArguments hw_info, scheduler_args }; @@ -462,6 +532,8 @@ struct ConvTestbed { for (size_t i = 0; i < size_t(size(reference)); ++i) { if (reference(i) != computed(i)) { passed = false; + printf("[%llu] %f, %f\n", static_cast(i), + float(reference(i)), float(computed(i))); break; } } @@ -475,6 +547,8 @@ struct ConvTestbed { if (std::isnan(abs_error) || std::isnan(rel_error) || std::min(abs_error, rel_error) > epsilon) { passed = false; + printf("[%llu] %f, %f\n", static_cast(i), + float(reference(i)), float(computed(i))); break; } } @@ -488,18 +562,20 @@ struct ConvTestbed { cute::print("\n"); for (size_t i = 0; i < size_t(size(A)); ++i) { - printf("[%ld]: A = %f\n", i, float(A(i))); + printf("[%llu]: A = %f\n", static_cast(i), float(A(i))); } for (size_t i = 0; i < size_t(size(B)); ++i) { - printf("[%ld]: B = %f\n", i, float(B(i))); + printf("[%llu]: B = %f\n", static_cast(i), float(B(i))); } if constexpr (IsBiasEnabled) { for (size_t i = 0; i < size_t(size(tensor_bias)); ++i) { - printf("[%ld]: bias = %f\n", i, float(tensor_bias(i))); + printf("[%llu]: bias = %f\n", static_cast(i), + float(tensor_bias(i))); } } for (size_t i = 0; i < size_t(size(reference)); ++i) { - printf("[%ld]: ref = %f, computed = %f\n", i, float(reference(i)), float(computed(i))); + printf("[%llu]: ref = %f, computed = %f\n", static_cast(i), + float(reference(i)), float(computed(i))); } } #endif @@ -509,30 +585,56 @@ struct ConvTestbed { ///////////////////////////////////////////////////////////////////////////////////////////////// -template -bool TestAllConv(double alpha = 1.0, double beta = 0.0, float epsilon = 0.0f) { +template +bool TestAllConv(double alpha = 1.0, double beta = 0.0, float epsilon = 0.0f + ) { using ElementScalar = typename Conv::EpilogueOutputOp::ElementScalar; bool passed = true; ConvTestbed testbed; testbed.epsilon = epsilon; auto problem_vector = get_conv_problem_vector< - Conv::NumSpatialDimensions, Conv::DispatchPolicy::ConvOp>(); + Conv::NumSpatialDimensions, Conv::DispatchPolicy::ConvOp, SupportStrides>(); + + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using MaxSwizzleSize = typename gemm::device::detail::MaxSwizzleSize; + using Splits = typename gemm::device::detail::Splits; + + std::vector decomposition_modes = {DecompositionMode::Heuristic}; + static constexpr bool UsesStreamKScheduler = cute::is_same_v; + if constexpr (UsesStreamKScheduler) { + decomposition_modes.push_back(DecompositionMode::DataParallel); + decomposition_modes.push_back(DecompositionMode::SplitK); + decomposition_modes.push_back(DecompositionMode::StreamK); + } for (auto conv_problem : problem_vector) { #if CUTLASS_DEBUG_TRACE_LEVEL > 0 - print(conv_problem); + print(conv_problem); #endif - - passed = testbed.run( - conv_problem, - cutlass::from_real(alpha), - cutlass::from_real(beta)); - - if (!passed) { - printf("Failed test for "); print(conv_problem); - return false; - } + for (DecompositionMode decomp_mode : decomposition_modes) { + std::vector problem_splits = {Splits{1}}; + if (decomp_mode == DecompositionMode::Heuristic || decomp_mode == DecompositionMode::SplitK) { + problem_splits.push_back(Splits{2}); + } + for (auto splits : problem_splits) { + + passed = testbed.run( + conv_problem, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ,RasterOrderOptions::Heuristic, // raster_order + MaxSwizzleSize(1), + splits, + decomp_mode + ); + if (!passed) { + printf("Failed test for "); print(conv_problem); + return false; + } + } // splits + } // decomposition_mode } return passed; diff --git a/test/unit/core/CMakeLists.txt b/test/unit/core/CMakeLists.txt index d0e10a7b47..9c68d4af8b 100644 --- a/test/unit/core/CMakeLists.txt +++ b/test/unit/core/CMakeLists.txt @@ -42,6 +42,7 @@ cutlass_test_unit_add_executable( tensor_view.cu matrix_coord.cu numeric_conversion.cu + numeric_conversion_subbyte.cu fast_numeric_conversion.cu functional.cu ) diff --git a/test/unit/core/complex.cu b/test/unit/core/complex.cu index 07455a25bf..880af5868c 100644 --- a/test/unit/core/complex.cu +++ b/test/unit/core/complex.cu @@ -32,14 +32,16 @@ \brief CUTLASS host-device template for complex numbers supporting all CUTLASS numeric types. */ -// Standard Library's std::complex used for reference checking #include +#include #include "../common/cutlass_unit_test.h" #include "cutlass/complex.h" #include "cutlass/constants.h" #include "cutlass/numeric_conversion.h" +#include "cutlass/tfloat32.h" +#include ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -93,7 +95,6 @@ TEST(complex, f16_to_f32_conversion) { //////////////////////////////////////////////////////////////////////////////////////////////////// TEST(complex, exp_f32) { - cutlass::complex Z[] = { {1, 1}, {2 , cutlass::constants::pi()/2.0f }, @@ -126,6 +127,376 @@ TEST(complex, exp_f32) { } } +TEST(complex, absolute_value_real_and_imag) { + { + cutlass::complex z_d{3.0, 4.0}; + + auto abs_d = cutlass::abs(z_d); + static_assert(std::is_same_v); + EXPECT_EQ(abs_d, 5.0); + + auto real_d = cutlass::real(z_d); + static_assert(std::is_same_v); + EXPECT_EQ(real_d, 3.0); + + auto imag_d = cutlass::imag(z_d); + static_assert(std::is_same_v); + EXPECT_EQ(imag_d, 4.0); + } + + { + cutlass::complex z_f{3.0f, 4.0f}; + + auto abs_f = cutlass::abs(z_f); + static_assert(std::is_same_v); + EXPECT_EQ(abs_f, 5.0f); + + auto real_f = cutlass::real(z_f); + static_assert(std::is_same_v); + EXPECT_EQ(real_f, 3.0f); + + auto imag_f = cutlass::imag(z_f); + static_assert(std::is_same_v); + EXPECT_EQ(imag_f, 4.0f); + } + + { + cutlass::complex z_tf32{cutlass::tfloat32_t{3.0f}, cutlass::tfloat32_t{4.0f}}; + auto abs_tf32 = cutlass::abs(z_tf32); + static_assert(std::is_same_v); + EXPECT_EQ(abs_tf32, cutlass::tfloat32_t{5.0f}); + + auto real_tf32 = cutlass::real(z_tf32); + static_assert(std::is_same_v); + EXPECT_EQ(real_tf32, cutlass::tfloat32_t{3.0f}); + + auto imag_tf32 = cutlass::imag(z_tf32); + static_assert(std::is_same_v); + EXPECT_EQ(imag_tf32, cutlass::tfloat32_t{4.0f}); + } + + { + cutlass::complex z_i{3, 4}; + + // sqrt(int) isn't a valid overload, so cutlass::abs isn't tested. + auto real_i = cutlass::real(z_i); + static_assert(std::is_same_v); + EXPECT_EQ(real_i, 3); + + auto imag_i = cutlass::imag(z_i); + static_assert(std::is_same_v); + EXPECT_EQ(imag_i, 4); + } + + { + double x_d{3.0}; + + auto real_d = cutlass::real(x_d); + static_assert(std::is_same_v); + EXPECT_EQ(real_d, 3.0); + + auto imag_d = cutlass::imag(x_d); + static_assert(std::is_same_v); + EXPECT_EQ(imag_d, 0.0); + } + + { + float x_f{3.0f}; + + auto real_f = cutlass::real(x_f); + static_assert(std::is_same_v); + EXPECT_EQ(real_f, 3.0f); + + auto imag_f = cutlass::imag(x_f); + static_assert(std::is_same_v); + EXPECT_EQ(imag_f, 0.0f); + } + + { + cutlass::tfloat32_t x_tf32{3.0f}; + + auto real_tf32 = cutlass::real(x_tf32); + static_assert(std::is_same_v); + EXPECT_EQ(real_tf32, cutlass::tfloat32_t{3.0f}); + + auto imag_tf32 = cutlass::imag(x_tf32); + static_assert(std::is_same_v); + EXPECT_EQ(imag_tf32, cutlass::tfloat32_t{0.0f}); + } + + { + int x_i{3}; + + auto real_i = cutlass::real(x_i); + static_assert(std::is_same_v); + EXPECT_EQ(real_i, 3); + + auto imag_i = cutlass::imag(x_i); + static_assert(std::is_same_v); + EXPECT_EQ(imag_i, 0); + } +} + +// FakeReal and FakeComplex test whether cutlass::real and +// cutlass::imag correctly handle user-defined non-complex +// and complex number types. +namespace test { + +// These classes have no conversions to or from arithmetic types, so +// that the test can ensure that the implementation does not silently +// convert to, say, float or int. +class FakeReal { +public: + // cutlass::imag must be able to value-construct its noncomplex input. + FakeReal() = default; + + static CUTLASS_HOST_DEVICE FakeReal make_FakeReal(int val) { + return FakeReal{val}; + } + + friend CUTLASS_HOST_DEVICE bool operator==(FakeReal lhs, FakeReal rhs) { + return lhs.value_ == rhs.value_; + } + + friend CUTLASS_HOST_DEVICE FakeReal operator-(FakeReal const& x) { + return make_FakeReal(-x.value_); + } + +private: + CUTLASS_HOST_DEVICE FakeReal(int val) : value_(val) {} + int value_ = 0; +}; + +class FakeComplex { +public: + static CUTLASS_HOST_DEVICE FakeComplex + make_FakeComplex(FakeReal re, FakeReal im) { + return FakeComplex{re, im}; + } + + // Existence of member functions real and imag tell + // CUTLASS that FakeComplex is a complex number type. + CUTLASS_HOST_DEVICE FakeReal real() const { return real_; } + CUTLASS_HOST_DEVICE FakeReal imag() const { return imag_; } + + friend CUTLASS_HOST_DEVICE bool operator==(FakeComplex lhs, FakeComplex rhs) { + return lhs.real_ == rhs.real_ && lhs.imag_ == rhs.imag_; + } + +private: + CUTLASS_HOST_DEVICE FakeComplex(FakeReal re, FakeReal im) + : real_(re), imag_(im) + {} + + FakeReal real_{}; + FakeReal imag_{}; +}; + +CUTLASS_HOST_DEVICE FakeComplex conj(FakeComplex const& z) { + return FakeComplex::make_FakeComplex(z.real(), -z.imag()); +} + +// Variant of FakeComplex that has a hidden friend conj instead of a +// nonmember conj defined outside the class. +class FakeComplexWithHiddenFriendConj { +public: + static CUTLASS_HOST_DEVICE FakeComplexWithHiddenFriendConj + make_FakeComplexWithHiddenFriendConj(FakeReal re, FakeReal im) { + return FakeComplexWithHiddenFriendConj{re, im}; + } + + CUTLASS_HOST_DEVICE FakeReal real() const { return real_; } + CUTLASS_HOST_DEVICE FakeReal imag() const { return imag_; } + + friend CUTLASS_HOST_DEVICE bool + operator==(FakeComplexWithHiddenFriendConj lhs, + FakeComplexWithHiddenFriendConj rhs) + { + return lhs.real_ == rhs.real_ && lhs.imag_ == rhs.imag_; + } + + friend CUTLASS_HOST_DEVICE FakeComplexWithHiddenFriendConj + conj(FakeComplexWithHiddenFriendConj const& z) { + return FakeComplexWithHiddenFriendConj::make_FakeComplexWithHiddenFriendConj(z.real(), -z.imag()); + } + +private: + CUTLASS_HOST_DEVICE + FakeComplexWithHiddenFriendConj(FakeReal re, FakeReal im) + : real_(re), imag_(im) + {} + + FakeReal real_{}; + FakeReal imag_{}; +}; + +} // namespace test + +TEST(complex, real_and_imag_with_custom_types) { + using test::FakeReal; + using test::FakeComplex; + + { + FakeReal x = FakeReal::make_FakeReal(42); + auto x_r = cutlass::real(x); + static_assert(std::is_same_v); + EXPECT_EQ(x_r, FakeReal::make_FakeReal(42)); + auto x_i = cutlass::imag(x); + static_assert(std::is_same_v); + EXPECT_EQ(x_i, FakeReal::make_FakeReal(0)); + } + { + FakeComplex z = FakeComplex::make_FakeComplex( + FakeReal::make_FakeReal(3), FakeReal::make_FakeReal(4)); + auto z_r = cutlass::real(z); + static_assert(std::is_same_v); + EXPECT_EQ(z_r, FakeReal::make_FakeReal(3)); + auto z_i = cutlass::imag(z); + static_assert(std::is_same_v); + EXPECT_EQ(z_i, FakeReal::make_FakeReal(4)); + } +} + +namespace test { + +template +void conj_tester(T z, T z_c_expected, const char type_name[]) { + // Use cutlass::conj just like std::swap (the "std::swap two-step"). + using cutlass::conj; + auto z_c = conj(z); + static_assert(std::is_same_v); + constexpr bool is_cuComplex = std::is_same_v || + std::is_same_v; + if constexpr (is_cuComplex) { + EXPECT_EQ(z_c.x, z_c_expected.x); + EXPECT_EQ(z_c.y, z_c_expected.y) << "conj failed for type " << type_name; + } + else { + EXPECT_EQ(z_c, z_c_expected) << "conj failed for type " << type_name; + } + + auto z_c2 = cutlass::conjugate{}(z); + static_assert(std::is_same_v); + if constexpr (is_cuComplex) { + // cuFloatComplex and cuDoubleComplex don't report conj(z) as + // being well-formed, probably because they are type aliases of + // some kind. cutlass::conj works fine, though! + static_assert(! cutlass::platform::is_arithmetic_v && + (cutlass::detail::has_unqualified_conj_v || + cutlass::detail::has_cutlass_conj_v)); + + EXPECT_EQ(z_c2.x, z_c_expected.x); + EXPECT_EQ(z_c2.y, z_c_expected.y) + << "conjugate failed for type " << type_name; + } + else { + EXPECT_EQ(z_c2, z_c_expected) << "conjugate failed for type " << type_name; + } +} + +} // namespace test + +TEST(complex, conj_with_standard_arithmetic_types) { + { + double x = 42.0; + double x_c_expected = 42.0; + test::conj_tester(x, x_c_expected, "double"); + } + { + float x = 42.0f; + float x_c_expected = 42.0f; + test::conj_tester(x, x_c_expected, "float"); + } + { + int x = 42; + int x_c_expected = 42; + test::conj_tester(x, x_c_expected, "int"); + } +} + +TEST(complex, conj_with_cutlass_complex_types) { + { + cutlass::complex z{3.0, 4.0}; + cutlass::complex z_c_expected{3.0, -4.0}; + test::conj_tester(z, z_c_expected, "cutlass::complex"); + } + { + cutlass::complex z{3.0f, 4.0f}; + cutlass::complex z_c_expected{3.0f, -4.0f}; + test::conj_tester(z, z_c_expected, "cutlass::complex"); + } + { + cutlass::complex z{ + cutlass::tfloat32_t{3.0f}, cutlass::tfloat32_t{4.0f}}; + cutlass::complex z_c_expected{ + cutlass::tfloat32_t{3.0f}, cutlass::tfloat32_t{-4.0f}}; + test::conj_tester(z, z_c_expected, "cutlass::complex"); + } +} + +TEST(complex, conj_with_noncomplex_type_not_in_cutlass_namespace) { + test::FakeReal x = test::FakeReal::make_FakeReal(42); + test::FakeReal x_c_expected = test::FakeReal::make_FakeReal(42); + test::conj_tester(x, x_c_expected, "test::FakeReal"); +} + +TEST(complex, conj_with_noncomplex_type_in_cutlass_namespace) { + cutlass::tfloat32_t x{42.0f}; + cutlass::tfloat32_t x_c_expected{42.0f}; + test::conj_tester(x, x_c_expected, "cutlass::tfloat32_t"); +} + +TEST(complex, conj_with_complex_types_not_in_cutlass_namespace) { + using test::FakeReal; + + // conj defined as nonmember outside the class + { + test::FakeComplex z = test::FakeComplex::make_FakeComplex( + FakeReal::make_FakeReal(3), FakeReal::make_FakeReal(4)); + test::FakeComplex z_c_expected = test::FakeComplex::make_FakeComplex( + FakeReal::make_FakeReal(3), FakeReal::make_FakeReal(-4)); + test::conj_tester(z, z_c_expected, "test::FakeComplex"); + } + // conj defined as hidden friend + { + test::FakeComplexWithHiddenFriendConj z = + test::FakeComplexWithHiddenFriendConj::make_FakeComplexWithHiddenFriendConj( + FakeReal::make_FakeReal(3), + FakeReal::make_FakeReal(4)); + test::FakeComplexWithHiddenFriendConj z_c_expected = + test::FakeComplexWithHiddenFriendConj::make_FakeComplexWithHiddenFriendConj( + FakeReal::make_FakeReal(3), + FakeReal::make_FakeReal(-4)); + test::conj_tester(z, z_c_expected, "test::FakeComplexWithHiddenFriendConj"); + } +} + +TEST(complex, conj_with_cuda_std_complex_types) { + { + cuda::std::complex z{3.0, 4.0}; + cuda::std::complex z_c_expected{3.0, -4.0}; + test::conj_tester(z, z_c_expected, "cuda::std::complex"); + } + { + cuda::std::complex z{3.0f, 4.0f}; + cuda::std::complex z_c_expected{3.0f, -4.0f}; + test::conj_tester(z, z_c_expected, "cuda::std::complex"); + } +} + +TEST(complex, conj_with_cuComplex_types) { + { + cuDoubleComplex z = make_cuDoubleComplex(3.0, 4.0); + cuDoubleComplex z_c_expected = make_cuDoubleComplex(3.0, -4.0); + test::conj_tester(z, z_c_expected, "cuDoubleComplex"); + } + { + cuFloatComplex z = make_cuFloatComplex(3.0f, 4.0f); + cuFloatComplex z_c_expected = make_cuFloatComplex(3.0f, -4.0f); + test::conj_tester(z, z_c_expected, "cuFloatComplex"); + } +} + //////////////////////////////////////////////////////////////////////////////////////////////////// namespace test { diff --git a/test/unit/core/numeric_conversion.cu b/test/unit/core/numeric_conversion.cu index 75e12bdf14..1c966a9cc8 100644 --- a/test/unit/core/numeric_conversion.cu +++ b/test/unit/core/numeric_conversion.cu @@ -668,3 +668,4 @@ TYPED_TEST(VectorArrayConverterTest, array_263) { } ///////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/core/numeric_conversion_subbyte.cu b/test/unit/core/numeric_conversion_subbyte.cu new file mode 100644 index 0000000000..a670afce46 --- /dev/null +++ b/test/unit/core/numeric_conversion_subbyte.cu @@ -0,0 +1,69 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for conversion operators. +*/ + +#include "../common/cutlass_unit_test.h" + +#include "cutlass/numeric_conversion.h" +#include "cutlass/integer_subbyte.h" + +namespace test::core::host { + +template +void run_test() { + cutlass::Array dst; + dst.clear(); + + cutlass::Array src; + for (int k = 0; k < NumElements; ++k) { + src[k] = SrcValueType(k+1); + } + + cutlass::NumericArrayConverter converter; + dst = converter(src); + + for (int k = 0; k < NumElements; ++k) { + EXPECT_TRUE(static_cast(src[k]) == static_cast(dst[k])); + } +} + +} // namespace test::core::host + +TEST(NumericArrayConversion, Subbyte_int8_int8) { + test::core::host::run_test(); +} + +TEST(NumericArrayConversion, Subbyte_int8_int4) { + test::core::host::run_test(); +} + diff --git a/test/unit/cute/ampere/CMakeLists.txt b/test/unit/cute/ampere/CMakeLists.txt index fd701de656..6ac7f2f203 100644 --- a/test/unit/cute/ampere/CMakeLists.txt +++ b/test/unit/cute/ampere/CMakeLists.txt @@ -31,6 +31,7 @@ cutlass_test_unit_add_executable( cp_async.cu ldsm.cu cooperative_gemm.cu + cooperative_copy.cu ) cutlass_test_unit_add_executable( diff --git a/test/unit/cute/ampere/cooperative_copy.cu b/test/unit/cute/ampere/cooperative_copy.cu new file mode 100644 index 0000000000..aa0d7536df --- /dev/null +++ b/test/unit/cute/ampere/cooperative_copy.cu @@ -0,0 +1,633 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +using namespace cute; + +namespace cooperative_copy_mode { + struct global_shared {}; + struct global_global {}; + struct shared_shared {}; +} + +// gs --> global to/from shared +template +__device__ void +cooperative_copy_default_gs(T const* g_in, T* g_out, GMemLayout const& gmem_layout, SMemLayout const& smem_layout) +{ + using namespace cute; + extern __shared__ uint128_t smem_buf[]; + // Cast smem_buf to smem_uint8_ptr and move it by MaxVecBits bits + // This is to make sure tests pass on pointer aligned to MaxVecBits bits + uint8_t* smem_uint8_ptr = reinterpret_cast(smem_buf) + (MaxVecBits/8); + T* smem = reinterpret_cast(smem_uint8_ptr); + + Tensor g_in_tensor = make_tensor(make_gmem_ptr(g_in), gmem_layout); + Tensor g_out_tensor = make_tensor(make_gmem_ptr(g_out), gmem_layout); + Tensor s_tensor = make_tensor(make_smem_ptr(smem), smem_layout); + + cooperative_copy(threadIdx.x, g_in_tensor, s_tensor); + + cp_async_fence(); + cp_async_wait<0>(); + __syncthreads(); + + if(thread0()) { + for(int i = 0; i < size(s_tensor); ++i) { + s_tensor(i) += T(i); + } + } + __syncthreads(); + + cooperative_copy(threadIdx.x, s_tensor, g_out_tensor); +} + +// ss --> shared to shared +template +__device__ void +cooperative_copy_default_ss(T const* g_in, T* g_out, Layout1 const& layout1, Layout2 const& layout2) +{ + using namespace cute; + extern __shared__ uint128_t smem_buf[]; + // Cast smem_buf to smem_uint8_ptr and move it by MaxVecBits bits + // This is to make sure tests pass on pointer aligned to MaxVecBits bits + T* smem1 = reinterpret_cast(smem_buf); + uint8_t* smem2_uint8_ptr = reinterpret_cast(smem_buf) + (MaxVecBits/8); + T* smem2 = reinterpret_cast(smem2_uint8_ptr) + cute::cosize(layout2); + + Tensor g_in_tensor = make_tensor(make_gmem_ptr(g_in), layout1); + Tensor g_out_tensor = make_tensor(make_gmem_ptr(g_out), layout2); + + Tensor s1_tensor = make_tensor(make_smem_ptr(smem1), layout2); + Tensor s2_tensor = make_tensor(make_smem_ptr(smem2), layout1); + + cooperative_copy>(threadIdx.x, g_in_tensor, s1_tensor); + + cp_async_fence(); + cp_async_wait<0>(); + __syncthreads(); + + if(thread0()) { + for(int i = 0; i < size(s1_tensor); ++i) { + s1_tensor(i) += T(i); + } + } + __syncthreads(); + + cooperative_copy(threadIdx.x, s1_tensor, s2_tensor); + __syncthreads(); + + cooperative_copy>(threadIdx.x, s2_tensor, g_out_tensor); +} + +// gg --> global to global +template +__device__ void +cooperative_copy_default_gg(T const* g_in, T* g_out, Layout1 const& layout1, Layout2 const& layout2) +{ + using namespace cute; + + Tensor g_in_tensor = make_tensor(make_gmem_ptr(g_in), layout1); + Tensor g_out_tensor = make_tensor(make_gmem_ptr(g_out), layout2); + + cooperative_copy(threadIdx.x, g_in_tensor, g_out_tensor); +} + +template +__global__ void +cooperative_copy_default_kernel(T const* g_in, T* g_out, Layout1 const layout1, Layout2 const layout2) +{ + if constexpr(std::is_same_v) { + cooperative_copy_default_gs(g_in, g_out, layout1, layout2); + } else if constexpr (std::is_same_v) { + cooperative_copy_default_gg(g_in, g_out, layout1, layout2); + } else if constexpr (std::is_same_v) { + cooperative_copy_default_ss(g_in, g_out, layout1, layout2); + } +} + +// Mode - defines memory types of src and dst in cooperative_copy operation +// MaxVecBits - defines max vectorization in cooperative_copy operation, and enforces that +// alignment on used pointers to ensure correct testing +template +void test_cooperative_copy_default(Layout1 const& layout1, Layout2 const& layout2) +{ + using value_type = T; + CUTE_STATIC_ASSERT_V(cute::size(layout1) == cute::size(layout2)); + + auto gmem_layout_in = layout1; + auto gmem_layout_out = cute::conditional_return>(layout1, layout2); + +#if 0 + print(" "); print("layout1: "); print(layout1); print("\n"); + print(" "); print("layout2: "); print(layout2); print("\n"); + print(" "); print("threads: "); print(ThreadBlockSize); print("\n"); + print(" "); print("maxvecbits: "); print(MaxVecBits); print("\n"); +#endif + + if constexpr (MaxVecBits < cute::sizeof_bits_v) { + GTEST_SKIP() << "Skipping test since MaxVecBits (=" << MaxVecBits + << ") < cute::sizeof_bits_v (=" << cute::sizeof_bits_v << ")"; + } else { + constexpr auto max_vec_bytes = MaxVecBits / 8; + static_assert((max_vec_bytes % sizeof(T)) == 0); + + uint32_t count = cute::cosize(gmem_layout_in); + // Extra elements to force MaxVecBits alignment in global memory + uint32_t extra_elements = max_vec_bytes / sizeof(value_type); + + // Allocate + thrust::host_vector h_in (count + extra_elements); + thrust::host_vector h_out(count + extra_elements); + + // Initialize + Tensor h_in_tensor = make_tensor(h_in.data() + extra_elements, gmem_layout_in); + Tensor h_out_tensor = make_tensor(h_out.data() + extra_elements, gmem_layout_out); + for (int i = 0; i < cute::size(h_in_tensor); ++i) { + h_in_tensor(i) = value_type(float(i)); + // For global-to-global copy need to compare against the same value + h_out_tensor(i) = std::is_same_v ? value_type(float(i)) : value_type(float(2 * i)); + } + + // To GPU + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(d_in.size(), value_type(float(-2))); + + // Adds (MaxVecBits/8) bytes to shared memory as we'll move pointer by that many bytes inside the kernel to enforce + // alignment to (MaxVecBits/8) bytes + size_t shared_memory_bytes = (sizeof(value_type) * count) + max_vec_bytes; + shared_memory_bytes += std::is_same_v * (sizeof(value_type) * count); + + // Launch + auto coop_copy = cooperative_copy_default_kernel; + ASSERT_EQ(cudaFuncSetAttribute(coop_copy, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(shared_memory_bytes)), cudaSuccess); + + auto d_in_ptr = thrust::raw_pointer_cast(d_in.data() + extra_elements); + auto d_out_ptr = thrust::raw_pointer_cast(d_out.data() + extra_elements); + coop_copy<<<1, ThreadBlockSize, shared_memory_bytes>>>(d_in_ptr, d_out_ptr, layout1, layout2); + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + cudaError_t error = cudaGetLastError(); + FAIL() << "Error at kernel sync: " << cudaGetErrorString(error) << "\n"; + } + + // Validate + thrust::host_vector h_result = d_out; + Tensor h_result_tensor = make_tensor(h_result.data() + extra_elements, gmem_layout_out); + for (int i = 0; i < cute::size(h_in_tensor); ++i) { + ASSERT_EQ(h_result_tensor(i), h_out_tensor(i)) + << i << " - result:" << h_result_tensor(i) << " expected:" << h_out_tensor(i); + } + } +} + +template +class SM80_CuTe_Ampere; + +template +class SM80_CuTe_Ampere>: public testing::Test +{ +public: + using mode = Mode; + static constexpr int max_vec_bits = MaxVecBits::value; +}; + +typedef testing::Types< + std::tuple>, + std::tuple>, + std::tuple>, + std::tuple>, + + std::tuple>, + std::tuple>, + std::tuple>, + std::tuple>, + + std::tuple>, + std::tuple>, + std::tuple>, + std::tuple>, +> CooperativeCopyModeMaxVecBitsList; + +TYPED_TEST_SUITE(SM80_CuTe_Ampere, CooperativeCopyModeMaxVecBitsList); + +// Fast path +TYPED_TEST(SM80_CuTe_Ampere, CooperativeCopyDefault1D) +{ + using value_type = float; + constexpr uint32_t count = 512; + auto gmem_layout = make_layout(make_shape(Int{})); + auto smem_layout = make_layout(make_shape(Int{})); + constexpr uint32_t thread_block_size = 64; + test_cooperative_copy_default(gmem_layout, smem_layout); +} + +TYPED_TEST(SM80_CuTe_Ampere, CooperativeCopyDefault1DFallback) +{ + using value_type = float; + constexpr uint32_t count = 99; + auto gmem_layout = make_layout(make_shape(Int{})); + auto smem_layout = make_layout(make_shape(Int{})); + constexpr uint32_t thread_block_size = 128; + test_cooperative_copy_default(gmem_layout, smem_layout); +} + +// Fast path +TYPED_TEST(SM80_CuTe_Ampere, CooperativeCopyDefault2D) +{ + using value_type = float; + constexpr uint32_t x = 32; + constexpr uint32_t y = 32; + auto gmem_layout = make_layout(make_shape(Int{}, Int{})); + auto smem_layout = make_layout(make_shape(Int{}, Int{})); + constexpr uint32_t thread_block_size = 64; + test_cooperative_copy_default(gmem_layout, smem_layout); +} + +#if 0 + +// Fast path +TYPED_TEST(SM80_CuTe_Ampere, CooperativeCopyDefault2DDynamicStrides) +{ + using value_type = float; + constexpr uint32_t x = 32; + constexpr uint32_t y = 32; + auto gmem_layout = make_layout(make_shape(Int{}, Int{}), make_stride(1, x)); + auto smem_layout = make_layout(make_shape(Int{}, Int{}), make_stride(1, x)); + constexpr uint32_t thread_block_size = 64; + test_cooperative_copy_default(gmem_layout, smem_layout); +} + + + +// Fast path +TYPED_TEST(SM80_CuTe_Ampere, CooperativeCopyDefault2DMixedStrides) +{ + using value_type = float; + constexpr uint32_t x = 32; + constexpr uint32_t y = 32; + auto gmem_layout = make_layout(make_shape(Int{}, Int{})); + auto smem_layout = make_layout(make_shape(Int{}, Int{}), make_stride(1, x)); + constexpr uint32_t thread_block_size = 64; + test_cooperative_copy_default(gmem_layout, smem_layout); +} + +#endif + +TYPED_TEST(SM80_CuTe_Ampere, CooperativeCopyDefault2DFallback) +{ + using value_type = float; + constexpr uint32_t x = 37; + constexpr uint32_t y = 37; + auto gmem_layout = make_layout(make_shape(Int{}, Int{})); + auto smem_layout = make_layout(make_shape(Int{}, Int{})); + constexpr uint32_t thread_block_size = 64; + test_cooperative_copy_default(gmem_layout, smem_layout); +} + +// Fast Path +TYPED_TEST(SM80_CuTe_Ampere, CooperativeCopyDefault2DCustomStride) +{ + using value_type = float; + constexpr uint32_t x = 16; + constexpr uint32_t y = 16; + auto gmem_layout = make_layout(make_shape(Int{}, Int{}), make_stride(Int{}, Int<1>{})); + auto smem_layout = make_layout(make_shape(Int{}, Int{}), make_stride(Int<1>{}, Int{})); + constexpr uint32_t thread_block_size = 64; + test_cooperative_copy_default(gmem_layout, smem_layout); +} + +// Fast path +TYPED_TEST(SM80_CuTe_Ampere, CooperativeCopyDefault3D) +{ + using value_type = cute::half_t; + constexpr uint32_t x = 8; + constexpr uint32_t y = 8; + constexpr uint32_t z = 16; + auto gmem_layout = make_layout(make_shape(Int{}, Int{}, Int{})); + auto smem_layout = make_layout(make_shape(Int{}, Int{}, Int{})); + constexpr uint32_t thread_block_size = 64; + test_cooperative_copy_default(gmem_layout, smem_layout); +} + +// Fast path +TYPED_TEST(SM80_CuTe_Ampere, CooperativeCopyDefault2Dto3D) +{ + using value_type = double; + constexpr uint32_t x = 16; + constexpr uint32_t y = 16; + constexpr uint32_t z = 4; + auto gmem_layout = make_layout(make_shape(Int{}, Int{})); + auto smem_layout = make_layout(make_shape(Int{}, Int{}, Int{})); + constexpr uint32_t thread_block_size = 64; + test_cooperative_copy_default(gmem_layout, smem_layout); +} + +// Fast path +TYPED_TEST(SM80_CuTe_Ampere, CooperativeCopyDefaultCustom1) +{ + using value_type = double; + auto gmem_layout = make_layout( + make_shape(Int<8>{}, make_shape(Int<2>{}, Int<2>{})), + make_stride(Int<2>{}, make_shape(Int<1>{}, Int<16>{})) + ); + auto smem_layout = make_layout( + make_shape(Int<8>{}, Int<4>{}), + make_stride(Int<4>{}, Int<1>{}) + ); + constexpr uint32_t thread_block_size = 8; + test_cooperative_copy_default(gmem_layout, smem_layout); +} + +// Fast Path +TYPED_TEST(SM80_CuTe_Ampere, CooperativeCopyDefaultCustom2) +{ + using value_type = float; + auto gmem_layout = make_layout( + make_shape(make_shape(Int<4>{}, Int<2>{}), make_shape(Int<2>{}, Int<2>{})), + make_stride(make_shape(Int<4>{}, Int<1>{}), make_shape(Int<16>{}, Int<2>{})) + ); + auto smem_layout = make_layout( + make_shape(make_shape(Int<2>{}, Int<2>{}, Int<2>{}), make_shape(Int<2>{}, Int<2>{})), + make_stride(make_shape(Int<16>{}, Int<4>{}, Int<1>{}), make_shape(Int<8>{}, Int<2>{})) + ); + constexpr uint32_t thread_block_size = 16; + test_cooperative_copy_default(gmem_layout, smem_layout); +} + +// Fast Path +TYPED_TEST(SM80_CuTe_Ampere, CooperativeCopyDefaultSwizzle1) +{ + using value_type = float; + auto gmem_layout = Layout, Stride<_64, _1>>{}; + auto smem_layout = composition(Swizzle<3, 3, 3>{}, Layout, Stride<_64, _1>>{}); + constexpr uint32_t thread_block_size = 128; + test_cooperative_copy_default(gmem_layout, smem_layout); +} + +// Fast Path +TYPED_TEST(SM80_CuTe_Ampere, CooperativeCopyDefaultSwizzle2) +{ + using value_type = cute::half_t; + auto gmem_layout = make_layout(make_shape(Int<64>{}, Int<64>{})); + auto smem_atom_layout = composition(Swizzle<3, 2, 3>{}, Layout, Stride<_32, _1>>{}); + auto smem_layout = tile_to_shape( + smem_atom_layout, + make_shape(shape<0>(gmem_layout), shape<1>(gmem_layout)) + ); + constexpr uint32_t thread_block_size = 128; + test_cooperative_copy_default(gmem_layout, smem_layout); +} + +// Fast Path +TYPED_TEST(SM80_CuTe_Ampere, CooperativeCopyDefaultSwizzle3) +{ + using value_type = cute::half_t; + auto gmem_layout = make_layout(make_shape(Int<64>{}, Int<64>{})); + auto smem_atom_layout = composition(Swizzle<2, 4, 3>{}, Layout, Stride<_64, _1>>{}); + auto smem_layout = tile_to_shape( + smem_atom_layout, + make_shape(shape<0>(gmem_layout), shape<1>(gmem_layout)) + ); + constexpr uint32_t thread_block_size = 128; + test_cooperative_copy_default(gmem_layout, smem_layout); +} + +// Fast path +TYPED_TEST(SM80_CuTe_Ampere, CooperativeCopyDefaultSwizzle4) +{ + using value_type = cute::half_t; + auto gmem_atom_layout = composition(Swizzle<3, 2, 3>{}, Layout, Stride<_32, _1>>{}); + auto smem_layout = make_layout(make_shape(Int<64>{}, Int<64>{})); + auto gmem_layout = tile_to_shape( + gmem_atom_layout, + make_shape(shape<0>(smem_layout), shape<1>(smem_layout)) + ); + constexpr uint32_t thread_block_size = 128; + test_cooperative_copy_default(gmem_layout, smem_layout); +} + +// Needs coalescing to work on fast path +// OK if we enforce slow path +// Problem: Wrong condition when we select between slow and fast path +TYPED_TEST(SM80_CuTe_Ampere, CooperativeCopyDefaultCoalesceToCompose) +{ + constexpr int m = 96; + using value_type = cute::half_t; + auto gmem_layout = make_layout(make_shape(Int{}, Int{}), GenColMajor{}); + auto smem_layout = make_layout(make_shape(Int{}, Int{}), GenColMajor{}); + constexpr uint32_t thread_block_size = 128; + test_cooperative_copy_default(gmem_layout, smem_layout); +} + + // Fast path (default): OK + // Slow path (enforced): OK + TYPED_TEST(SM80_CuTe_Ampere, CooperativeCopyDefaultSwizzle5) + { + constexpr int m = 64; + constexpr int n = 128; + using value_type = cute::half_t; + auto gmem_layout = make_layout(make_shape(Int{}, Int{}), GenColMajor{}); + // auto smem_layout = make_layout(make_shape(Int{}, Int{}), GenColMajor{})); + auto smem_atom_layout = + composition(Swizzle<3,3,3>{}, + Layout, + Stride<_64, _1>>{}); + auto smem_layout = tile_to_shape( + smem_atom_layout, + make_shape(shape<0>(gmem_layout), shape<1>(gmem_layout)) + ); + + constexpr uint32_t thread_block_size = 128; + test_cooperative_copy_default(gmem_layout, smem_layout); + } + + // If condition not strict enought will go to fast path + // This test needs checking if CuTe can compose layouts + // Fast path (default): fail + // Slow path (enforced): Should go to vectorized naive path + TYPED_TEST(SM80_CuTe_Ampere, CooperativeCopyDefaultSwizzleNaiveVectorizable) + { + constexpr int m = 192; + constexpr int n = 64; + using value_type = cute::half_t; + auto gmem_layout = make_layout(make_shape(Int{}, Int{}), GenColMajor{}); + // auto smem_layout = make_layout(make_shape(Int{}, Int{}), GenColMajor{}); + auto smem_atom_layout = + composition(Swizzle<3,3,3>{}, + Layout, + Stride< _1,_64>>{}); + auto smem_layout = tile_to_shape( + smem_atom_layout, + shape(gmem_layout) + ); + + constexpr uint32_t thread_block_size = 128; + test_cooperative_copy_default(gmem_layout, smem_layout); + } + + // fast path: ok (chosen) + // slow path: ok + TYPED_TEST(SM80_CuTe_Ampere, CooperativeCopyDefaultRowMajorSmall) + { + constexpr int m = 24; + constexpr int n = 8; + using value_type = cute::half_t; + auto gmem_layout = make_layout(make_shape(Int{}, Int{}), GenRowMajor{}); + auto smem_layout = make_layout(make_shape(Int{}, Int{}), GenRowMajor{}); + + constexpr uint32_t thread_block_size = 64; + test_cooperative_copy_default(gmem_layout, smem_layout); + } + + // fast path: doesn't apply + // slow path: ok + TYPED_TEST(SM80_CuTe_Ampere, CooperativeCopyDefaultSlowPath) + { + constexpr int m = 67; + constexpr int n = 67; + using value_type = cute::half_t; + auto gmem_layout = make_layout(make_shape(Int{}, Int{}), GenRowMajor{}); + auto smem_layout = make_layout(make_shape(Int{}, Int{}), GenRowMajor{}); + + constexpr uint32_t thread_block_size = 64; + test_cooperative_copy_default(gmem_layout, smem_layout); + } + + // fast path: doesn't apply + // slow path: should vectorize + TYPED_TEST(SM80_CuTe_Ampere, CooperativeCopyDefaultSwizzleSlowPathVectorize) + { + constexpr int m = 68; + constexpr int n = 68; + using value_type = cute::half_t; + auto gmem_layout = make_layout(make_shape(Int{}, Int{}), GenRowMajor{}); + auto smem_layout = make_layout(make_shape(Int{}, Int{}), GenRowMajor{}); + + constexpr uint32_t thread_block_size = 32; + test_cooperative_copy_default(gmem_layout, smem_layout); + } + + TYPED_TEST(SM80_CuTe_Ampere, CooperativeCopy48x48Swizzle) + { + constexpr int m = 48; + constexpr int n = 48; + using value_type = cute::half_t; + auto gmem_layout = make_layout(make_shape(Int{}, Int{}), GenRowMajor{}); + auto smem_layout = composition(Swizzle<2,2,3>{}, + Layout>>, + Stride, _16>>>{}); + + constexpr uint32_t thread_block_size = 8 * 32; + test_cooperative_copy_default(gmem_layout, smem_layout); + } diff --git a/test/unit/cute/ampere/cooperative_gemm.cu b/test/unit/cute/ampere/cooperative_gemm.cu index 2fcd01205d..02196204ad 100644 --- a/test/unit/cute/ampere/cooperative_gemm.cu +++ b/test/unit/cute/ampere/cooperative_gemm.cu @@ -298,3 +298,146 @@ TEST(SM80_CuTe_Ampere, CooperativeGemm8_MixedPrecisionTF32FP32_MMA) { test_cooperative_gemm_col_major_layout(); } + +TEST(SM80_CuTe_Ampere, CooperativeGemm9_C64C64C64_MMA) { + + using TA = cutlass::complex; + using TB = cutlass::complex; + using TC = cutlass::complex; + + constexpr uint32_t thread_block_size = 256; + constexpr int MaxVecBits = 128; + + using tiled_mma_t = + TiledMMA< + MMA_Atom, + Layout, Stride<_1, _4, _0>>, + Tile + >; + + using ALayout = Layout,Int<35>>, Stride, Int<1> >>; + using BLayout = Layout, Int<35>>, Stride, Int<1> >>; + using CLayout = Layout, Int<7>>, Stride< Int<1>, Int<30>>>; + + + test_cooperative_gemm, // A + AutoVectorizingCopyWithAssumedAlignment, // B + AutoVectorizingCopyWithAssumedAlignment, // C + thread_block_size, + tiled_mma_t, + MaxVecBits, + TA, + TB, + TC>(); + +} + +TEST(SM80_CuTe_Ampere, CooperativeGemm10_F16F64F16_FMA) { + + using TA = cutlass::half_t; + using TB = double; + using TC = cutlass::half_t; + + constexpr uint32_t thread_block_size = 256; + constexpr int MaxVecBits = 128; + + using tiled_mma_t = + TiledMMA< + MMA_Atom>, + Layout, Stride<_1, _16, _0>>, + Tile + >; + + using ALayout = Layout,Int<64>>, Stride, Int< 1>>>; + using BLayout = Layout,Int<64>>, Stride, Int<64>>>; + using CLayout = Layout,Int<64>>, Stride, Int<64>>>; + + + test_cooperative_gemm, // A + AutoVectorizingCopyWithAssumedAlignment, // B + AutoVectorizingCopyWithAssumedAlignment, // C + thread_block_size, + tiled_mma_t, + MaxVecBits, + TA, + TB, + TC>(); +} + +TEST(SM80_CuTe_Ampere, CooperativeGemmComposedStride) { + + using T = cute::half_t; + + constexpr uint32_t thread_block_size = 128; + constexpr int MaxVecBits = 16; + + using tiled_mma_t = + TiledMMA< + MMA_Atom, + Layout, Stride<_1, _2, _0>>, + Tile + >; + + using swizzle = cute::Swizzle<3, 3, 3>; + using offset = cute::_0; + using atom_tile_right = decltype(cute::make_layout(cute::Shape{}, cute::LayoutRight{})); + using FP16AtomLayoutRight = decltype(cute::composition(swizzle{}, offset{}, atom_tile_right{})); + + using shape = cute::Shape, cute::Int<128>>; + using global_a_layout = decltype(cute::make_layout(shape{}, cute::LayoutRight{})); + using global_b_layout = decltype(cute::make_layout(shape{}, cute::LayoutLeft{})); + using global_c_layout = decltype(cute::make_layout(shape{}, cute::LayoutRight{})); + + // This is for A row major, B col major according to CUTLASS default configs + using ALayout = decltype(cute::tile_to_shape(FP16AtomLayoutRight{}, global_a_layout{})); + using BLayout = decltype(cute::tile_to_shape(FP16AtomLayoutRight{}, global_b_layout{})); + using CLayout = global_c_layout; + + test_cooperative_gemm, // A + AutoVectorizingCopyWithAssumedAlignment, // B + AutoVectorizingCopyWithAssumedAlignment, // C + thread_block_size, + tiled_mma_t, + MaxVecBits, + T, + T, + T>(); +} + +TEST(SM89_CuTe_Ampere, CooperativeGemm8_MixedPrecisionTF32FP32_Transform) { + using TA = cutlass::tfloat32_t; + using TB = cutlass::tfloat32_t; + using TC = float; + + constexpr uint32_t m = 9; + constexpr uint32_t n = 9; + constexpr uint32_t k = 9; + + constexpr uint32_t thread_block_size = 64; + + using tiled_mma_t = + TiledMMA< + MMA_Atom, + Layout> + >; + + test_cooperative_gemm_col_major_layout(cute::negate{}, cute::negate{}, cute::negate{}, cute::negate{}); +} diff --git a/test/unit/cute/cooperative_gemm_common.hpp b/test/unit/cute/cooperative_gemm_common.hpp index 9f7f694619..5dec22ca0c 100644 --- a/test/unit/cute/cooperative_gemm_common.hpp +++ b/test/unit/cute/cooperative_gemm_common.hpp @@ -31,6 +31,7 @@ #pragma once +#include "cutlass/relatively_equal.h" #include "cutlass_unit_test.h" #include "cutlass/util/reference/host/tensor_compare.h" @@ -43,6 +44,16 @@ using namespace cute; +template +struct fp64_tester { + using value_type = double; +}; + +template +struct fp64_tester> { + using value_type = complex; +}; + template::value_type, typename fp64_tester::value_type>); + static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); + using tester = fp64_tester; + using ABC_64 = typename tester::value_type; + static_assert(size<0>(gmem_a_layout_t{}) == size<0>(gmem_c_layout_t{})); // AM == CM static_assert(size<0>(gmem_b_layout_t{}) == size<1>(gmem_c_layout_t{})); // BN == CN static_assert(size<1>(gmem_a_layout_t{}) == size<1>(gmem_b_layout_t{})); // AK == BK @@ -184,7 +200,7 @@ void test_cooperative_gemm(ALoadTransform const& a_load_transform = {}, h_a_tensor(i) = static_cast(di / size(gmem_a_layout_t{})); } if(i < size(gmem_b_layout_t{})) { - h_b_tensor(i) = static_cast(di / size(gmem_a_layout_t{})); + h_b_tensor(i) = static_cast(di / size(gmem_a_layout_t{})); } if(i < size(gmem_c_layout_t{})) { h_c_tensor(i) = static_cast((di*di) / size(gmem_a_layout_t{})); @@ -196,8 +212,10 @@ void test_cooperative_gemm(ALoadTransform const& a_load_transform = {}, thrust::device_vector d_c(h_c); thrust::device_vector d_c_out(h_c_out.size(), TC(float(-1))); - const size_t shared_memory_size = - (sizeof(TA) * h_a.size()) + (sizeof(TB) * h_b.size()) + (sizeof(TC) * h_c.size()); + constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; + const size_t shared_memory_size = round_up(sizeof(TA) * h_a.size(), copy_max_vec_bytes) + + round_up(sizeof(TB) * h_b.size(), copy_max_vec_bytes) + + (sizeof(TC) * h_c.size()); auto kernel = cooperative_gemm_kernel< gmem_a_layout_t, gmem_b_layout_t, gmem_c_layout_t, smem_a_layout_t, smem_b_layout_t, smem_c_layout_t, @@ -234,24 +252,24 @@ void test_cooperative_gemm(ALoadTransform const& a_load_transform = {}, for (int n = 0; n < size<0>(h_b_tensor); n++) { const auto a_value = a_load_transform(h_a_tensor(m, k)); const auto b_value = b_load_transform(h_b_tensor(n, k)); - const auto a_value_fp64 = static_cast(a_value); - const auto b_value_fp64 = static_cast(b_value); + const auto a_value_fp64 = static_cast(a_value); + const auto b_value_fp64 = static_cast(b_value); h_c_ref_tensor(m, n) += static_cast(a_value_fp64 * b_value_fp64); } } } // C = A*B + C for (int i = 0; i < size(h_c_ref_tensor); i++) { - const auto ab_value_fp64 = static_cast(h_c_ref_tensor(i)); - const auto c_value_fp64 = static_cast(c_load_transform(h_c_tensor(i))); + const auto ab_value_fp64 = static_cast(h_c_ref_tensor(i)); + const auto c_value_fp64 = static_cast(c_load_transform(h_c_tensor(i))); h_c_ref_tensor(i) = c_store_transform(static_cast(alpha * ab_value_fp64 + beta * c_value_fp64)); } h_c_out = d_c_out; auto h_c_out_tensor = make_tensor(h_c_out.data(), gmem_c_layout_t{}); for (int i = 0; i < size(h_c_ref_tensor); i++) { - double h_c_ref_i = h_c_ref_tensor(i); - double h_c_out_i = h_c_out_tensor(i); + ABC_64 h_c_ref_i = h_c_ref_tensor(i); + ABC_64 h_c_out_i = h_c_out_tensor(i); double epsilon(0.1f); double nonzero_floor(std::numeric_limits::min()); bool passed = cutlass::relatively_equal(h_c_out_i, h_c_ref_i, epsilon, nonzero_floor); diff --git a/test/unit/cute/core/CMakeLists.txt b/test/unit/cute/core/CMakeLists.txt index ddd23df31c..e6dea35f78 100644 --- a/test/unit/cute/core/CMakeLists.txt +++ b/test/unit/cute/core/CMakeLists.txt @@ -38,16 +38,19 @@ cutlass_test_unit_add_executable( composition.cpp constants.cpp core_unit.cpp + domain_distribute.cpp inverse_left.cpp inverse_right.cpp logical_divide.cpp logical_product.cpp - math.cpp + math.cpp mixedbits.cpp nullspace.cpp + packed_tuple.cpp pointer.cpp reverse.cpp transform.cpp tuple.cpp + tuple_find.cpp int_tuple.cpp ) diff --git a/test/unit/cute/core/array_subbyte.cpp b/test/unit/cute/core/array_subbyte.cpp index 0fd02f49c1..f3b94a8ff3 100644 --- a/test/unit/cute/core/array_subbyte.cpp +++ b/test/unit/cute/core/array_subbyte.cpp @@ -51,7 +51,7 @@ TEST(CuTe_core, ArraySubbyte) for (size_t i = 0; i < array1.size(); ++i) { array0[i+5] = array1[i]; } - + EXPECT_EQ(int4_t(array0.back()), int4_t(1)); for (size_t i = 0; i < array1.size(); ++i) { @@ -137,7 +137,7 @@ TEST(CuTe_core, Subbyte_iterator) { array_subbyte a{}; - auto tensor = make_tensor(subbyte_iterator(a.raw_data()), make_shape(15)); + auto tensor = make_tensor(a.begin(), make_shape(15)); fill(a, uint8_t(13)); for (int i = 0; i < int(a.size()); ++i) { @@ -150,7 +150,7 @@ TEST(CuTe_core, Subbyte_iterator) { array_subbyte a{}; - auto tensor = make_tensor(subbyte_iterator(a.raw_data()), make_shape(15)); + auto tensor = make_tensor(a.begin(), make_shape(15)); fill(a, int4_t(-5)); for (int i = 0; i < int(a.size()); ++i) { @@ -163,7 +163,7 @@ TEST(CuTe_core, Subbyte_iterator) { array_subbyte a{}; - auto tensor = make_tensor(subbyte_iterator(a.raw_data()), make_shape(15)); + auto tensor = make_tensor(a.begin(), make_shape(15)); fill(a, uint2_t(-5)); for (int i = 0; i < int(a.size()); ++i) { @@ -176,7 +176,7 @@ TEST(CuTe_core, Subbyte_iterator) { array_subbyte a{}; - auto tensor = make_tensor(subbyte_iterator(a.raw_data()), make_shape(15)); + auto tensor = make_tensor(a.begin(), make_shape(15)); fill(a, bool(1)); for (int i = 0; i < int(a.size()); ++i) { @@ -193,7 +193,7 @@ TEST(CuTe_core, Const_subbyte_iterator) { array_subbyte a{}; - auto tensor = make_tensor(subbyte_iterator(a.raw_data()), make_shape(15)); + auto tensor = make_tensor(a.begin(), make_shape(15)); fill(a, uint8_t(13)); for (int i = 0; i < int(a.size()); ++i) { @@ -206,7 +206,7 @@ TEST(CuTe_core, Const_subbyte_iterator) { array_subbyte a{}; - auto tensor = make_tensor(subbyte_iterator(a.raw_data()), make_shape(15)); + auto tensor = make_tensor(a.begin(), make_shape(15)); fill(a, int4_t(-5)); for (int i = 0; i < int(a.size()); ++i) { @@ -219,7 +219,7 @@ TEST(CuTe_core, Const_subbyte_iterator) { array_subbyte a{}; - auto tensor = make_tensor(subbyte_iterator(a.raw_data()), make_shape(15)); + auto tensor = make_tensor(a.begin(), make_shape(15)); fill(a, uint2_t(-5)); for (int i = 0; i < int(a.size()); ++i) { @@ -232,7 +232,7 @@ TEST(CuTe_core, Const_subbyte_iterator) { array_subbyte a{}; - auto tensor = make_tensor(subbyte_iterator(a.raw_data()), make_shape(15)); + auto tensor = make_tensor(a.begin(), make_shape(15)); fill(a, bool(1)); for (int i = 0; i < int(a.size()); ++i) { diff --git a/test/unit/cute/core/domain_distribute.cpp b/test/unit/cute/core/domain_distribute.cpp new file mode 100644 index 0000000000..55a4f76a67 --- /dev/null +++ b/test/unit/cute/core/domain_distribute.cpp @@ -0,0 +1,109 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#define CUTLASS_DEBUG_TRACE_LEVEL 1 + +#include "cutlass_unit_test.h" + +#include + +#include + +#include + +using namespace cute; + + +template +void +test_distribute(LayoutA const& layoutA, + LayoutB const& layoutB) +{ + auto layoutR = domain_distribute(shape(layoutA), shape(layoutB)); + + CUTLASS_TRACE_HOST("test_distribute()"); + CUTLASS_TRACE_HOST(layoutA << " <-> " << layoutB); + CUTLASS_TRACE_HOST(" => "); + CUTLASS_TRACE_HOST(layoutR); + + // Test that layout B is softly compatible with layout R + EXPECT_TRUE(softly_compatible(layoutB, layoutR)); + + // Post-condition on the codomain of the distribute + for (int i = 0; i < size(layoutR); ++i) { + for (int j = i+1; j < size(layoutR); ++j) { + EXPECT_TRUE(layoutR(i) < layoutR(j)); // Surjective and Ordered + } + } +} + + +TEST(CuTe_core, Distribute) +{ + CUTLASS_TRACE_HOST("-------------------------------"); + CUTLASS_TRACE_HOST("DOMAIN DISTRIBUTE" ); + CUTLASS_TRACE_HOST("-------------------------------"); + + { + auto shape_a = Shape,Shape<_8,_8>>{}; + auto shape_b = _128{}; + + test_distribute(shape_a, shape_b); + } + + { + auto shape_a = Shape,Shape<_8,_8>>{}; + auto shape_b = _128{}; + + test_distribute(shape_a, shape_b); + } + + { + auto shape_a = Shape,Shape<_8,_8>>{}; + auto shape_b = _128{} * _8{}; + + test_distribute(shape_a, shape_b); + } + + { + auto shape_a = Shape,Shape<_8,_8>>{}; + auto shape_b = _128{} * _8{}; + + test_distribute(shape_a, shape_b); + } + + { + auto shape_a = Shape>{}; + auto shape_b = _128{}; + + test_distribute(shape_a, shape_b); + } +} diff --git a/test/unit/cute/core/int_tuple.cpp b/test/unit/cute/core/int_tuple.cpp index cd237a278c..d68ff2a789 100644 --- a/test/unit/cute/core/int_tuple.cpp +++ b/test/unit/cute/core/int_tuple.cpp @@ -56,7 +56,7 @@ TEST(CuTe_core, WeaklyCongruent) EXPECT_TRUE (weakly_congruent(a0, a0)); EXPECT_TRUE (weakly_congruent(b0, b0)); EXPECT_TRUE (weakly_congruent(a0, b0)); - + auto a1 = Shape<_1, _1>{}; EXPECT_TRUE (weakly_congruent(a , a1)); EXPECT_FALSE(weakly_congruent(a0, a1)); @@ -93,7 +93,7 @@ TEST(CuTe_core, WeaklyCompatible) EXPECT_TRUE (weakly_compatible(a, a)); EXPECT_TRUE (weakly_compatible(b, b)); EXPECT_TRUE (weakly_compatible(c, c)); - EXPECT_FALSE(weakly_compatible(a, b)); + EXPECT_FALSE(weakly_compatible(a, b)); EXPECT_FALSE(weakly_compatible(a, c)); EXPECT_TRUE (weakly_compatible(c, a)); @@ -102,9 +102,9 @@ TEST(CuTe_core, WeaklyCompatible) EXPECT_TRUE (weakly_compatible(a , a0)); EXPECT_FALSE(weakly_compatible(a0, a )); EXPECT_TRUE (weakly_compatible(c , a0)); - EXPECT_FALSE(weakly_compatible(a0, c )); + EXPECT_FALSE(weakly_compatible(a0, c )); EXPECT_FALSE(weakly_compatible(b , a0)); - EXPECT_FALSE(weakly_compatible(a0, b )); + EXPECT_FALSE(weakly_compatible(a0, b )); auto a1 = Shape<_2,_8>{}; EXPECT_TRUE (weakly_compatible(a1, a1)); @@ -129,3 +129,50 @@ TEST(CuTe_core, WeaklyCompatible) EXPECT_TRUE (weakly_compatible(a2, a3)); EXPECT_FALSE(weakly_compatible(a3, a2)); } + +TEST(CuTe_core, SoftlyCompatible) +{ + using namespace cute; + + auto a = _16{}; + auto b = _12{}; + auto c = _8{}; + EXPECT_TRUE (softly_compatible(a, a)); + EXPECT_TRUE (softly_compatible(b, b)); + EXPECT_TRUE (softly_compatible(c, c)); + EXPECT_FALSE(softly_compatible(a, b)); + EXPECT_TRUE (softly_compatible(a, c)); + EXPECT_FALSE(softly_compatible(c, a)); + + auto a0 = Shape<_16>{}; + EXPECT_TRUE (softly_compatible(a0, a0)); + EXPECT_TRUE (softly_compatible(a , a0)); + EXPECT_FALSE(softly_compatible(a0, a )); + EXPECT_FALSE(softly_compatible(c , a0)); + EXPECT_FALSE(softly_compatible(a0, c )); + EXPECT_FALSE(softly_compatible(b , a0)); + EXPECT_FALSE(softly_compatible(a0, b )); + + auto a1 = Shape<_2,_8>{}; + EXPECT_TRUE (softly_compatible(a1, a1)); + EXPECT_TRUE (softly_compatible(a , a1)); + EXPECT_FALSE(softly_compatible(a0, a1)); + EXPECT_FALSE(softly_compatible(a1, a0)); + EXPECT_TRUE (softly_compatible(a1, Shape<_2,Shape<_2,_4>>{})); + + auto a2 = Shape>{}; + EXPECT_TRUE (softly_compatible(a2, a2)); + EXPECT_TRUE (softly_compatible(a , a2)); + EXPECT_FALSE(softly_compatible(c , a2)); + EXPECT_TRUE (softly_compatible(a0, a2)); + EXPECT_FALSE(softly_compatible(a2, a0)); + + auto a3 = Shape>>{}; + EXPECT_TRUE (softly_compatible(a3, a3)); + EXPECT_TRUE (softly_compatible(a , a3)); + EXPECT_FALSE(softly_compatible(c , a3)); + EXPECT_TRUE (softly_compatible(a0, a3)); + EXPECT_FALSE(softly_compatible(a3, a0)); + EXPECT_TRUE (softly_compatible(a2, a3)); + EXPECT_FALSE(softly_compatible(a3, a2)); +} diff --git a/test/unit/cute/core/packed_tuple.cpp b/test/unit/cute/core/packed_tuple.cpp new file mode 100644 index 0000000000..fbbcab0587 --- /dev/null +++ b/test/unit/cute/core/packed_tuple.cpp @@ -0,0 +1,581 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include + +#include +#include + +#include +#include +#include +#include +#include + +namespace pt_test { + +template +struct Nonempty { + T datum; + + Nonempty(T const& t) : datum{t} {} + + friend bool operator==(Nonempty const& lhs, Nonempty const& rhs) { + return lhs.datum == rhs.datum; + } + + friend bool operator!=(Nonempty const& lhs, Nonempty const& rhs) { + return !(lhs == rhs); + } +}; + +template +struct Empty { + template + friend bool operator==(Empty const&, Empty const&) { + return V == W; + } + + template + friend bool operator!=(Empty const& lhs, Empty const& rhs) { + return !(lhs == rhs); + } +}; + +// std::tuple +static_assert(cute::is_standard_layout_v>); // it happens to be +static_assert(cute::is_standard_layout_v>); // it happens to be +static_assert(cute::is_standard_layout_v>); // it happens to be +static_assert(not cute::is_standard_layout_v>); // it's not + +#if ! defined(CUTLASS_USE_PACKED_TUPLE) +// cute::tuple +static_assert(cute::is_standard_layout_v>); // it happens to be +static_assert(cute::is_standard_layout_v>); // it happens to be +static_assert(cute::is_standard_layout_v>); // it happens to be +static_assert(not cute::is_standard_layout_v>); // it's not +#endif // CUTLASS_USE_PACKED_TUPLE + +// cute::packed_tuple +static_assert(cute::is_standard_layout_v>); +static_assert(cute::is_standard_layout_v>); +static_assert(cute::is_standard_layout_v>); +static_assert(cute::is_standard_layout_v>); // it is +static_assert(cute::is_standard_layout_v>); // it is +static_assert(cute::is_standard_layout_v, int>>); // it is +static_assert(cute::is_standard_layout_v, Empty<0>>, int>>); // it is + +////////////////////////////////////////////////////////////////////// +// packed_tuple test starts here +////////////////////////////////////////////////////////////////////// + +template < + class ExpectedPackedType, + size_t ExpectedPackedSize, + class ... Args> +constexpr void +test_packed_type_alias([[maybe_unused]] ExpectedPackedType packed, std::tuple unpacked) +{ + using cute::packed_tuple; + + if constexpr ((cute::is_standard_layout_v && ...)) { + static_assert(cute::is_standard_layout_v>); + } + + if constexpr ((cute::is_empty_v && ...)) { + static_assert(cute::is_empty_v>); + } + + static_assert(cute::tuple_size_v> == sizeof...(Args)); + + auto test_element = [unpacked] (auto index) { + static_assert(cute::is_same_v< + std::tuple_element_t>, + std::tuple_element_t> + >); + + packed_tuple sl = cute::apply(unpacked, [](auto... a){ return cute::make_packed_tuple(a...); }); + EXPECT_EQ(std::get(unpacked), cute::get(sl)); + }; + cute::for_each(std::make_index_sequence(), test_element); +} + +void test_packed_type_aliases() { + using cute::packed_tuple; + test_packed_type_alias, 0>({}, {}); + + test_packed_type_alias, 1, int>({7}, {7}); + test_packed_type_alias, 1, double>({1.5}, {1.5}); + + // Make sure that class types are handled the same as scalar types + test_packed_type_alias>, 1, Nonempty>( + {Nonempty{7}}, {Nonempty{7}}); + test_packed_type_alias>, 1, Nonempty>( + {Nonempty{1.5}}, {Nonempty{1.5}}); + + test_packed_type_alias, 0, Empty<0>>({}, {}); + test_packed_type_alias, 0, Empty<0>, Empty<1>>( + {}, {Empty<0>{}, Empty<1>{}}); + test_packed_type_alias, 0, Empty<0>, Empty<1>, Empty<2>>( + {}, {Empty<0>{}, Empty<1>{}, Empty<2>{}}); + + test_packed_type_alias, 1, Empty<0>, int>( + {7}, {Empty<0>{}, 7}); + test_packed_type_alias, 1, int, Empty<0>>( + {7}, {7, Empty<0>{}}); + + test_packed_type_alias, 1, int, Empty<0>, Empty<1>>( + {7}, {7, Empty<0>{}, Empty<1>{}}); + test_packed_type_alias, 1, Empty<0>, int, Empty<1>>( + {7}, {Empty<0>{}, 7, Empty<1>{}}); + test_packed_type_alias, 1, Empty<0>, Empty<1>, int>( + {7}, {Empty<0>{}, Empty<1>{}, 7}); + + test_packed_type_alias, 2, int, double, Empty<0>>( + {7, 1.5}, {7, 1.5, Empty<0>{}}); + test_packed_type_alias, 2, int, Empty<0>, double>( + {7, 1.5}, {7, Empty<0>{}, 1.5}); + test_packed_type_alias, 2, int, double, Empty<0>>( + {7, 1.5}, {7, 1.5, Empty<0>{}}); + + test_packed_type_alias, 2, int, double, Empty<0>, Empty<1>>( + {7, 1.5}, {7, 1.5, Empty<0>{}, Empty<1>{}}); + test_packed_type_alias, 2, int, Empty<0>, double, Empty<1>>( + {7, 1.5}, {7, Empty<0>{}, 1.5, Empty<1>{}}); + test_packed_type_alias, 2, int, Empty<0>, Empty<1>, double>( + {7, 1.5}, {7, Empty<0>{}, Empty<1>{}, 1.5}); + test_packed_type_alias, 2, Empty<0>, int, Empty<1>, double>( + {7, 1.5}, {Empty<0>{}, 7, Empty<1>{}, 1.5}); + test_packed_type_alias, 2, Empty<0>, Empty<1>, int, double>( + {7, 1.5}, {Empty<0>{}, Empty<1>{}, 7, 1.5}); + + test_packed_type_alias, 3, Empty<0>, int, double, float>( + {7, 1.5, 2.5f}, {Empty<0>{}, 7, 1.5, 2.5f}); + test_packed_type_alias, 3, int, Empty<0>, double, float>( + {7, 1.5, 2.5f}, {7, Empty<0>{}, 1.5, 2.5f}); + test_packed_type_alias, 3, int, double, Empty<0>, float>( + {7, 1.5, 2.5f}, {7, 1.5, Empty<0>{}, 2.5f}); + test_packed_type_alias, 3, int, double, float, Empty<0>>( + {7, 1.5, 2.5f}, {7, 1.5, 2.5f, Empty<0>{}}); +} + +template +constexpr bool test_tuple_element() { + return cute::is_same_v, ExpectedElementType>; +} + +void test_tuple_elements() { + using cute::packed_tuple; + + static_assert(test_tuple_element>, 0, Empty<0>>()); + static_assert(test_tuple_element>, 0, Empty<0>>()); +} + +// A default-constructible type. +template +struct DefaultConstructible {}; + +void test_default_constructibility() { + using cute::packed_tuple; + { + [[maybe_unused]] packed_tuple<> t_p_0; + [[maybe_unused]] packed_tuple> t_p_1; + [[maybe_unused]] packed_tuple, DefaultConstructible<1>> t_p_2; + [[maybe_unused]] packed_tuple, int, DefaultConstructible<1>> t_p_3; + } +} + +void test_sizes_and_not_storing_empty_types() { + using cute::packed_tuple; + + [[maybe_unused]] packed_tuple< + int, + pt_test::Empty<0>, + double + > pt{42, pt_test::Empty<0>{}, 1.5}; + static_assert(cute::is_standard_layout_v); + // packed_result_type must only store the packed tuple, + // and not the integer_sequence(s) used to access it. + // The latter can be represented entirely at compile time as types. + struct { int i; double j; } IntDouble; + static_assert(sizeof(pt) == sizeof(IntDouble)); + + EXPECT_EQ(cute::get<0>(pt), 42); + EXPECT_EQ(cute::get<1>(pt), pt_test::Empty<0>{}); + EXPECT_EQ(cute::get<2>(pt), 1.5); + packed_tuple< + pt_test::Empty<0>, + pt_test::Empty<1>, + packed_tuple< + pt_test::Empty<0>, + pt_test::Empty<1>, + packed_tuple, packed_tuple<>> + > + > pt_empty{}; + static_assert(cute::is_empty_v); + static_assert(cute::is_standard_layout_v); + static_assert(sizeof(pt_empty) == 1); + + // Template arguments must be default constructible, + // and packed_tuple itself needs a default constructor. + [[maybe_unused]] packed_tuple< + packed_tuple>, + double, + pt_test::Empty<3>> pt2; + static_assert(cute::is_standard_layout_v); + + // cute::packed_tuple, like the original cute::tuple, does not + // promise to have working CTAD (constructor template argument + // deduction). + [[maybe_unused]] packed_tuple< + packed_tuple>, + pt_test::Empty<1> + > pt3{ + packed_tuple>{42, pt_test::Empty<0>{}}, + pt_test::Empty<1>{} + }; + static_assert(cute::is_standard_layout_v); + static_assert(cute::is_same_v< + cute::tuple_element_t<0, decltype(pt3)>, + packed_tuple>>); + static_assert(cute::is_same_v< + cute::tuple_element_t<1, decltype(pt3)>, + pt_test::Empty<1>>); + static_assert(cute::tuple_size_v> == 2u); + + packed_tuple> pt3_0 = cute::get<0>(pt3); + auto pt3_0_1 = cute::get<1>(pt3_0); + static_assert(cute::is_same_v>); + + EXPECT_EQ(cute::get<0>(cute::get<0>(pt3)), 42); + EXPECT_EQ(cute::get<1>(cute::get<0>(pt3)), pt_test::Empty<0>{}); +} + +} // namespace test + +TEST(CuTe_core, PackedTuple2) +{ + CUTLASS_TRACE_HOST("-------------------------------"); + CUTLASS_TRACE_HOST("packed_tuple"); + CUTLASS_TRACE_HOST("-------------------------------"); + + pt_test::test_packed_type_aliases(); + pt_test::test_tuple_elements(); + pt_test::test_default_constructibility(); + pt_test::test_sizes_and_not_storing_empty_types(); +} + +TEST(CuTe_core, PackedTuple2Get) { + using cute::packed_tuple; + using pt_test::Empty; + using pt_test::Nonempty; + + { + using tuple_type = packed_tuple; + tuple_type pt{42}; + static_assert(cute::tuple_size_v == 1u); + static_assert(cute::is_same_v, int>); + EXPECT_EQ(cute::get<0>(pt), 42); + cute::get<0>(pt) = 43; + EXPECT_EQ(cute::get<0>(pt), 43); + } + { + using tuple_type = packed_tuple; + tuple_type const pt{42}; + EXPECT_EQ(cute::get<0>(pt), 42); + static_assert(cute::is_same_v(pt)), int const&>); + } + { + EXPECT_EQ(cute::get<0>(packed_tuple{42}), 42); + } + + { + using tuple_type = packed_tuple>; + tuple_type pt; + static_assert(cute::tuple_size_v == 1u); + static_assert(cute::is_same_v, pt_test::Empty<0>>); + EXPECT_EQ(cute::get<0>(pt), pt_test::Empty<0>{}); + } + { + using tuple_type = packed_tuple>; + tuple_type const pt; + EXPECT_EQ(cute::get<0>(pt), pt_test::Empty<0>{}); + } + { + using tuple_type = packed_tuple>; + EXPECT_EQ(cute::get<0>(tuple_type{}), pt_test::Empty<0>{}); + } + + { + using tuple_type = packed_tuple; + tuple_type pt{1, 2.5}; + static_assert(cute::tuple_size_v == 2u); + static_assert(cute::is_same_v, int>); + static_assert(cute::is_same_v, double>); + EXPECT_EQ(cute::get<0>(pt), 1); + cute::get<0>(pt) = 2; + EXPECT_EQ(cute::get<0>(pt), 2); + EXPECT_EQ(cute::get<1>(pt), 2.5); + cute::get<1>(pt) = 3.5; + EXPECT_EQ(cute::get<1>(pt), 3.5); + } + { + using tuple_type = packed_tuple; + tuple_type const pt{1, 2.5}; + EXPECT_EQ(cute::get<0>(pt), 1); + static_assert(cute::is_same_v(pt)), int const&>); + EXPECT_EQ(cute::get<1>(pt), 2.5); + static_assert(cute::is_same_v(pt)), double const&>); + } + { + using tuple_type = packed_tuple; + EXPECT_EQ(cute::get<0>(tuple_type{1, 2.5}), 1); + EXPECT_EQ(cute::get<1>(tuple_type{1, 2.5}), 2.5); + } + + { + using tuple_type = packed_tuple, double>; + tuple_type pt{Empty<0>{}, 2.5}; + static_assert(cute::tuple_size_v == 2u); + static_assert(cute::is_same_v, Empty<0>>); + static_assert(cute::is_same_v, double>); + EXPECT_EQ(cute::get<0>(pt), Empty<0>{}); + EXPECT_EQ(cute::get<1>(pt), 2.5); + cute::get<1>(pt) = 3.5; + EXPECT_EQ(cute::get<1>(pt), 3.5); + } + { + using tuple_type = packed_tuple, double>; + tuple_type const pt{Empty<0>{}, 2.5}; + EXPECT_EQ(cute::get<0>(pt), Empty<0>{}); + static_assert(cute::is_same_v(pt)), Empty<0>>); + EXPECT_EQ(cute::get<1>(pt), 2.5); + static_assert(cute::is_same_v(pt)), double const&>); + } + { + using tuple_type = packed_tuple, double>; + EXPECT_EQ(cute::get<0>(tuple_type{Empty<0>{}, 2.5}), Empty<0>{}); + EXPECT_EQ(cute::get<1>(tuple_type{Empty<0>{}, 2.5}), 2.5); + } + + { + using tuple_type = packed_tuple>; + tuple_type pt{1, 2.5, Nonempty{3.25f}}; + static_assert(cute::tuple_size_v == 3u); + static_assert(cute::is_same_v, int>); + static_assert(cute::is_same_v, double>); + static_assert(cute::is_same_v, Nonempty>); + EXPECT_EQ(cute::get<0>(pt), 1); + EXPECT_EQ(cute::get<1>(pt), 2.5); + EXPECT_EQ(cute::get<2>(pt), Nonempty{3.25f}); + + cute::get<0>(pt) = 42; + EXPECT_EQ(cute::get<0>(pt), 42); + cute::get<1>(pt) = 4.5; + EXPECT_EQ(cute::get<1>(pt), 4.5); + cute::get<2>(pt) = Nonempty{3.75f}; + EXPECT_EQ(cute::get<2>(pt), Nonempty{3.75f}); + } + { + using tuple_type = packed_tuple>; + tuple_type const pt{1, 2.5, Nonempty{3.25f}}; + EXPECT_EQ(cute::get<0>(pt), 1); + EXPECT_EQ(cute::get<1>(pt), 2.5); + EXPECT_EQ(cute::get<2>(pt), Nonempty{3.25f}); + } + { + using tuple_type = packed_tuple>; + EXPECT_EQ((cute::get<0>(tuple_type{1, 2.5, Nonempty{3.25f}})), 1); + EXPECT_EQ((cute::get<1>(tuple_type{1, 2.5, Nonempty{3.25f}})), 2.5); + EXPECT_EQ((cute::get<2>(tuple_type{1, 2.5, Nonempty{3.25f}})), Nonempty{3.25f}); + } + + { + using tuple_type = packed_tuple, Nonempty>; + packed_tuple, Nonempty> pt{1, Empty<0>{}, Nonempty{3.25f}}; + static_assert(cute::tuple_size_v == 3u); + static_assert(cute::is_same_v, int>); + static_assert(cute::is_same_v, Empty<0>>); + static_assert(cute::is_same_v, Nonempty>); + EXPECT_EQ(cute::get<0>(pt), 1); + EXPECT_EQ(cute::get<1>(pt), Empty<0>{}); + EXPECT_EQ(cute::get<2>(pt), Nonempty{3.25f}); + + cute::get<0>(pt) = 42; + EXPECT_EQ(cute::get<0>(pt), 42); + cute::get<2>(pt) = Nonempty{3.75f}; + EXPECT_EQ(cute::get<2>(pt), Nonempty{3.75f}); + } + { + using tuple_type = packed_tuple, Nonempty>; + tuple_type const pt{1, Empty<0>{}, Nonempty{3.25f}}; + EXPECT_EQ(cute::get<0>(pt), 1); + EXPECT_EQ(cute::get<1>(pt), Empty<0>{}); + EXPECT_EQ(cute::get<2>(pt), Nonempty{3.25f}); + } + { + using tuple_type = packed_tuple, Nonempty>; + EXPECT_EQ((cute::get<0>(tuple_type{1, Empty<0>{}, Nonempty{3.25f}})), 1); + EXPECT_EQ((cute::get<1>(tuple_type{1, Empty<0>{}, Nonempty{3.25f}})), Empty<0>{}); + EXPECT_EQ((cute::get<2>(tuple_type{1, Empty<0>{}, Nonempty{3.25f}})), Nonempty{3.25f}); + } +} + +namespace pt_test { + +// An empty class type to which Empty is convertible. +template +struct ConvertibleFromEmpty { + constexpr ConvertibleFromEmpty() = default; + constexpr ConvertibleFromEmpty(Empty) {} + + template + friend constexpr bool operator==(ConvertibleFromEmpty const&, ConvertibleFromEmpty const&) { + return Value == OtherValue; + } + + template + friend constexpr bool operator!=(ConvertibleFromEmpty const& lhs, ConvertibleFromEmpty const& rhs) { + return !(lhs == rhs); + } +}; + +} // end namespace pt_test + +TEST(CuTe_core, PackedTupleConstexprDefaultConstruction) { + // Make sure that packed_tuple's default constructor is constexpr. + // MSVC makes this a bit more challenging than usual. + + using pt_test::Empty; + { + [[maybe_unused]] constexpr cute::detail::ESO_t> eso1{}; + [[maybe_unused]] constexpr cute::detail::ESO_t eso2{}; + } + { + [[maybe_unused]] constexpr cute::detail::ESO_t, Empty<1>> eso0{}; + [[maybe_unused]] constexpr cute::detail::ESO_t> eso1{}; + [[maybe_unused]] constexpr cute::detail::ESO_t, int64_t> eso2{}; + [[maybe_unused]] constexpr cute::detail::ESO_t eso3{}; + } +} + +TEST(CuTe_core, PackedTupleConvertingConstruction) { + using cute::packed_tuple; + using pt_test::ConvertibleFromEmpty; + using pt_test::Empty; + using pt_test::Nonempty; + + { + using tuple_type = cute::tuple>; + [[maybe_unused]] tuple_type t(7); + EXPECT_EQ(cute::get<0>(t), Nonempty(7)); + } + { + using tuple_type = packed_tuple>; + [[maybe_unused]] tuple_type t(7); + EXPECT_EQ(cute::get<0>(t), Nonempty(7)); + } + { + using tuple_type = cute::tuple>; + [[maybe_unused]] tuple_type t(Empty<0>{}); + EXPECT_EQ(cute::get<0>(t), ConvertibleFromEmpty<0>{}); + } + { + using tuple_type = packed_tuple>; + [[maybe_unused]] tuple_type t(Empty<0>{}); + EXPECT_EQ(cute::get<0>(t), ConvertibleFromEmpty<0>{}); + } + + { + using tuple_type = cute::tuple>; + [[maybe_unused]] tuple_type t(1.5f, 7); + EXPECT_EQ(cute::get<0>(t), 1.5f); + EXPECT_EQ(cute::get<1>(t), Nonempty(7)); + } + { + using tuple_type = packed_tuple>; + [[maybe_unused]] tuple_type t(1.5f, 7); + EXPECT_EQ(cute::get<0>(t), 1.5f); + EXPECT_EQ(cute::get<1>(t), Nonempty(7)); + } + + { + using tuple_type = cute::tuple, Nonempty>; + [[maybe_unused]] tuple_type t(Empty<0>{}, 7); + EXPECT_EQ(cute::get<0>(t), Empty<0>{}); + EXPECT_EQ(cute::get<1>(t), Nonempty(7)); + } + { + using tuple_type = packed_tuple, Nonempty>; + [[maybe_unused]] tuple_type t(Empty<0>{}, 7); + EXPECT_EQ(cute::get<0>(t), Empty<0>{}); + EXPECT_EQ(cute::get<1>(t), Nonempty(7)); + } + + { + using tuple_type = cute::tuple, Nonempty>; + [[maybe_unused]] tuple_type t(Empty<0>{}, 7); + EXPECT_EQ(cute::get<0>(t), ConvertibleFromEmpty<0>{}); + EXPECT_EQ(cute::get<1>(t), Nonempty(7)); + } + { + using tuple_type = packed_tuple, Nonempty>; + [[maybe_unused]] tuple_type t(Empty<0>{}, 7); + EXPECT_EQ(cute::get<0>(t), ConvertibleFromEmpty<0>{}); + EXPECT_EQ(cute::get<1>(t), Nonempty(7)); + } + + { + using inner_tuple_type = cute::tuple>; + using outer_tuple_type = cute::tuple; + [[maybe_unused]] outer_tuple_type t(inner_tuple_type{Empty<0>{}}); + } + { + using inner_tuple_type = packed_tuple>; + using outer_tuple_type = packed_tuple; + [[maybe_unused]] outer_tuple_type t(inner_tuple_type{Empty<0>{}}); + } + { + using inner_tuple_type = cute::tuple>; + using outer_tuple_type = cute::tuple; + [[maybe_unused]] outer_tuple_type t(inner_tuple_type{Empty<0>{}}); + } + { + using inner_tuple_type = packed_tuple>; + using outer_tuple_type = packed_tuple; + [[maybe_unused]] outer_tuple_type t(inner_tuple_type{Empty<0>{}}); + } + +} + + diff --git a/test/unit/cute/core/tuple_find.cpp b/test/unit/cute/core/tuple_find.cpp new file mode 100644 index 0000000000..0eeeb16613 --- /dev/null +++ b/test/unit/cute/core/tuple_find.cpp @@ -0,0 +1,103 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include +#include +#include + +namespace test { + +template +void test_tuple_find(Tuple const& t) { + auto index = cute::find(t); + static_assert(decltype(index)::value == ExpectedIndex); +} + +template class Tuple> +void test_tuple_find_all() { + using test::test_tuple_find; + using cute::_1; + using cute::_2; + using cute::_4; + + test_tuple_find<0, _1>(Tuple<_1>{}); + test_tuple_find<0, int>(Tuple{7}); + + test_tuple_find<0, _1>(Tuple<_1, _2>{}); + test_tuple_find<0, _1>(Tuple<_1, int>{_1{}, 7}); + test_tuple_find<0, float>(Tuple{15.5f, 7}); + test_tuple_find<1, _2>(Tuple<_1, _2>{}); + test_tuple_find<1, int>(Tuple<_1, int>{_1{}, 7}); + test_tuple_find<1, int>(Tuple{15.5f, 7}); + + test_tuple_find<0, _1>(Tuple<_1, _2, _4>{_1{}, _2{}, _4{}}); + test_tuple_find<0, _1>(Tuple<_1, _2, int>{_1{}, _2{}, 7}); + test_tuple_find<0, _1>(Tuple<_1, float, _4>{_1{}, 15.5f, _4{}}); + test_tuple_find<0, _1>(Tuple<_1, float, int>{_1{}, 15.5f, 7}); + test_tuple_find<0, double>(Tuple{105.5, _2{}, _4{}}); + test_tuple_find<0, double>(Tuple{105.5, 15.5f, _4{}}); + test_tuple_find<0, double>(Tuple{105.5, 15.5f, 7}); + + test_tuple_find<1, _2>(Tuple<_1, _2, _4>{_1{}, _2{}, _4{}}); + test_tuple_find<1, _2>(Tuple<_1, _2, int>{_1{}, _2{}, 7}); + test_tuple_find<1, float>(Tuple<_1, float, _4>{_1{}, 15.5f, _4{}}); + test_tuple_find<1, float>(Tuple<_1, float, int>{_1{}, 15.5f, 7}); + test_tuple_find<1, _2>(Tuple{105.5, _2{}, _4{}}); + test_tuple_find<1, float>(Tuple{105.5, 15.5f, _4{}}); + test_tuple_find<1, float>(Tuple{105.5, 15.5f, 7}); + + test_tuple_find<2, _4>(Tuple<_1, _2, _4>{_1{}, _2{}, _4{}}); + test_tuple_find<2, int>(Tuple<_1, _2, int>{_1{}, _2{}, 7}); + test_tuple_find<2, _4>(Tuple<_1, float, _4>{_1{}, 15.5f, _4{}}); + test_tuple_find<2, int>(Tuple<_1, float, int>{_1{}, 15.5f, 7}); + test_tuple_find<2, _4>(Tuple{105.5, _2{}, _4{}}); + test_tuple_find<2, _4>(Tuple{105.5, 15.5f, _4{}}); + test_tuple_find<2, int>(Tuple{105.5, 15.5f, 7}); +} + +} // end namespace test + + +TEST(CuTe_core, TupleFind) +{ + test::test_tuple_find_all(); +} + +// If cute::tuple is not simply an alias for cute::packed_tuple, +// then test cute::packed_tuple separately. +#if ! defined(CUTLASS_USE_PACKED_TUPLE) +TEST(CuTe_core, PackedTupleFind) +{ + test::test_tuple_find_all(); +} +#endif // CUTLASS_USE_PACKED_TUPLE diff --git a/test/unit/cute/hopper/CMakeLists.txt b/test/unit/cute/hopper/CMakeLists.txt index 0b6db66f22..f77aad93c0 100644 --- a/test/unit/cute/hopper/CMakeLists.txt +++ b/test/unit/cute/hopper/CMakeLists.txt @@ -29,6 +29,7 @@ add_custom_target( cutlass_test_unit_cute_hopper DEPENDS + cutlass_test_unit_cute_hopper_cooperative_gemm cutlass_test_unit_cute_hopper_stsm cutlass_test_unit_cute_hopper_tma_load cutlass_test_unit_cute_hopper_tma_store @@ -46,6 +47,11 @@ add_custom_target( test_unit_cute_hopper_bulk_store ) +cutlass_test_unit_add_executable( + cutlass_test_unit_cute_hopper_cooperative_gemm + cooperative_gemm.cu +) + cutlass_test_unit_add_executable( cutlass_test_unit_cute_hopper_stsm stsm.cu diff --git a/test/unit/cute/hopper/cooperative_gemm.cu b/test/unit/cute/hopper/cooperative_gemm.cu new file mode 100644 index 0000000000..bab7122eda --- /dev/null +++ b/test/unit/cute/hopper/cooperative_gemm.cu @@ -0,0 +1,132 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include + +#include "../cooperative_gemm_common.hpp" + +using namespace cute; + +#define USE_FP8 1 + +#if USE_FP8 +TEST(SM90_CuTe_Hopper, CooperativeGemmTilingF8) { + + using TA = uint8_t; + using TB = uint8_t; + using TC = uint32_t; + + constexpr uint32_t thread_block_size = 128; + constexpr int MaxVecBits = 16; + + using tiled_mma_t = + TiledMMA< + MMA_Atom, + Layout, Stride<_1, _2, _0>>, + Tile<_32, _32, _32> + >; + + using swizzle = Swizzle<2, 4, 3>; + + // This is for A row major, B col major according to CUTLASS default configs + using ALayout = decltype(composition(swizzle{}, Layout, Stride<_64, _1>>{})); + using BLayout = decltype(composition(swizzle{}, Layout, Stride<_1, _64>>{})); + + using CLayout = decltype(make_layout(Shape<_64, _64>{}, LayoutLeft{})); + + test_cooperative_gemm, // A + AutoVectorizingCopyWithAssumedAlignment, // B + AutoVectorizingCopyWithAssumedAlignment, // C + thread_block_size, + tiled_mma_t, + MaxVecBits, + TA, + TB, + TC>(); + +} + +#else + +TEST(SM90_CuTe_Hopper, CooperativeGemmTilingF16) { + + using TA = half_t; + using TB = half_t; + using TC = half_t; + + constexpr uint32_t thread_block_size = 64; + constexpr int MaxVecBits = 16; + + using tiled_mma_t = + TiledMMA< + MMA_Atom, + Layout, Stride<_1, _0, _0>>, + Tile<_32, _32, _32> + >; + + using swizzle = Swizzle<3, 3, 3>; + + // This is for A row major, B col major according to CUTLASS default configs + using ALayout = decltype(composition(swizzle{}, + Layout, Stride<_64, _1>>{})); + + using BLayout = decltype(composition(swizzle{}, + Layout, Stride<_1, _64>>{})); + + using CLayout = decltype(make_layout(Shape<_64, _64>{}, LayoutLeft{})); + + test_cooperative_gemm, // A + AutoVectorizingCopyWithAssumedAlignment, // B + AutoVectorizingCopyWithAssumedAlignment, // C + thread_block_size, + tiled_mma_t, + MaxVecBits, + TA, + TB, + TC>(); + +} + +#endif diff --git a/test/unit/cute/hopper/tma_load_testbed.hpp b/test/unit/cute/hopper/tma_load_testbed.hpp index fd1fe6e7f1..0c8ed91d69 100644 --- a/test/unit/cute/hopper/tma_load_testbed.hpp +++ b/test/unit/cute/hopper/tma_load_testbed.hpp @@ -122,6 +122,11 @@ tma_test_device_cute(T const* g_in, T* g_out, } #endif + // Test L2 prefetch + if (threadIdx.x == 0) { + prefetch(tma, tAgA); + } + // Loop over the TMA stages, using smem as our buffer for (int stage = 0; stage < size<1>(tAgA); ++stage) { diff --git a/test/unit/cute/hopper/tma_store_testbed.hpp b/test/unit/cute/hopper/tma_store_testbed.hpp index 08fbcba015..ebdec55abe 100644 --- a/test/unit/cute/hopper/tma_store_testbed.hpp +++ b/test/unit/cute/hopper/tma_store_testbed.hpp @@ -117,6 +117,9 @@ tma_test_device_cute(T const* g_in, T* g_out, } #endif + // Test L2 prefetch + cooperative_prefetch<128>(threadIdx.x, gA); + // Loop over the TMA stages, using smem as our buffer for (int stage = 0; stage < size<1>(tBgB); ++stage) { diff --git a/test/unit/cute/msvc_compilation/tuple.cpp b/test/unit/cute/msvc_compilation/tuple.cpp index 7cf843062a..a8a31dd3ca 100644 --- a/test/unit/cute/msvc_compilation/tuple.cpp +++ b/test/unit/cute/msvc_compilation/tuple.cpp @@ -53,6 +53,8 @@ class ConvertibleTo { template using IC = std::integral_constant; +#if ! defined(CUTLASS_USE_PACKED_TUPLE) + TEST(CuTe_core_msvc_compilation, TupleAssignment) { CUTLASS_TRACE_HOST("-------------------------------"); @@ -89,29 +91,22 @@ TEST(CuTe_core_msvc_compilation, TupleAssignment) using tuple_0d_type = cute::tuple<>; using tuple_1d_d_type = cute::tuple; - using tuple_1d_s_type = cute::tuple; using tuple_2d_dd_type = cute::tuple; - using tuple_2d_ss_type = cute::tuple; [[maybe_unused]] tuple_0d_type t0; // Symptom: "illegal member initialization: 'TupleBase' is not a base or member" [[maybe_unused]] tuple_1d_d_type t1{ 42 }; - - [[maybe_unused]] tuple_1d_s_type t2; - [[maybe_unused]] tuple_1d_d_type t1a{ 43 }; t1 = t1a; [[maybe_unused]] tuple_2d_dd_type t3{ 42, size_t(43u) }; - [[maybe_unused]] tuple_2d_ss_type t4; - t3 = t4; - [[maybe_unused]] tuple_2d_dd_type t3a{ 44, size_t(45u) }; // Symptom: "illegal member initialization: // 'TupleBase' is not a base or member" t3 = t3a; } +#endif // CUTLASS_USE_PACKED_TUPLE TEST(CuTe_core_msvc_compilation, TupleGetSingleInteger) { diff --git a/test/unit/cute/volta/CMakeLists.txt b/test/unit/cute/volta/CMakeLists.txt index d6688aa30d..27ebcb29fd 100644 --- a/test/unit/cute/volta/CMakeLists.txt +++ b/test/unit/cute/volta/CMakeLists.txt @@ -29,6 +29,5 @@ cutlass_test_unit_add_executable( cutlass_test_unit_cute_volta vectorization_auto.cu - cooperative_copy.cu cooperative_gemm.cu ) diff --git a/test/unit/cute/volta/cooperative_copy.cu b/test/unit/cute/volta/cooperative_copy.cu deleted file mode 100644 index 2fc80b366a..0000000000 --- a/test/unit/cute/volta/cooperative_copy.cu +++ /dev/null @@ -1,486 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#include "cutlass_unit_test.h" - -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include -#include - -using namespace cute; - -namespace cooperative_copy_mode { - struct global_shared {}; - struct global_global {}; - struct shared_shared {}; -} - -// gs --> global to/from shared -template -__device__ void -cooperative_copy_default_gs(T const* g_in, T* g_out) -{ - using namespace cute; - extern __shared__ float4 smem_buf[]; - // Cast smem_buf to smem_uint8_ptr and move it by MaxVecBits bits - // This is to make sure tests pass on pointer aligned to MaxVecBits bits - uint8_t* smem_uint8_ptr = reinterpret_cast(smem_buf) + (MaxVecBits/8); - T* smem = reinterpret_cast(smem_uint8_ptr); - - Tensor g_in_tensor = make_tensor(make_gmem_ptr(g_in), GMemLayout{}); - Tensor g_out_tensor = make_tensor(make_gmem_ptr(g_out), GMemLayout{}); - Tensor s_tensor = make_tensor(make_smem_ptr(smem), SMemLayout{}); - - cooperative_copy(threadIdx.x, g_in_tensor, s_tensor); - __syncthreads(); - - if(thread0()) { - for(int i = 0; i < size(s_tensor); ++i) { - s_tensor(i) += T(i); - } - } - __syncthreads(); - - cooperative_copy(threadIdx.x, s_tensor, g_out_tensor); -} - -// ss --> shared to shared -template -__device__ void -cooperative_copy_default_ss(T const* g_in, T* g_out) -{ - using namespace cute; - extern __shared__ float4 smem_buf[]; - // Cast smem_buf to smem_uint8_ptr and move it by MaxVecBits bits - // This is to make sure tests pass on pointer aligned to MaxVecBits bits - T* smem1 = reinterpret_cast(smem_buf); - uint8_t* smem2_uint8_ptr = reinterpret_cast(smem_buf) + (MaxVecBits/8); - T* smem2 = reinterpret_cast(smem2_uint8_ptr) + cute::cosize(Layout2{}); - - Tensor g_in_tensor = make_tensor(make_gmem_ptr(g_in), Layout1 {}); - Tensor g_out_tensor = make_tensor(make_gmem_ptr(g_out), Layout2 {}); - - Tensor s1_tensor = make_tensor(make_smem_ptr(smem1), Layout2 {}); - Tensor s2_tensor = make_tensor(make_smem_ptr(smem2), Layout1 {}); - - cooperative_copy>(threadIdx.x, g_in_tensor, s1_tensor); - __syncthreads(); - - if(thread0()) { - for(int i = 0; i < size(s1_tensor); ++i) { - s1_tensor(i) += T(i); - } - } - __syncthreads(); - - cooperative_copy(threadIdx.x, s1_tensor, s2_tensor); - __syncthreads(); - - cooperative_copy>(threadIdx.x, s2_tensor, g_out_tensor); -} - -// gg --> global to global -template -__device__ void -cooperative_copy_default_gg(T const* g_in, T* g_out) -{ - using namespace cute; - - Tensor g_in_tensor = make_tensor(make_gmem_ptr(g_in), Layout1{}); - Tensor g_out_tensor = make_tensor(make_gmem_ptr(g_out), Layout2{}); - - cooperative_copy(threadIdx.x, g_in_tensor, g_out_tensor); -} - -template -__global__ void -cooperative_copy_default_kernel(T const* g_in, T* g_out) -{ - if constexpr(std::is_same_v) { - cooperative_copy_default_gs(g_in, g_out); - } else if constexpr (std::is_same_v) { - cooperative_copy_default_gg(g_in, g_out); - } else if constexpr (std::is_same_v) { - cooperative_copy_default_ss(g_in, g_out); - } -} - -// Mode - defines memory types of src and dst in cooperative_copy operation -// MaxVecBits - defines max vectorization in cooperative_copy operation, and enforces that -// alignment on used pointers to ensure correct testing -template -void test_cooperative_copy_default() -{ - using value_type = T; - static_assert(cute::size(Layout1{}) == cute::size(Layout2{})); - - using gmem_layout_in = Layout1; - using gmem_layout_out = std::conditional_t, Layout1, Layout2>; - -#if 0 - print(" "); print("layout1: "); print(Layout1{}); print("\n"); - print(" "); print("layout2: "); print(Layout2{}); print("\n"); - print(" "); print("threads: "); print(ThreadBlockSize); print("\n"); -#endif - - if constexpr (MaxVecBits < cute::sizeof_bits_v) { - GTEST_SKIP() << "Skipping test since MaxVecBits (=" << MaxVecBits - << ") < cute::sizeof_bits_v (=" << cute::sizeof_bits_v << ")"; - } else { - constexpr auto max_vec_bytes = MaxVecBits / 8; - static_assert((max_vec_bytes % sizeof(T)) == 0); - - constexpr uint32_t count = cute::cosize(gmem_layout_in {}); - // Extra elements to force MaxVecBits alignment in global memory - constexpr uint32_t extra_elements = max_vec_bytes / sizeof(value_type); - - // Allocate - thrust::host_vector h_in(count + extra_elements); - thrust::host_vector h_out(count + extra_elements); - - // Initialize - Tensor h_in_tensor = make_tensor((h_in.data() + extra_elements), gmem_layout_in {}); - Tensor h_out_tensor = make_tensor((h_out.data() + extra_elements), gmem_layout_out {}); - for (int i = 0; i < cute::size(h_in_tensor); ++i) { - h_in_tensor(i) = value_type(float(i)); - // For global-to-global copy need to compare against the same value - h_out_tensor(i) = std::is_same_v ? value_type(float(i)) : value_type(float(2 * i)); - } - - // To GPU - thrust::device_vector d_in = h_in; - thrust::device_vector d_out(d_in.size(), value_type(float(-2))); - - // Adds (MaxVecBits/8) bytes to shared memory as we'll move pointer by that many bytes inside the kernel to enforce - // alignment to (MaxVecBits/8) bytes - size_t shared_memory_bytes = (sizeof(value_type) * count) + max_vec_bytes; - shared_memory_bytes += std::is_same_v * (sizeof(value_type) * count); - - // Launch - auto coop_copy = cooperative_copy_default_kernel; - ASSERT_EQ(cudaFuncSetAttribute(coop_copy, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(shared_memory_bytes)), cudaSuccess); - - auto d_in_ptr = thrust::raw_pointer_cast(d_in.data() + extra_elements); - auto d_out_ptr = thrust::raw_pointer_cast(d_out.data() + extra_elements); - coop_copy<<<1, ThreadBlockSize, shared_memory_bytes>>>(d_in_ptr, d_out_ptr); - - cudaError_t result = cudaDeviceSynchronize(); - if (result != cudaSuccess) { - cudaError_t error = cudaGetLastError(); - FAIL() << "Error at kernel sync: " << cudaGetErrorString(error) << "\n"; - } - - // Validate - thrust::host_vector h_result = d_out; - Tensor h_result_tensor = make_tensor((h_result.data() + extra_elements), gmem_layout_out {}); - for (int i = 0; i < cute::size(h_in_tensor); ++i) { - ASSERT_EQ(h_result_tensor(i), h_out_tensor(i)) - << i << " - result:" << h_result_tensor(i) << " expected:" << h_out_tensor(i); - } - } -} - -template -class SM70_CuTe_Volta; - -template -class SM70_CuTe_Volta>: public testing::Test -{ -public: - using mode = Mode; - static constexpr int max_vec_bits = MaxVecBits::value; -}; - -typedef testing::Types< - std::tuple>, - std::tuple>, - std::tuple>, - std::tuple>, - - std::tuple>, - std::tuple>, - std::tuple>, - std::tuple>, - - std::tuple>, - std::tuple>, - std::tuple>, - std::tuple>, -> CooperativeCopyModeMaxVecBitsList; - -TYPED_TEST_SUITE(SM70_CuTe_Volta, CooperativeCopyModeMaxVecBitsList); - -TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefault1D) -{ - using value_type = float; - constexpr uint32_t count = 512; - using gmem_layout_t = decltype(make_layout(make_shape(Int{}))); - using smem_layout_t = decltype(make_layout(make_shape(Int{}))); - constexpr uint32_t thread_block_size = 64; - test_cooperative_copy_default(); -} - -TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefault1DFallback) -{ - using value_type = float; - constexpr uint32_t count = 99; - using gmem_layout_t = decltype(make_layout(make_shape(Int{}))); - using smem_layout_t = decltype(make_layout(make_shape(Int{}))); - constexpr uint32_t thread_block_size = 128; - test_cooperative_copy_default(); -} - -TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefaultGSSG2D) -{ - using value_type = float; - constexpr uint32_t x = 32; - constexpr uint32_t y = 32; - using gmem_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); - using smem_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); - constexpr uint32_t thread_block_size = 64; - test_cooperative_copy_default(); -} - -TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefaultGSSG2DFallback) -{ - using value_type = float; - constexpr uint32_t x = 37; - constexpr uint32_t y = 37; - using gmem_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); - using smem_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); - constexpr uint32_t thread_block_size = 64; - test_cooperative_copy_default(); -} - -TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefaultGSSG2DCustomStride) -{ - using value_type = float; - constexpr uint32_t x = 16; - constexpr uint32_t y = 16; - using gmem_layout_t = decltype(make_layout(make_shape(Int{}, Int{}), make_stride(Int{}, Int<1>{}))); - using smem_layout_t = decltype(make_layout(make_shape(Int{}, Int{}), make_stride(Int<1>{}, Int{}))); - constexpr uint32_t thread_block_size = 64; - test_cooperative_copy_default(); -} - -TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefaultGSSG3D) -{ - using value_type = cute::half_t; - constexpr uint32_t x = 8; - constexpr uint32_t y = 8; - constexpr uint32_t z = 16; - using gmem_layout_t = decltype(make_layout(make_shape(Int{}, Int{}, Int{}))); - using smem_layout_t = decltype(make_layout(make_shape(Int{}, Int{}, Int{}))); - constexpr uint32_t thread_block_size = 64; - test_cooperative_copy_default(); -} - -TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefaultGSSG3DFallback) -{ - using value_type = cute::half_t; - constexpr uint32_t x = 44; - constexpr uint32_t y = 24; - constexpr uint32_t z = 14; - using gmem_layout_t = decltype(make_layout(make_shape(Int{}, Int{}, Int{}))); - using smem_layout_t = decltype(make_layout(make_shape(Int{}, Int{}, Int{}))); - constexpr uint32_t thread_block_size = 128; - test_cooperative_copy_default(); -} - -TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefaultGSSG2Dto3D) -{ - using value_type = double; - constexpr uint32_t x = 16; - constexpr uint32_t y = 16; - constexpr uint32_t z = 4; - using gmem_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); - using smem_layout_t = decltype(make_layout(make_shape(Int{}, Int{}, Int{}))); - constexpr uint32_t thread_block_size = 64; - test_cooperative_copy_default(); -} - -TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefaultGSSGCustom1) -{ - using value_type = double; - using gmem_layout_t = decltype(make_layout( - make_shape(Int<8>{}, make_shape(Int<2>{}, Int<2>{})), - make_stride(Int<2>{}, make_shape(Int<1>{}, Int<16>{})) - )); - using smem_layout_t = decltype(make_layout( - make_shape(Int<8>{}, Int<4>{}), - make_stride(Int<4>{}, Int<1>{}) - )); - constexpr uint32_t thread_block_size = 8; - test_cooperative_copy_default(); -} - -TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefaultGSSGCustom2) -{ - using value_type = float; - using gmem_layout_t = decltype(make_layout( - make_shape(make_shape(Int<4>{}, Int<2>{}), make_shape(Int<2>{}, Int<2>{})), - make_stride(make_shape(Int<4>{}, Int<1>{}), make_shape(Int<16>{}, Int<2>{})) - )); - using smem_layout_t = decltype(make_layout( - make_shape(make_shape(Int<2>{}, Int<2>{}, Int<2>{}), make_shape(Int<2>{}, Int<2>{})), - make_stride(make_shape(Int<16>{}, Int<4>{}, Int<1>{}), make_shape(Int<8>{}, Int<2>{})) - )); - constexpr uint32_t thread_block_size = 16; - test_cooperative_copy_default(); -} - -TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefaultGSSGSwizzle1) -{ - using value_type = float; - using gmem_layout_t = Layout, Stride<_64, _1>>; - using smem_layout_t = decltype(composition(Swizzle<3, 3, 3>{}, Layout, Stride<_64, _1>>{})); - constexpr uint32_t thread_block_size = 128; - test_cooperative_copy_default(); -} - -TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefaultGSSGSwizzle2) -{ - using value_type = cute::half_t; - using gmem_layout_t = decltype(make_layout(make_shape(Int<64>{}, Int<64>{}))); - using smem_atom_layout_t = decltype(composition(Swizzle<3, 2, 3> {}, Layout, Stride<_32, _1>>{})); - using smem_layout_t = decltype(tile_to_shape( - smem_atom_layout_t{}, - make_shape(shape<0>(gmem_layout_t{}), shape<1>(gmem_layout_t{}))) - ); - constexpr uint32_t thread_block_size = 128; - test_cooperative_copy_default(); -} - -TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefaultGSSGSwizzle3) -{ - using value_type = cute::half_t; - using gmem_layout_t = decltype(make_layout(make_shape(Int<64>{}, Int<64>{}))); - using smem_atom_layout_t = decltype(composition(Swizzle<2, 4, 3> {}, Layout, Stride<_64, _1>>{})); - using smem_layout_t = decltype(tile_to_shape( - smem_atom_layout_t{}, - make_shape(shape<0>(gmem_layout_t{}), shape<1>(gmem_layout_t{}))) - ); - constexpr uint32_t thread_block_size = 128; - test_cooperative_copy_default(); -} - -TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefaultGSSGSwizzle4) -{ - using value_type = cute::half_t; - using gmem_atom_layout_t = decltype(composition(Swizzle<3, 2, 3> {}, Layout, Stride<_32, _1>>{})); - using smem_layout_t = decltype(make_layout(make_shape(Int<64>{}, Int<64>{}))); - using gmem_layout_t = decltype(tile_to_shape( - gmem_atom_layout_t{}, - make_shape(shape<0>(smem_layout_t{}), shape<1>(smem_layout_t{}))) - ); - constexpr uint32_t thread_block_size = 128; - test_cooperative_copy_default(); -} diff --git a/test/unit/epilogue/warp/fragment_iterator_tensor_op.cu b/test/unit/epilogue/warp/fragment_iterator_tensor_op.cu index 1dd0e77735..8be67942ea 100644 --- a/test/unit/epilogue/warp/fragment_iterator_tensor_op.cu +++ b/test/unit/epilogue/warp/fragment_iterator_tensor_op.cu @@ -78,7 +78,7 @@ TEST(SM75_Epilogue_warp_FragmentIterator, mma_f32_64x64x8) { std::cout << "Native accumulators:\n"; - for (int i = 0; i < MmaTensorOp::FragmentC::kElements; ++i) { + for (size_t i = 0; i < MmaTensorOp::FragmentC::kElements; ++i) { accum[i] = ElementC(i); std::cout << accum[i] << " "; @@ -106,7 +106,7 @@ TEST(SM75_Epilogue_warp_FragmentIterator, mma_f32_64x64x8) { std::cout << "Iteration " << iter << ":\n"; - for (int i = 0; i < FragmentIterator::Fragment::kElements; ++i) { + for (size_t i = 0; i < FragmentIterator::Fragment::kElements; ++i) { std::cout << frag[i] << " "; } @@ -153,8 +153,83 @@ TEST(SM75_Epilogue_warp_FragmentIterator, mma_f16_64x64x8) { std::cout << "Native accumulators:\n"; - for (int i = 0; i < MmaTensorOp::FragmentC::kElements; ++i) { - accum[i] = ElementC(i); + for (size_t i = 0; i < MmaTensorOp::FragmentC::kElements; ++i) { + accum[i] = ElementC((int)i); + + std::cout << (float)accum[i] << " "; + if (i && !((i + 1) % 4)) { + std::cout << "\n"; + } + } + + std::cout << std::endl; + + std::cout << "FragmentIterator::Policy = { \n" + << " kAccessesPerInstruction: " << FragmentIterator::Policy::kIterationsPerInstruction << "\n" + << " kAccumulatorRowStride: " << FragmentIterator::Policy::kAccumulatorRowStride << "\n" + << " kAccumulatorColumnStride: " << FragmentIterator::Policy::kAccumulatorColumnStride << "\n" + << " kIterations: " << FragmentIterator::Policy::kIterations << "\n" + << " }" << std::endl; + + FragmentIterator fragment_iterator(accum); + + for (int iter = 0; iter < FragmentIterator::kIterations; ++iter) { + + typename FragmentIterator::Fragment frag; + + fragment_iterator.load(frag); + + std::cout << "Iteration " << iter << ":\n"; + + for (size_t i = 0; i < FragmentIterator::Fragment::kElements; ++i) { + std::cout << (float)frag[i] << " "; + } + + std::cout << std::endl; + + ++fragment_iterator; + } + #endif +} + +TEST(SM75_Epilogue_warp_FragmentIterator_column, mma_f32_64x64x8) { + + using Shape = cutlass::gemm::GemmShape<64, 64, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor + >::Type; + + using FragmentIterator = cutlass::epilogue::warp::FragmentIteratorTensorOp< + Shape, + typename MmaTensorOp::Policy::Operator::Shape, + typename MmaTensorOp::Policy::Operator::ElementC, + typename MmaTensorOp::Policy::Operator::FragmentC, + cutlass::layout::ColumnMajor + >; + + // This test just prints things. + #if 0 + typename MmaTensorOp::FragmentC accum; + + std::cout << "Native accumulators:\n"; + + for (size_t i = 0; i < MmaTensorOp::FragmentC::kElements; ++i) { + accum[i] = ElementC((int)i); std::cout << (float)accum[i] << " "; if (i && !((i + 1) % 4)) { @@ -181,7 +256,7 @@ TEST(SM75_Epilogue_warp_FragmentIterator, mma_f16_64x64x8) { std::cout << "Iteration " << iter << ":\n"; - for (int i = 0; i < FragmentIterator::Fragment::kElements; ++i) { + for (size_t i = 0; i < FragmentIterator::Fragment::kElements; ++i) { std::cout << (float)frag[i] << " "; } @@ -191,4 +266,5 @@ TEST(SM75_Epilogue_warp_FragmentIterator, mma_f16_64x64x8) { } #endif } + ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index 072d4f0c15..ce7b02606d 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -310,6 +310,18 @@ cutlass_test_unit_add_executable( sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32_warpspecialized_pingpong.cu ) +# Ptr Array test +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_tensorop_sm90_ptr_array + sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu +) + +# Group Gemm test +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_tensorop_sm90_group_gemm + sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm.cu +) + # Fused epilogue tests cutlass_test_unit_add_executable( cutlass_test_unit_gemm_device_tensorop_epilogue_fusion_sm90 @@ -348,7 +360,6 @@ cutlass_test_unit_add_executable( sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu sm90_gemm_f8_f8_f32_tensor_op_f32_cluster_warpspecialized_cooperative.cu ) - cutlass_test_unit_add_executable( cutlass_test_unit_gemm_device_tensorop_gmma_rs_warpspecialized_sm90 @@ -508,7 +519,7 @@ cutlass_test_unit_add_executable( gemm_f8t_f8n_f32t_tensor_op_f32_sm89.cu gemm_f8t_f8n_f32t_tensor_op_f32_sparse_sm89.cu gemm_f8t_f8n_f8t_tensor_op_f32_sm89.cu - gemm_f8t_f8n_f8t_tensor_op_f32_sparse_sm89.cu +# gemm_f8t_f8n_f8t_tensor_op_f32_sparse_sm89.cu ) cutlass_test_unit_add_executable( diff --git a/test/unit/gemm/device/gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32_sm80.cu b/test/unit/gemm/device/gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32_sm80.cu index b5f8533ffa..02f34628ef 100644 --- a/test/unit/gemm/device/gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32_sm80.cu +++ b/test/unit/gemm/device/gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32_sm80.cu @@ -162,7 +162,7 @@ TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 64x64x16_32x32x16) { TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 128x64x16_64x32x16) { - using Element = cutlass::complex;; + using Element = cutlass::complex; using Gemm = cutlass::gemm::device::GemmComplex< Element, @@ -194,7 +194,7 @@ TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 128x64x16_64x32x16) TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32, 64x128x16_32x64x16) { - using Element = cutlass::complex;; + using Element = cutlass::complex; using Gemm = cutlass::gemm::device::GemmComplex< Element, diff --git a/test/unit/gemm/device/gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32_sm80.cu b/test/unit/gemm/device/gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32_sm80.cu index cb63a8ba6c..be739dbad2 100644 --- a/test/unit/gemm/device/gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32_sm80.cu +++ b/test/unit/gemm/device/gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32_sm80.cu @@ -161,7 +161,7 @@ TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 64x64x16_32x32x16) { TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 128x64x16_64x32x16) { - using Element = cutlass::complex;; + using Element = cutlass::complex; using Gemm = cutlass::gemm::device::GemmComplex< Element, @@ -193,7 +193,7 @@ TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 128x64x16_64x32x16) TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 64x128x16_32x64x16) { - using Element = cutlass::complex;; + using Element = cutlass::complex; using Gemm = cutlass::gemm::device::GemmComplex< Element, @@ -225,7 +225,7 @@ TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 64x128x16_32x64x16) TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32, 128x128x16_32x64x16) { - using Element = cutlass::complex;; + using Element = cutlass::complex; using Gemm = cutlass::gemm::device::GemmComplex< Element, diff --git a/test/unit/gemm/device/gemm_f16n_f16n_f16n_direct_store_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_f16n_f16n_f16n_direct_store_tensor_op_f32_sm80.cu index d22d53eece..94b46326ac 100644 --- a/test/unit/gemm/device/gemm_f16n_f16n_f16n_direct_store_tensor_op_f32_sm80.cu +++ b/test/unit/gemm/device/gemm_f16n_f16n_f16n_direct_store_tensor_op_f32_sm80.cu @@ -79,7 +79,7 @@ TEST(SM80_Device_GemmUniversal_DirectStore_f16n_f16t_f32n_tensor_op_f32, 128x128 cutlass::gemm::GemmShape<16, 8, 16>, cutlass::epilogue::thread::LinearCombination< ElementOutput, - 4, // This is the vector size of the epilogue. + 4, ElementAccumulator, ElementAccumulator>, cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, diff --git a/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sparse_sm80.cu b/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sparse_sm80.cu index 3857521b8f..0764a012cf 100644 --- a/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sparse_sm80.cu +++ b/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sparse_sm80.cu @@ -36,7 +36,7 @@ #include "../../common/cutlass_unit_test.h" #include "cutlass/cutlass.h" -#include "cutlass/gemm/device/gemm_sparse.h" +#include "cutlass/gemm/device/gemm_sparse_universal.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/reference/host/gemm.h" #include "cutlass/util/reference/host/tensor_compare.h" @@ -54,7 +54,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x256x64_64x64x64) using ElementOutput = cutlass::half_t; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -72,7 +72,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x128x64_64x64x64) using ElementOutput = cutlass::half_t; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -91,7 +91,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x128x64_64x64x64) using ElementOutput = cutlass::half_t; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -109,7 +109,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x64x64_64x64x64) { using ElementOutput = cutlass::half_t; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -127,7 +127,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x256x64_64x64x64) { using ElementOutput = cutlass::half_t; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -145,7 +145,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x128x64_32x64x64) { using ElementOutput = cutlass::half_t; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -163,7 +163,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x64x64_64x32x64) { using ElementOutput = cutlass::half_t; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -181,7 +181,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x64x64_32x32x64) { using ElementOutput = cutlass::half_t; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -199,7 +199,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x128x128_64x64x128 using ElementOutput = cutlass::half_t; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -217,7 +217,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x64x128_64x64x128) using ElementOutput = cutlass::half_t; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -235,7 +235,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x64x128_64x32x128) using ElementOutput = cutlass::half_t; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -253,7 +253,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x64x128_32x32x128) using ElementOutput = cutlass::half_t; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, diff --git a/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu b/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu index 6218b32b79..8c652f38d2 100644 --- a/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu +++ b/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu @@ -36,7 +36,7 @@ #include "../../common/cutlass_unit_test.h" #include "cutlass/cutlass.h" -#include "cutlass/gemm/device/gemm_sparse.h" +#include "cutlass/gemm/device/gemm_sparse_universal.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/reference/host/gemm.h" #include "cutlass/util/reference/host/tensor_compare.h" @@ -54,7 +54,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x256x64_64x64x64) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -73,7 +73,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x128x64_64x64x64) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -91,7 +91,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x128x64_64x64x64) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -109,7 +109,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x64x64_64x64x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -127,7 +127,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x256x64_64x64x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -145,7 +145,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x128x64_32x64x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -164,7 +164,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x64x64_64x32x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -182,7 +182,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x64x64_32x32x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -200,7 +200,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x128x128_64x64x128 using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -218,7 +218,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x64x128_64x64x128) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -236,7 +236,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x64x128_64x32x128) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -254,7 +254,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x64x128_32x32x128) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, diff --git a/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sparse_sm80.cu b/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sparse_sm80.cu index 78ca22d360..b05e95d22f 100644 --- a/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sparse_sm80.cu +++ b/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sparse_sm80.cu @@ -36,7 +36,7 @@ #include "../../common/cutlass_unit_test.h" #include "cutlass/cutlass.h" -#include "cutlass/gemm/device/gemm_sparse.h" +#include "cutlass/gemm/device/gemm_sparse_universal.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/reference/host/gemm.h" #include "cutlass/util/reference/host/tensor_compare.h" @@ -54,7 +54,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x256x64_64x64x64) using ElementOutput = cutlass::half_t; using ElementAccumulator = cutlass::half_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -72,7 +72,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x128x64_64x64x64) using ElementOutput = cutlass::half_t; using ElementAccumulator = cutlass::half_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -90,7 +90,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x128x64_64x64x64) using ElementOutput = cutlass::half_t; using ElementAccumulator = cutlass::half_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -108,7 +108,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x64x64_64x64x64) { using ElementOutput = cutlass::half_t; using ElementAccumulator = cutlass::half_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -126,7 +126,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x256x64_64x64x64) { using ElementOutput = cutlass::half_t; using ElementAccumulator = cutlass::half_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -144,7 +144,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x128x64_32x64x64) { using ElementOutput = cutlass::half_t; using ElementAccumulator = cutlass::half_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -163,7 +163,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x64x64_64x32x64) { using ElementOutput = cutlass::half_t; using ElementAccumulator = cutlass::half_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -181,7 +181,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x64x64_32x32x64) { using ElementOutput = cutlass::half_t; using ElementAccumulator = cutlass::half_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -199,7 +199,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x128x128_64x64x128 using ElementOutput = cutlass::half_t; using ElementAccumulator = cutlass::half_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -217,7 +217,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x64x128_64x64x128) using ElementOutput = cutlass::half_t; using ElementAccumulator = cutlass::half_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -235,7 +235,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x64x128_64x32x128) using ElementOutput = cutlass::half_t; using ElementAccumulator = cutlass::half_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -253,7 +253,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x64x128_32x32x128) using ElementOutput = cutlass::half_t; using ElementAccumulator = cutlass::half_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -266,6 +266,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x64x128_32x32x128) EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); } + //////////////////////////////////////////////////////////////////////////////// #endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu b/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu index ad3f59c888..f2ddaa5a94 100644 --- a/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu +++ b/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu @@ -36,7 +36,7 @@ #include "../../common/cutlass_unit_test.h" #include "cutlass/cutlass.h" -#include "cutlass/gemm/device/gemm_sparse.h" +#include "cutlass/gemm/device/gemm_sparse_universal.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/reference/host/gemm.h" #include "cutlass/util/reference/host/tensor_compare.h" @@ -54,7 +54,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x256x64_64x64x64) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -72,7 +72,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x128x64_64x64x64) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -90,7 +90,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x128x64_64x64x64) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -108,7 +108,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x64x64_64x64x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -126,7 +126,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x256x64_64x64x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -144,7 +144,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x128x64_32x64x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -162,7 +162,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x64x64_64x32x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -180,7 +180,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x64x64_32x32x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -198,7 +198,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x128x128_64x64x128 using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -216,7 +216,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x64x128_64x64x128) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -234,7 +234,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x64x128_64x32x128) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -252,7 +252,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x64x128_32x32x128) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -266,6 +266,205 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x64x128_32x32x128) EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); } +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x32x64_32x32x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 32, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x32x128_32x32x128) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 32, 128>, + cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x32x64_64x32x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 32, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x32x128_64x32x128) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 32, 128>, + cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 32x128x128_32x32x128) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 128, 128>, + cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 900) +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 32x256x64_32x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 256, 64>, + cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} +#endif + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 32x256x128_32x64x128) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 256, 128>, + cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x16x64_32x16x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 16, 64>, + cutlass::gemm::GemmShape<32, 16, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x16x128_32x16x128) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 16, 128>, + cutlass::gemm::GemmShape<32, 16, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x16x64_64x16x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 16, 64>, + cutlass::gemm::GemmShape<64, 16, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x16x128_64x16x128) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 16, 128>, + cutlass::gemm::GemmShape<64, 16, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} //////////////////////////////////////////////////////////////////////////////// #endif // CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED diff --git a/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sparse_sm80.cu b/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sparse_sm80.cu index 877e856f89..18871125d3 100644 --- a/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sparse_sm80.cu +++ b/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sparse_sm80.cu @@ -36,7 +36,7 @@ #include "../../common/cutlass_unit_test.h" #include "cutlass/cutlass.h" -#include "cutlass/gemm/device/gemm_sparse.h" +#include "cutlass/gemm/device/gemm_sparse_universal.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/reference/host/gemm.h" #include "cutlass/util/reference/host/tensor_compare.h" @@ -54,7 +54,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x256x64_64x64x64) using ElementOutput = cutlass::half_t; using ElementAccumulator = cutlass::half_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -72,7 +72,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x128x64_64x64x64) using ElementOutput = cutlass::half_t; using ElementAccumulator = cutlass::half_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -90,7 +90,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x128x64_64x64x64) using ElementOutput = cutlass::half_t; using ElementAccumulator = cutlass::half_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -109,7 +109,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x64x64_64x64x64) { using ElementOutput = cutlass::half_t; using ElementAccumulator = cutlass::half_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -127,7 +127,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x256x64_64x64x64) { using ElementOutput = cutlass::half_t; using ElementAccumulator = cutlass::half_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -145,7 +145,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x128x64_32x64x64) { using ElementOutput = cutlass::half_t; using ElementAccumulator = cutlass::half_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -163,7 +163,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x64x64_64x32x64) { using ElementOutput = cutlass::half_t; using ElementAccumulator = cutlass::half_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -181,7 +181,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x64x64_32x32x64) { using ElementOutput = cutlass::half_t; using ElementAccumulator = cutlass::half_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -199,7 +199,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x128x128_64x64x128 using ElementOutput = cutlass::half_t; using ElementAccumulator = cutlass::half_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -218,7 +218,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x64x128_64x64x128) using ElementOutput = cutlass::half_t; using ElementAccumulator = cutlass::half_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -236,7 +236,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x64x128_64x32x128) using ElementOutput = cutlass::half_t; using ElementAccumulator = cutlass::half_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -254,7 +254,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x64x128_32x32x128) using ElementOutput = cutlass::half_t; using ElementAccumulator = cutlass::half_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, diff --git a/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu b/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu index 1a15f43cdb..3b1e85e750 100644 --- a/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu +++ b/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu @@ -36,7 +36,7 @@ #include "../../common/cutlass_unit_test.h" #include "cutlass/cutlass.h" -#include "cutlass/gemm/device/gemm_sparse.h" +#include "cutlass/gemm/device/gemm_sparse_universal.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/reference/host/gemm.h" #include "cutlass/util/reference/host/tensor_compare.h" @@ -54,7 +54,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x256x64_64x64x64) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -72,7 +72,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x128x64_64x64x64) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -90,7 +90,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x128x64_64x64x64) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -108,7 +108,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x64x64_64x64x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -126,7 +126,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x256x64_64x64x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -144,7 +144,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x128x64_32x64x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -162,7 +162,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x64x64_64x32x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -180,7 +180,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x64x64_32x32x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -198,7 +198,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x128x128_64x64x128 using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -216,7 +216,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x64x128_64x64x128) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -234,7 +234,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x64x128_64x32x128) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -252,7 +252,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x64x128_32x32x128) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -266,6 +266,206 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x64x128_32x32x128) EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); } +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 32x128x128_32x32x128) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 128, 128>, + cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 900) +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 32x256x64_32x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 256, 64>, + cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} +#endif + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 32x256x128_32x64x128) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 256, 128>, + cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x32x64_32x32x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 32, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x32x128_32x32x128) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 32, 128>, + cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x32x64_64x32x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 32, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x32x128_64x32x128) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 32, 128>, + cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x16x64_32x16x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 16, 64>, + cutlass::gemm::GemmShape<32, 16, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x16x128_32x16x128) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 16, 128>, + cutlass::gemm::GemmShape<32, 16, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x16x64_64x16x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 16, 64>, + cutlass::gemm::GemmShape<64, 16, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x16x128_64x16x128) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 16, 128>, + cutlass::gemm::GemmShape<64, 16, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + //////////////////////////////////////////////////////////////////////////////// #endif // CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED diff --git a/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sparse_sm80.cu b/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sparse_sm80.cu index 0cd439cc87..0da78ec4e3 100644 --- a/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sparse_sm80.cu +++ b/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sparse_sm80.cu @@ -36,7 +36,7 @@ #include "../../common/cutlass_unit_test.h" #include "cutlass/cutlass.h" -#include "cutlass/gemm/device/gemm_sparse.h" +#include "cutlass/gemm/device/gemm_sparse_universal.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/reference/host/gemm.h" #include "cutlass/util/reference/host/tensor_compare.h" @@ -54,7 +54,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x256x64_64x64x64) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -72,7 +72,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 256x128x64_64x64x64) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -90,7 +90,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x128x64_64x64x64) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -108,7 +108,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 256x64x64_64x64x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -126,7 +126,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x256x64_64x64x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -144,7 +144,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x128x64_32x64x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -162,7 +162,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x64x64_64x32x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -180,7 +180,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x64x64_32x32x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, diff --git a/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_f32_sparse_sm80.cu b/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_f32_sparse_sm80.cu index ce11a2590b..3b05ae1857 100644 --- a/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_f32_sparse_sm80.cu +++ b/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_f32_sparse_sm80.cu @@ -35,7 +35,7 @@ #include #include "cutlass/cutlass.h" -#include "cutlass/gemm/device/gemm_sparse.h" +#include "cutlass/gemm/device/gemm_sparse_universal.h" #include "../../common/cutlass_unit_test.h" @@ -57,7 +57,7 @@ TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 256x128x32_64x64x32) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::ColumnMajor, float, @@ -88,7 +88,7 @@ TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 128x256x32_64x64x32) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::ColumnMajor, float, @@ -119,7 +119,7 @@ TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 128x128x32_64x64x32) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::ColumnMajor, float, @@ -150,7 +150,7 @@ TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 256x64x32_64x64x32) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::ColumnMajor, float, @@ -181,7 +181,7 @@ TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 64x256x32_64x64x32) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::ColumnMajor, float, @@ -212,7 +212,7 @@ TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 128x64x32_64x32x32) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::ColumnMajor, float, @@ -243,7 +243,7 @@ TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 64x128x32_32x64x32) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::ColumnMajor, float, @@ -274,7 +274,7 @@ TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 64x64x32_32x32x32) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::ColumnMajor, float, @@ -305,7 +305,7 @@ TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 128x128x64_64x64x64) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::ColumnMajor, float, @@ -336,7 +336,7 @@ TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 256x64x64_64x64x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::ColumnMajor, float, @@ -367,7 +367,7 @@ TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 128x64x64_64x32x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::ColumnMajor, float, @@ -398,7 +398,7 @@ TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 64x64x64_32x32x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::ColumnMajor, float, diff --git a/test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu b/test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu index 6b811413c4..8ce0f138f2 100644 --- a/test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu +++ b/test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu @@ -35,7 +35,7 @@ #include #include "cutlass/cutlass.h" -#include "cutlass/gemm/device/gemm_sparse.h" +#include "cutlass/gemm/device/gemm_sparse_universal.h" #include "../../common/cutlass_unit_test.h" @@ -57,7 +57,7 @@ TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 256x128x32_64x64x32) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::ColumnMajor, float, @@ -88,7 +88,7 @@ TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 128x256x32_64x64x32) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::ColumnMajor, float, @@ -119,7 +119,7 @@ TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 128x128x32_64x64x32) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::ColumnMajor, float, @@ -150,7 +150,7 @@ TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 256x64x32_64x64x32) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::ColumnMajor, float, @@ -181,7 +181,7 @@ TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 64x256x32_64x64x32) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::ColumnMajor, float, @@ -212,7 +212,7 @@ TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 128x64x32_64x32x32) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::ColumnMajor, float, @@ -243,7 +243,7 @@ TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 64x128x32_32x64x32) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::ColumnMajor, float, @@ -274,7 +274,7 @@ TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 64x64x32_32x32x32) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::ColumnMajor, float, @@ -305,7 +305,7 @@ TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 128x128x64_64x64x64) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::ColumnMajor, float, @@ -336,7 +336,7 @@ TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 256x64x64_64x64x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::ColumnMajor, float, @@ -367,7 +367,7 @@ TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 128x64x64_64x32x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::ColumnMajor, float, @@ -398,7 +398,7 @@ TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 64x64x64_32x32x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::ColumnMajor, float, @@ -424,6 +424,100 @@ TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 64x64x64_32x32x64) { EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); } +TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 32x128x64_32x32x64) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 128, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 900) +TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 32x256x32_32x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 256, 32>, + cutlass::gemm::GemmShape<32, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} +#endif + +TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 32x256x64_32x64x64) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 256, 64>, + cutlass::gemm::GemmShape<32, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} ///////////////////////////////////////////////////////////////////////////////////////////////// #endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu b/test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu index cdb8cb411c..382b0b2261 100644 --- a/test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu +++ b/test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu @@ -35,7 +35,7 @@ #include #include "cutlass/cutlass.h" -#include "cutlass/gemm/device/gemm_sparse.h" +#include "cutlass/gemm/device/gemm_sparse_universal.h" #include "../../common/cutlass_unit_test.h" @@ -57,7 +57,7 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 256x128x32_64x64x32) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::RowMajor, float, @@ -88,7 +88,7 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 128x256x32_64x64x32) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::RowMajor, float, @@ -119,7 +119,7 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 128x128x32_64x64x32) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::RowMajor, float, @@ -150,7 +150,7 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 256x64x32_64x64x32) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::RowMajor, float, @@ -181,7 +181,7 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 64x256x32_64x64x32) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::RowMajor, float, @@ -212,7 +212,7 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 128x64x32_64x32x32) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::RowMajor, float, @@ -243,7 +243,7 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 64x128x32_32x64x32) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::RowMajor, float, @@ -274,7 +274,7 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 64x64x32_32x32x32) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::RowMajor, float, @@ -305,7 +305,7 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 128x128x64_64x64x64) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::RowMajor, float, @@ -336,7 +336,7 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 256x64x64_64x64x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::RowMajor, float, @@ -367,7 +367,7 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 128x64x64_64x32x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::RowMajor, float, @@ -398,7 +398,7 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 64x64x64_32x32x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::RowMajor, float, @@ -423,6 +423,350 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 64x64x64_32x32x64) { EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); } + +TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 32x128x64_32x32x64) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 128, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 900) +TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 32x256x32_32x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 256, 32>, + cutlass::gemm::GemmShape<32, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} +#endif + +TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 32x256x64_32x64x64) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 256, 64>, + cutlass::gemm::GemmShape<32, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 128x32x32_32x32x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 32, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 128x32x64_32x32x64) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 32, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 256x32x32_64x32x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 32, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 256x32x64_64x32x64) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 32, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 128x16x32_32x16x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 16, 32>, + cutlass::gemm::GemmShape<32, 16, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 128x16x64_32x16x64) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 16, 64>, + cutlass::gemm::GemmShape<32, 16, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 256x16x32_64x16x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 16, 32>, + cutlass::gemm::GemmShape<64, 16, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 256x16x64_64x16x64) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 16, 64>, + cutlass::gemm::GemmShape<64, 16, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + ///////////////////////////////////////////////////////////////////////////////////////////////// #endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_f32t_f32t_f32t_tensor_op_f32_sparse_sm80.cu b/test/unit/gemm/device/gemm_f32t_f32t_f32t_tensor_op_f32_sparse_sm80.cu index e75e2de58c..3988e6e4e0 100644 --- a/test/unit/gemm/device/gemm_f32t_f32t_f32t_tensor_op_f32_sparse_sm80.cu +++ b/test/unit/gemm/device/gemm_f32t_f32t_f32t_tensor_op_f32_sparse_sm80.cu @@ -35,7 +35,7 @@ #include #include "cutlass/cutlass.h" -#include "cutlass/gemm/device/gemm_sparse.h" +#include "cutlass/gemm/device/gemm_sparse_universal.h" #include "../../common/cutlass_unit_test.h" @@ -57,7 +57,7 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 256x128x32_64x64x32) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::RowMajor, float, @@ -88,7 +88,7 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 128x256x32_64x64x32) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::RowMajor, float, @@ -119,7 +119,7 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 128x128x32_64x64x32) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::RowMajor, float, @@ -150,7 +150,7 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 256x64x32_64x64x32) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::RowMajor, float, @@ -181,7 +181,7 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 64x256x32_64x64x32) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::RowMajor, float, @@ -212,7 +212,7 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 128x64x32_64x32x32) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::RowMajor, float, @@ -243,7 +243,7 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 64x128x32_32x64x32) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::RowMajor, float, @@ -274,7 +274,7 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 64x64x32_32x32x32) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::RowMajor, float, @@ -305,7 +305,7 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 128x128x64_64x64x64) using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::RowMajor, float, @@ -336,7 +336,7 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 256x64x64_64x64x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::RowMajor, float, @@ -367,7 +367,7 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 128x64x64_64x32x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::RowMajor, float, @@ -398,7 +398,7 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 64x64x64_32x32x64) { using ElementOutput = float; using ElementAccumulator = float; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< float, cutlass::layout::RowMajor, float, diff --git a/test/unit/gemm/device/gemm_f8t_f8n_f32t_tensor_op_f32_sparse_sm89.cu b/test/unit/gemm/device/gemm_f8t_f8n_f32t_tensor_op_f32_sparse_sm89.cu index c347134095..d21020f235 100644 --- a/test/unit/gemm/device/gemm_f8t_f8n_f32t_tensor_op_f32_sparse_sm89.cu +++ b/test/unit/gemm/device/gemm_f8t_f8n_f32t_tensor_op_f32_sparse_sm89.cu @@ -41,7 +41,7 @@ #include "../../common/cutlass_unit_test.h" #include "cutlass/cutlass.h" -#include "cutlass/gemm/device/gemm_sparse.h" +#include "cutlass/gemm/device/gemm_sparse_universal.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/reference/host/gemm.h" #include "cutlass/util/reference/host/tensor_compare.h" @@ -65,7 +65,7 @@ TEST(SM89_Device_Sparse_Gemm_fe4m3t_fe4m3n_f32t_tensor_op_f32, 128x128x128_64x64 using LayoutC = cutlass::layout::RowMajor; static int const kStages = 3; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, cutlass::gemm::GemmShape<128, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, @@ -89,7 +89,7 @@ TEST(SM89_Device_Sparse_Gemm_fe4m3t_fe5m2n_f32t_tensor_op_f32, 128x128x128_64x64 using LayoutC = cutlass::layout::RowMajor; static int const kStages = 3; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, cutlass::gemm::GemmShape<128, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, @@ -113,7 +113,7 @@ TEST(SM89_Device_Sparse_Gemm_fe5m2t_fe4m3n_f32t_tensor_op_f32, 128x128x128_64x64 using LayoutC = cutlass::layout::RowMajor; static int const kStages = 3; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, cutlass::gemm::GemmShape<128, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, @@ -137,7 +137,7 @@ TEST(SM89_Device_Sparse_Gemm_fe5m2t_fe5m2n_f32t_tensor_op_f32, 128x128x128_64x64 using LayoutC = cutlass::layout::RowMajor; static int const kStages = 3; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, cutlass::gemm::GemmShape<128, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, diff --git a/test/unit/gemm/device/gemm_f8t_f8n_f8t_tensor_op_f32_sparse_sm89.cu b/test/unit/gemm/device/gemm_f8t_f8n_f8t_tensor_op_f32_sparse_sm89.cu index dc50c8f2d5..0733bc7097 100644 --- a/test/unit/gemm/device/gemm_f8t_f8n_f8t_tensor_op_f32_sparse_sm89.cu +++ b/test/unit/gemm/device/gemm_f8t_f8n_f8t_tensor_op_f32_sparse_sm89.cu @@ -43,7 +43,7 @@ #include "cutlass/cutlass.h" #include "cutlass/epilogue/thread/activation.h" #include "cutlass/epilogue/thread/linear_combination_generic_with_scaling.h" -#include "cutlass/gemm/device/gemm_sparse_with_absmax.h" +#include "cutlass/gemm/device/gemm_sparse_universal_with_absmax.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/reference/host/gemm.h" #include "cutlass/util/reference/host/tensor_compare.h" @@ -78,7 +78,41 @@ TEST(SM89_Device_Sparse_Gemm_fe4m3t_fe4m3n_fe4m3t_tensor_op_f32, identity_128x12 ElementAccumulator >; - using Gemm = cutlass::gemm::device::SparseGemmWithAbsmax< + using Gemm = cutlass::gemm::device::GemmSparseUniversalWithAbsmax< + ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, + cutlass::gemm::GemmShape<128, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + EpilogueOutputOp, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages + >; + + bool passed = test::gemm::device::TestAllGemmWithAbsmax, cutlass::epilogue::thread::Identity>(); + EXPECT_TRUE(passed); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM89_Device_Sparse_Gemm_fe4m3t_fe4m3n_f32t_tensor_op_f32, identity_128x128x128_64x64x128) { + // Test with float D and Aux for testing split-K without needing relative equality checks + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementOutput = float; + using ElementAuxOutput = ElementOutput; + using ElementAccumulator = float; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + static int const kStages = 3; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationGenericWithScalingAndAbsMax< + cutlass::epilogue::thread::Identity, + ElementOutput, + ElementAuxOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >; + + using Gemm = cutlass::gemm::device::GemmSparseUniversalWithAbsmax< ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, cutlass::gemm::GemmShape<128, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, @@ -112,12 +146,12 @@ TEST(SM89_Device_Sparse_Gemm_fe4m3t_fe4m3n_fe4m3t_tensor_op_f32, identity_fastac ElementAccumulator >; - using Gemm = cutlass::gemm::device::SparseGemmWithAbsmax< + using Gemm = cutlass::gemm::device::GemmSparseUniversalWithAbsmax< ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, cutlass::gemm::GemmShape<128, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, EpilogueOutputOp, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages, - kAlignment, kAlignment, false, cutlass::arch::OpMultiplyAddFastAccum + kAlignment, kAlignment, cutlass::arch::OpMultiplyAddFastAccum >; bool passed = test::gemm::device::TestAllGemmWithAbsmax, cutlass::epilogue::thread::Identity>(); @@ -146,7 +180,7 @@ TEST(SM89_Device_Sparse_Gemm_fe4m3t_fe4m3n_fe4m3t_tensor_op_f32, relu_128x128x12 ElementAccumulator >; - using Gemm = cutlass::gemm::device::SparseGemmWithAbsmax< + using Gemm = cutlass::gemm::device::GemmSparseUniversalWithAbsmax< ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, cutlass::gemm::GemmShape<128, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, @@ -179,7 +213,7 @@ TEST(SM89_Device_Sparse_Gemm_fe4m3t_fe5m2n_fe4m3t_tensor_op_f32, identity_128x12 ElementAccumulator >; - using Gemm = cutlass::gemm::device::SparseGemmWithAbsmax< + using Gemm = cutlass::gemm::device::GemmSparseUniversalWithAbsmax< ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, cutlass::gemm::GemmShape<128, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, @@ -212,7 +246,7 @@ TEST(SM89_Device_Sparse_Gemm_fe5m2t_fe4m3n_fe4m3t_tensor_op_f32, identity_128x12 ElementAccumulator >; - using Gemm = cutlass::gemm::device::SparseGemmWithAbsmax< + using Gemm = cutlass::gemm::device::GemmSparseUniversalWithAbsmax< ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, cutlass::gemm::GemmShape<128, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, @@ -245,7 +279,7 @@ TEST(SM89_Device_Sparse_Gemm_fe5m2t_fe5m2n_fe4m3t_tensor_op_f32, identity_128x12 ElementAccumulator >; - using Gemm = cutlass::gemm::device::SparseGemmWithAbsmax< + using Gemm = cutlass::gemm::device::GemmSparseUniversalWithAbsmax< ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, cutlass::gemm::GemmShape<128, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, @@ -278,7 +312,7 @@ TEST(SM89_Device_Sparse_Gemm_fe4m3t_fe4m3n_fe5m2t_tensor_op_f32, identity_128x12 ElementAccumulator >; - using Gemm = cutlass::gemm::device::SparseGemmWithAbsmax< + using Gemm = cutlass::gemm::device::GemmSparseUniversalWithAbsmax< ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, cutlass::gemm::GemmShape<128, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, @@ -311,7 +345,7 @@ TEST(SM89_Device_Sparse_Gemm_fe5m2t_fe5m2n_fe5m2t_tensor_op_f32, identity_diff_a ElementAccumulator >; - using Gemm = cutlass::gemm::device::SparseGemmWithAbsmax< + using Gemm = cutlass::gemm::device::GemmSparseUniversalWithAbsmax< ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, cutlass::gemm::GemmShape<128, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, @@ -344,7 +378,7 @@ TEST(SM89_Device_Sparse_Gemm_fe4m3t_fe4m3n_fe4m3t_tensor_op_f32, identity_128x64 ElementAccumulator >; - using Gemm = cutlass::gemm::device::SparseGemmWithAbsmax< + using Gemm = cutlass::gemm::device::GemmSparseUniversalWithAbsmax< ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, cutlass::gemm::GemmShape<128, 64, 128>, cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, @@ -377,7 +411,7 @@ TEST(SM89_Device_Sparse_Gemm_fe4m3t_fe4m3n_fe4m3t_tensor_op_f32, identity_noScal ElementAccumulator >; - using Gemm = cutlass::gemm::device::SparseGemmWithAbsmax< + using Gemm = cutlass::gemm::device::GemmSparseUniversalWithAbsmax< ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, cutlass::gemm::GemmShape<128, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, @@ -414,7 +448,7 @@ TEST(SM89_Device_Sparse_Gemm_fe4m3t_fe4m3n_fe4m3t_tensor_op_f32, identity_noAux_ ElementAccumulator >; - using Gemm = cutlass::gemm::device::SparseGemmWithAbsmax< + using Gemm = cutlass::gemm::device::GemmSparseUniversalWithAbsmax< ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, cutlass::gemm::GemmShape<128, 128, 128>, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, diff --git a/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu b/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu index 7dbc3cd17b..bea3e946b8 100644 --- a/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu +++ b/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu @@ -36,7 +36,7 @@ #include "../../common/cutlass_unit_test.h" #include "cutlass/cutlass.h" -#include "cutlass/gemm/device/gemm_sparse.h" +#include "cutlass/gemm/device/gemm_sparse_universal.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/reference/host/gemm.h" #include "cutlass/util/reference/host/tensor_compare.h" @@ -55,7 +55,7 @@ TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x256x256_64x64x256) using ElementAccumulator = int32_t; using ElementCompute = int32_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -74,7 +74,7 @@ TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x128x256_64x64x256) using ElementAccumulator = int32_t; using ElementCompute = int32_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -93,7 +93,7 @@ TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x128x256_64x64x256) using ElementAccumulator = int32_t; using ElementCompute = int32_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -113,7 +113,7 @@ TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x64x256_64x64x256) { using ElementAccumulator = int32_t; using ElementCompute = int32_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -132,7 +132,7 @@ TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x256x256_64x64x256) { using ElementAccumulator = int32_t; using ElementCompute = int32_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -151,7 +151,7 @@ TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x128x256_32x64x256) { using ElementAccumulator = int32_t; using ElementCompute = int32_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -170,7 +170,7 @@ TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x64x256_64x32x256) { using ElementAccumulator = int32_t; using ElementCompute = int32_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -189,7 +189,7 @@ TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x64x256_32x32x256) { using ElementAccumulator = int32_t; using ElementCompute = int32_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -208,7 +208,7 @@ TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x128x512_64x64x512) using ElementAccumulator = int32_t; using ElementCompute = int32_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -228,7 +228,7 @@ TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x64x512_64x32x512) { using ElementAccumulator = int32_t; using ElementCompute = int32_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -247,7 +247,7 @@ TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x64x512_32x32x512) { using ElementAccumulator = int32_t; using ElementCompute = int32_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -261,7 +261,272 @@ TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x64x512_32x32x512) { EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); } +TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 32x128x512_32x32x512) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 128, 512>, + cutlass::gemm::GemmShape<32, 32, 512>, cutlass::gemm::GemmShape<16, 8, 128>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 32x128x256_32x32x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 128, 256>, + cutlass::gemm::GemmShape<32, 32, 256>, cutlass::gemm::GemmShape<16, 8, 128>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 32x256x256_32x64x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 256, 256>, + cutlass::gemm::GemmShape<32, 64, 256>, cutlass::gemm::GemmShape<16, 8, 128>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 32x256x512_32x64x512) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 256, 512>, + cutlass::gemm::GemmShape<32, 64, 512>, cutlass::gemm::GemmShape<16, 8, 128>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 16x128x512_16x32x512) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<16, 128, 512>, + cutlass::gemm::GemmShape<16, 32, 512>, cutlass::gemm::GemmShape<16, 8, 128>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 16x256x512_16x64x512) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<16, 256, 512>, + cutlass::gemm::GemmShape<16, 64, 512>, cutlass::gemm::GemmShape<16, 8, 128>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x32x256_32x32x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 32, 256>, + cutlass::gemm::GemmShape<32, 32, 256>, cutlass::gemm::GemmShape<16, 8, 128>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x32x512_32x32x512) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 32, 512>, + cutlass::gemm::GemmShape<32, 32, 512>, cutlass::gemm::GemmShape<16, 8, 128>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x32x256_64x32x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 32, 256>, + cutlass::gemm::GemmShape<64, 32, 256>, cutlass::gemm::GemmShape<16, 8, 128>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x32x512_64x32x512) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 32, 512>, + cutlass::gemm::GemmShape<64, 32, 512>, cutlass::gemm::GemmShape<16, 8, 128>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x16x256_32x16x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 16, 256>, + cutlass::gemm::GemmShape<32, 16, 256>, cutlass::gemm::GemmShape<16, 8, 128>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x16x512_32x16x512) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 16, 512>, + cutlass::gemm::GemmShape<32, 16, 512>, cutlass::gemm::GemmShape<16, 8, 128>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x16x256_16x64x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 16, 256>, + cutlass::gemm::GemmShape<64, 16, 256>, cutlass::gemm::GemmShape<16, 8, 128>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x16x512_16x64x512) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 16, 512>, + cutlass::gemm::GemmShape<64, 16, 512>, cutlass::gemm::GemmShape<16, 8, 128>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + //////////////////////////////////////////////////////////////////////////////// #endif // defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) - diff --git a/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu b/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu index da921e173b..4cb879b635 100644 --- a/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu +++ b/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu @@ -36,7 +36,7 @@ #include "../../common/cutlass_unit_test.h" #include "cutlass/cutlass.h" -#include "cutlass/gemm/device/gemm_sparse.h" +#include "cutlass/gemm/device/gemm_sparse_universal.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/reference/host/gemm.h" #include "cutlass/util/reference/host/tensor_compare.h" @@ -55,7 +55,7 @@ TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x256x128_64x64x128) using ElementAccumulator = int32_t; using ElementCompute = int32_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -74,7 +74,7 @@ TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x128x128_64x64x128) using ElementAccumulator = int32_t; using ElementCompute = int32_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -93,7 +93,7 @@ TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x128x128_64x64x128) using ElementAccumulator = int32_t; using ElementCompute = int32_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -113,7 +113,7 @@ TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x64x128_64x64x128) { using ElementAccumulator = int32_t; using ElementCompute = int32_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -132,7 +132,7 @@ TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x256x128_64x64x128) { using ElementAccumulator = int32_t; using ElementCompute = int32_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -151,7 +151,7 @@ TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x128x128_32x64x128) { using ElementAccumulator = int32_t; using ElementCompute = int32_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -170,7 +170,7 @@ TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x64x128_64x32x128) { using ElementAccumulator = int32_t; using ElementCompute = int32_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -189,7 +189,7 @@ TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x64x128_32x32x128) { using ElementAccumulator = int32_t; using ElementCompute = int32_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -208,7 +208,7 @@ TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x128x256_64x64x256) using ElementAccumulator = int32_t; using ElementCompute = int32_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -228,7 +228,7 @@ TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x64x256_64x32x256) { using ElementAccumulator = int32_t; using ElementCompute = int32_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -247,7 +247,7 @@ TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x64x256_32x32x256) { using ElementAccumulator = int32_t; using ElementCompute = int32_t; - using Gemm = cutlass::gemm::device::SparseGemm< + using Gemm = cutlass::gemm::device::GemmSparseUniversal< int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, @@ -261,9 +261,273 @@ TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x64x256_32x32x256) { EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); } +TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 32x128x128_32x32x128) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 128, 128>, + cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 32x128x256_32x32x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 128, 256>, + cutlass::gemm::GemmShape<32, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 32x256x128_32x64x128) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 256, 128>, + cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 32x256x256_32x64x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 256, 256>, + cutlass::gemm::GemmShape<32, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 16x128x256_16x32x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<16, 128, 256>, + cutlass::gemm::GemmShape<16, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 16x128x256_32x32x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<16, 128, 256>, + cutlass::gemm::GemmShape<16, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x32x128_32x32x128) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 32, 128>, + cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x32x256_32x32x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 32, 256>, + cutlass::gemm::GemmShape<32, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x32x128_64x32x128) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 32, 128>, + cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x32x256_64x32x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 32, 256>, + cutlass::gemm::GemmShape<64, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x16x128_32x16x128) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 16, 128>, + cutlass::gemm::GemmShape<32, 16, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x16x256_32x16x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 16, 256>, + cutlass::gemm::GemmShape<32, 16, 256>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x16x128_64x16x128) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 16, 128>, + cutlass::gemm::GemmShape<64, 16, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x16x256_64x16x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::GemmSparseUniversal< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 16, 256>, + cutlass::gemm::GemmShape<64, 16, 256>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} //////////////////////////////////////////////////////////////////////////////// #endif // defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) - diff --git a/test/unit/gemm/device/gemm_testbed_3x.hpp b/test/unit/gemm/device/gemm_testbed_3x.hpp index b0225b812e..43ab224dc0 100644 --- a/test/unit/gemm/device/gemm_testbed_3x.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x.hpp @@ -346,7 +346,7 @@ struct HostCollectiveMainloop { tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); - + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2022)); EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2021)); @@ -710,7 +710,7 @@ struct HostCollectiveEpilogue { using ActivationFunctor = non_void_t>; - static constexpr bool IsBiasEnabled = FusionOp::IsPerRowBiasSupported; + static constexpr bool IsRowBiasEnabled = FusionOp::IsPerRowBiasSupported; static constexpr bool IsDeBiasEnabled = FusionOp::IsDePerRowBiasSupported; static constexpr bool IsPerRowScaleEnabled = FusionOp::IsPerRowScaleSupported; static constexpr bool IsScaleFactorEnabled = FusionOp::IsScaleFactorSupported; @@ -813,6 +813,7 @@ struct HostCollectiveEpilogue { auto scalar_coord = cutlass::make_Coord(1); auto col_vector_coord = cutlass::make_Coord(M); + auto row_vector_coord = cutlass::make_Coord(N); if constexpr (IsPerRowScaleEnabled) { alpha.resize(col_vector_coord); EXPECT_TRUE(initialize_tensor(alpha.host_view(), init_scale, seed + 2023)); @@ -849,8 +850,10 @@ struct HostCollectiveEpilogue { scale_D.sync_device(); } - if constexpr (IsBiasEnabled) { - bias.resize(col_vector_coord); + if constexpr ( + IsRowBiasEnabled + ) { + bias.resize(IsRowBiasEnabled ? col_vector_coord : row_vector_coord); EXPECT_TRUE(initialize_tensor(bias.host_view(), init_bias, seed + 2023)); bias.sync_device(); } @@ -1029,10 +1032,9 @@ struct HostCollectiveEpilogue { file << "\n\n"; } - if constexpr (IsBiasEnabled) { + if constexpr (IsRowBiasEnabled) { file << "\n\nBias = \n" << bias.host_view(); } - if constexpr (IsAuxInEnabled) { file << "\n\nAux Input = \n" << tensor_Aux.host_view(); } @@ -1090,7 +1092,9 @@ struct HostCollectiveEpilogue { fusion_args.scale_d_ptr = scale_D.device_data(); } - if constexpr (IsBiasEnabled) { + if constexpr ( + IsRowBiasEnabled + ) { fusion_args.bias_ptr = bias.device_data(); } @@ -1153,7 +1157,7 @@ struct HostCollectiveEpilogue { auto D = cute::make_tensor(detail::make_iterator(reference_D.host_data()), cute::make_layout(cute::make_shape(M, N, L), stride_d)); auto Bias = cute::make_tensor(detail::make_iterator(IsDeBiasEnabled ? reference_dbias.host_data() : bias.host_data()), - cute::make_layout(cute::make_shape(M, cute::_1{}))); + cute::make_layout(cute::make_shape(IsRowBiasEnabled ? M : N))); auto Aux = cute::make_tensor(detail::make_iterator(IsAuxInEnabled ? tensor_Aux.host_data() : reference_Aux.host_data()), cute::make_layout(cute::make_shape(M, N, L), stride_Aux)); auto Valpha = cute::make_tensor(detail::make_iterator(alpha.host_data()), @@ -1171,7 +1175,8 @@ struct HostCollectiveEpilogue { decltype(Aux), decltype(Valpha), decltype(Vbeta), - ActivationFunctor + ActivationFunctor, + cutlass::plus > epilogue_params{}; epilogue_params.C = C; @@ -1186,7 +1191,9 @@ struct HostCollectiveEpilogue { epilogue_params.scale_d = scale_D.at(coord_0); } - if constexpr (IsBiasEnabled or IsDeBiasEnabled) { + if constexpr (IsRowBiasEnabled + or IsDeBiasEnabled) + { epilogue_params.Bias = Bias; } diff --git a/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp b/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp new file mode 100644 index 0000000000..e2d3f2d06a --- /dev/null +++ b/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp @@ -0,0 +1,1792 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Testbed for Ptr-Array and Grouped GEMM interface +*/ + +#pragma once + +#include +#include +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/fusion/operations.hpp" +#include "cutlass/complex.h" +#include "testbed_utils.h" + +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/gemm/gemm.h" + +#include "cute/int_tuple.hpp" +#include "cute/layout.hpp" +#include "cute/numeric/int.hpp" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +enum class ScalarLoc { + ON_HOST = 0, + ON_DEVICE = 1 +}; + +enum class VectorBeta { + DISABLED = 0, + ENABLED = 1 +}; + +enum class CheckEquality { + EXACT = 0, + RELATIVE = 1 +}; + +namespace detail{ + +// Helper classes that take default data type when +// the Gemm::EpilogueOutputOp does not have ElementCompute +// and ElementScalar. +// (e.g. when Sm90TreeVisitor is used as FusionCallbacks) +template +struct ElementComputeType { + using Type = Default; +}; + +template +struct ElementComputeType> { + using Type = typename Gemm::EpilogueOutputOp::ElementCompute; +}; + +template +struct ElementScalarType { + using Type = Default; +}; + +template +struct ElementScalarType> { + using Type = typename Gemm::EpilogueOutputOp::ElementScalar; +}; + +// The maximum swizzle size to use +// +// This class, like Splits above makes it harder to confuse +// the order of arguments of the various run(...) functions in this file. +class MaxSwizzleSize { +public: + MaxSwizzleSize() = default; + + template && + !cute::is_same_v)) > + explicit MaxSwizzleSize(IntegralNotBool max_swizzle_size) : max_swizzle_size_(max_swizzle_size) {} + explicit operator int() const { return max_swizzle_size_; } +private: + int max_swizzle_size_ = 1; +}; + +template +auto make_iterator(T* ptr) { + using namespace cute; + if constexpr (cute::is_subbyte_v) { + return subbyte_iterator(ptr); + } + else { + return ptr; + } +} + +template +struct IsDefaultEpilogue { + static constexpr bool value = false; +}; + +template +struct IsDefaultEpilogue> { + static constexpr bool value = true; +}; + +template +struct IsDefaultEpilogue> { + static constexpr bool value = true; +}; + +// The number of splits to test. +// +// This class makes it harder to confuse the order of arguments +// of the various run(...) functions in this file. The constructor +// is explicit, so one can't just type 42 (or false, which the +// compiler unhelpfully turns into 0); one has to type Splits(42). +// Splits() picks the default number of splits, 1. +// +// The conversion-to-int operator (operator int()) MUST be explicit! +// Conversion to int MUST require static_cast. +// Otherwise, that defeats a key purpose of this class, +// which is to catch common errors of confusing the order +// of function arguments. +class Splits { +public: + Splits() = default; + + template && + !cute::is_same_v)) > + explicit Splits(IntegralNotBool splits) : splits_(splits) {} + explicit operator int() const { return splits_; } +private: + int splits_ = 1; +}; + +// The number of iterations to test. +// +// This class, like Splits above makes it harder to confuse +// the order of arguments of the various run(...) functions in this file. +// Iterations() picks the default number of iterations, 20. +class Iterations { +public: + Iterations() = default; + + template && + !cute::is_same_v)) > + explicit Iterations(IntegralNotBool iterations) : iterations_(iterations) {} + explicit operator int() const { return iterations_; } +private: + int iterations_ = 20; +}; + +template +bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 1; + scope_min = -1; + } + else{ + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + + else if (dist_kind == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(view); + } + + else if (dist_kind == cutlass::Distribution::Gaussian) { + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; +} + +// Looks at Cute Stride to check Row / Column Major +template +static constexpr bool is_row_or_col_major(){ + int stride_0 = int(cute::size<0>(Stride{})); + int stride_1 = int(cute::size<1>(Stride{})); + int depth = cute::depth(Stride{}); + return ((stride_0 == 1) || (stride_1 == 1)) && (depth == 1); +} + + +// +// Default MMA input Operands : A , B +// +template< + class ScheduleType_, + class Gemm, + class ElementA_ = typename Gemm::GemmKernel::ElementA, + class ElementB_ = typename Gemm::GemmKernel::ElementB> +struct HostCollectiveMainloop { + // Kernel data types + using ElementA = ElementA_; + using StrideA = typename Gemm::GemmKernel::StrideA; + using InternalStrideA = typename Gemm::GemmKernel::InternalStrideA; + using ElementB = ElementB_; + using StrideB = typename Gemm::GemmKernel::StrideB; + using InternalStrideB = typename Gemm::GemmKernel::InternalStrideB; + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using LayoutTagA = cutlass::detail::StrideToLayoutTagA_t; + using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; + + static constexpr bool IsGroupGemm = !cute::is_same_v; + + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; + + using Arguments = typename Gemm::GemmKernel::MainloopArguments; + + cutlass::ComplexTransform TransformA = Gemm::kTransformA; + cutlass::ComplexTransform TransformB = Gemm::kTransformB; + + std::vector stride_a_host; + std::vector stride_b_host; + + cutlass::DeviceAllocation stride_a_device; + cutlass::DeviceAllocation stride_b_device; + + typename LayoutTagA::Stride stride_factor_A; + typename LayoutTagB::Stride stride_factor_B; + + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + + std::vector> tensors_A; + std::vector> tensors_B; + cutlass::DeviceAllocation device_tensors_A; + cutlass::DeviceAllocation device_tensors_B; + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + // Note: this limitation comes from testbed / not the library + static_assert(is_row_or_col_major(), + "ERROR : A Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : B Layout is neither Row / Column Major)"); + + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed, + typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), + typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride() + ): + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + init_A(init_A_), init_B(init_B_), seed(seed_), + check_relative_equality(check_relative_equality_) { } + + bool initialize(ProblemShapeType problem_shapes) { + // + // Allocate the GEMM workspace + // + + // for pointer array problem_shapes.groups() is 1 + + tensors_A.clear(); + tensors_B.clear(); + stride_a_host.clear(); + stride_b_host.clear(); + + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = max(problem_shapes.groups(), L); + + for(int32_t i = 0; i < L; ++i) { + auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); + + stride_a_host.push_back(cutlass::make_cute_packed_stride(InternalStrideA{}, {M, K, 1})); + stride_b_host.push_back(cutlass::make_cute_packed_stride(InternalStrideB{}, {N, K, 1})); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto a_coord = cutlass::make_Coord(M, K); + // Cutlass has Row/Col major refers to MxK times KxN matrix product, + // so the HostTensorB should be treated as KxN in "coord"'s view + auto b_coord = cutlass::make_Coord(K, N); + + tensors_A.push_back(cutlass::HostTensor(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A))); + tensors_B.push_back(cutlass::HostTensor(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B))); + + EXPECT_TRUE(initialize_tensor(tensors_A[i].host_view(), init_A, seed + 2022 + i)); + EXPECT_TRUE(initialize_tensor(tensors_B[i].host_view(), init_B, seed + 2021 + i)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensors_A[i].host_view().at({0, 0}) = ElementA(1); + tensors_B[i].host_view().at({0, 0}) = ElementB(1); + + tensors_A[i].sync_device(); + tensors_B[i].sync_device(); + } + + return true; + } + + Arguments to_args(ProblemShapeType problem_shapes) { + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = max(problem_shapes.groups(), L); + + std::vector ptr_A_host(L); + std::vector ptr_B_host(L); + + for (int32_t i = 0; i < L; ++i) { + ptr_A_host.at(i) = tensors_A[i].device_data(); + ptr_B_host.at(i) = tensors_B[i].device_data(); + } + + device_tensors_A.reset(L); + device_tensors_A.copy_from_host(ptr_A_host.data()); + + device_tensors_B.reset(L); + device_tensors_B.copy_from_host(ptr_B_host.data()); + + stride_a_device.reset(problem_shapes.groups()); + stride_a_device.copy_from_host(stride_a_host.data()); + stride_b_device.reset(problem_shapes.groups()); + stride_b_device.copy_from_host(stride_b_host.data()); + + Arguments arguments; + + if constexpr (IsGroupGemm) { + arguments + = + { + device_tensors_A.get(), stride_a_device.get(), device_tensors_B.get(), stride_b_device.get() + }; + } + else { + arguments = + { + device_tensors_A.get(), stride_a_host[0], device_tensors_B.get(), stride_b_host[0] + }; + } + + return arguments; + } + + auto to_host_args(ProblemShapeType problem_shapes, int batch) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); + auto A = make_tensor(make_iterator(tensors_A[batch].host_data()), + make_layout(make_shape(M, K, 1), stride_a_host[batch])); + auto B = make_tensor(make_iterator(tensors_B[batch].host_data()), + make_layout(make_shape(N, K, 1), stride_b_host[batch])); + + cutlass::reference::host::GettMainloopParams mainloop_params{}; + + mainloop_params.A = A; + mainloop_params.B = B; + mainloop_params.transform_A = TransformA; + mainloop_params.transform_B = TransformB; + + return mainloop_params; + } + + void print_tensors(std::ofstream& file, int batch) { + file << "A =\n" << tensors_A[batch].host_view() + << "\nB =\n" << tensors_B[batch].host_view(); + } + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + + // Factors used for calculating relative equality. CUTLASS's relative-equality + // checks in include/cutlass/relatively_equal.h are inspired by + // https://floating-point-gui.de/errors/comparison/. This reference suggests using + // the minimum normal value of a given type as the nonzero_floor. + Element epsilon(static_cast(0.1f)); + Element nonzero_floor(std::numeric_limits::min()); + + if constexpr (!cutlass::is_complex::value) { + if (check_relative_equality == CheckEquality::RELATIVE) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, epsilon, nonzero_floor); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + bool compare_reference( + ProblemShapeType problem_shapes, int batch) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_A[batch].host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_B[batch].host_view()), 0); + + bool passed = true; + return passed; + } +}; + +template +struct HostCollectiveDefaultEpilogue { + // fusion types are potentially void if the fusion is not supported + // helper so we don't try to construct HostTensor with void type + template + using non_void_t = cute::conditional_t, U, T>; + + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using kernel = typename Gemm::GemmKernel; + using Epilogue = typename kernel::CollectiveEpilogue; + + using ElementD = typename kernel::ElementD; + using StrideD = typename kernel::StrideD; + using InternalStrideD = typename kernel::InternalStrideD; + using ElementC = non_void_t; + using StrideC = typename kernel::StrideC; + using InternalStrideC = typename kernel::InternalStrideC; + + static constexpr bool IsGroupGemm = !cute::is_same_v; + + using FusionOp = typename Gemm::EpilogueOutputOp; + + static_assert(rank(InternalStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(InternalStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + static_assert(is_row_or_col_major(), + "ERROR : C Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : D Layout is neither Row / Column Major)"); + + // Deduce Cutlass Layouts (RowMajor & ColumnMajor) + using LayoutTagC = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagD = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagScalar = cutlass::layout::PackedVectorLayout; // scalars are size-1 vectors + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + using ElementAccumulator = typename kernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename kernel::ProblemShape; + using ElementCompute = typename ElementComputeType::Type; + using ElementScalar = typename ElementScalarType::Type; + + using Arguments = typename Gemm::GemmKernel::EpilogueArguments; + + /// Initialization + cutlass::DeviceAllocation stride_c_device; + cutlass::DeviceAllocation stride_d_device; + + std::vector stride_c_host; + std::vector stride_d_host; + + typename LayoutTagC::Stride stride_factor_C; + typename LayoutTagD::Stride stride_factor_D; + + // Inputs + ElementScalar alpha; + ElementScalar beta; + + std::vector> tensors_C; + std::vector> tensors_D; + std::vector> references_D; + cutlass::DeviceAllocation device_tensors_C; + cutlass::DeviceAllocation device_tensors_D; + + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + // Are scalars copied to device memory before kernel launch + ScalarLoc use_device_scalars = ScalarLoc::ON_HOST; + // If per-row scale is enabled and this is true, beta is passed as a host scalar instead of device vector + VectorBeta disable_vector_beta = VectorBeta::DISABLED; + + cutlass::Distribution::Kind init_C; + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + HostCollectiveDefaultEpilogue( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorBeta disable_vector_beta_ = VectorBeta::DISABLED, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): init_C(init_C_), seed(seed_), + stride_factor_C(typename LayoutTagC::Stride()), + stride_factor_D(typename LayoutTagD::Stride()), + check_relative_equality(check_relative_equality_), + use_device_scalars(use_device_scalars_){ } + + bool initialize(ProblemShapeType problem_shapes, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { + // Initialize Epilogue tensors + + tensors_C.clear(); + tensors_D.clear(); + references_D.clear(); + stride_c_host.clear(); + stride_d_host.clear(); + + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = max(problem_shapes.groups(), L); + + for (int32_t i = 0; i < L; ++i) { + auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); + + stride_c_host.push_back(cutlass::make_cute_packed_stride(InternalStrideC{}, {M, N, 1})); + stride_d_host.push_back(cutlass::make_cute_packed_stride(InternalStrideD{}, {M, N, 1})); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto c_coord = cutlass::make_Coord(M, N); + + tensors_C.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C))); + tensors_D.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D))); + references_D.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D), false)); + EXPECT_TRUE(initialize_tensor(tensors_C[i].host_view(), init_C, seed + 2020)); + tensors_C[i].host_view().at({0, 0}) = ElementC(1); + + cutlass::reference::host::TensorCopy(references_D[i].host_view(), tensors_C[i].host_view()); + tensors_C[i].sync_device(); + tensors_D[i].sync_device(); + } + alpha = alpha_; + beta = beta_; + + return true; + } + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + + // Factors used for calculating relative equality. CUTLASS's relative-equality + // checks in include/cutlass/relatively_equal.h are inspired by + // https://floating-point-gui.de/errors/comparison/. This reference suggests using + // the minimum normal value of a given type as the nonzero_floor. + Element epsilon(static_cast(0.1f)); + Element nonzero_floor(std::numeric_limits::min()); + + if constexpr (!cutlass::is_complex::value) { + if (check_relative_equality == CheckEquality::RELATIVE) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, epsilon, nonzero_floor); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + bool compare_reference( + ProblemShapeType problem_shapes, + ElementScalar alpha, + ElementScalar beta, + int batch) { + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = max(problem_shapes.groups(), L); + + tensors_D[batch].sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_C[batch].host_view()), 0); + + if (tensors_D[batch].size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_D[batch].host_view()), 0); + } + + if (references_D[batch].size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(references_D[batch].host_view()), 0); + } + + bool passed = equality_check(references_D[batch].host_view(), tensors_D[batch].host_view()); + if(!passed) { + std::cout<<"D is incorrect"<(problem_shapes.get_host_problem_shape(0), 1); + L = max(problem_shapes.groups(), L); + + std::vector ptr_C_host(L); + std::vector ptr_D_host(L); + + for (int32_t i = 0; i < L; ++i) { + ptr_C_host.at(i) = tensors_C[i].device_data(); + ptr_D_host.at(i) = tensors_D[i].device_data(); + } + + device_tensors_C.reset(L); + device_tensors_C.copy_from_host(ptr_C_host.data()); + + device_tensors_D.reset(L); + device_tensors_D.copy_from_host(ptr_D_host.data()); + + stride_c_device.reset(problem_shapes.groups()); + stride_c_device.copy_from_host(stride_c_host.data()); + + stride_d_device.reset(problem_shapes.groups()); + stride_d_device.copy_from_host(stride_d_host.data()); + + Arguments arguments; + if constexpr (IsGroupGemm) { + arguments = + { + {alpha, beta}, + device_tensors_C.get(), stride_c_device.get(), device_tensors_D.get(), stride_d_device.get() + }; + } + else { + arguments = + { + {alpha, beta}, + device_tensors_C.get(), stride_c_host[0], device_tensors_D.get(), stride_d_host[0] + }; + } + + return arguments; + } + + auto to_host_args(ProblemShapeType problem_shapes, int batch) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = max(problem_shapes.groups(), L); + + auto coord_0 = cutlass::make_Coord(0); + auto C = cute::make_tensor(detail::make_iterator(tensors_C[batch].host_data()), + cute::make_layout(cute::make_shape(M, N, 1), stride_c_host[batch])); + auto D = cute::make_tensor(detail::make_iterator(references_D[batch].host_data()), + cute::make_layout(cute::make_shape(M, N, 1), stride_d_host[batch])); + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D)> + epilogue_params{}; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = alpha; + epilogue_params.beta = beta; + + return epilogue_params; + } +}; + +template +struct HostCollectiveEpilogue { + // fusion types are potentially void if the fusion is not supported + // helper so we don't try to construct HostTensor with void type + template + using non_void_t = cute::conditional_t, U, T>; + + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using kernel = typename Gemm::GemmKernel; + using Epilogue = typename kernel::CollectiveEpilogue; + static_assert(IsDefaultEpilogue::value == false, "Default Epilogue is not supported"); + + using ElementD = typename kernel::ElementD; + using StrideD = typename kernel::StrideD; + using InternalStrideD = typename kernel::InternalStrideD; + using ElementC = non_void_t; + using StrideC = typename kernel::StrideC; + using InternalStrideC = typename kernel::InternalStrideC; + + static constexpr bool IsGroupGemm = !cute::is_same_v; + + static_assert(rank(InternalStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(InternalStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + static_assert(is_row_or_col_major(), + "ERROR : C Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : D Layout is neither Row / Column Major)"); + + // Deduce Cutlass Layouts (RowMajor & ColumnMajor) + using LayoutTagC = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagD = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagScalar = cutlass::layout::PackedVectorLayout; // scalars are size-1 vectors + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + using ElementAccumulator = typename kernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename kernel::ProblemShape; + + // + // FusionOperation derived types/queries + // + using EpiloguePolicy = typename Epilogue::DispatchPolicy; + static constexpr bool IsLegacy = + cute::is_same_v< + EpiloguePolicy, + cutlass::epilogue::Sm90TmaWarpSpecializedBiasElementwise< + EpiloguePolicy::StagesC, EpiloguePolicy::StagesD, EpiloguePolicy::FragmentSize> + >; + + using FusionOp = typename Gemm::EpilogueOutputOp; + static_assert(cute::is_base_of_v); + + using ElementCompute = typename FusionOp::ElementCompute; + using ElementScalar = typename FusionOp::ElementScalar; + using ElementBias = non_void_t; + using ElementAux = non_void_t; + using ElementAmax = non_void_t; + using LayoutTagAux = non_void_t; + using ActivationFunctor = non_void_t>; + + static constexpr bool IsBiasEnabled = FusionOp::IsPerRowBiasSupported; + static constexpr bool IsDeBiasEnabled = FusionOp::IsDePerRowBiasSupported; + static constexpr bool IsPerRowScaleEnabled = FusionOp::IsPerRowScaleSupported; + static constexpr bool IsScaleFactorEnabled = FusionOp::IsScaleFactorSupported; + static constexpr bool IsAuxInEnabled = FusionOp::IsAuxInSupported; + static constexpr bool IsAuxOutEnabled = FusionOp::IsAuxOutSupported; + static constexpr bool IsAbsMaxEnabledD = FusionOp::IsAbsMaxSupported && + (cute::is_same_v || + cute::is_same_v); + static constexpr bool IsAbsMaxEnabledAux = IsAuxOutEnabled && FusionOp::IsAbsMaxSupported && + (cute::is_same_v || + cute::is_same_v); + + using Arguments = typename Gemm::GemmKernel::EpilogueArguments; + + /// Initialization + cutlass::DeviceAllocation stride_c_device; + cutlass::DeviceAllocation stride_d_device; + + std::vector stride_c_host; + std::vector stride_d_host; + + typename LayoutTagC::Stride stride_factor_C; + typename LayoutTagD::Stride stride_factor_D; + + // Inputs + cutlass::HostTensor alpha; + cutlass::HostTensor beta; + cutlass::HostTensor scale_A; + cutlass::HostTensor scale_B; + cutlass::HostTensor scale_C; + cutlass::HostTensor scale_D; + cutlass::HostTensor scale_Aux; + cutlass::HostTensor bias; + std::vector> tensors_C; + cutlass::DeviceAllocation device_tensors_C; + cutlass::HostTensor norm_constant; + + // Outputs + cutlass::HostTensor abs_max_Aux; + cutlass::HostTensor abs_max_D; + std::vector> tensors_Aux; + cutlass::DeviceAllocation device_tensors_Aux; + cutlass::gemm::TagToStrideC_t< LayoutTagAux > stride_Aux; + std::vector> tensors_D; + std::vector> references_D; + cutlass::DeviceAllocation device_tensors_D; + + // References + cutlass::HostTensor reference_dbias; + std::vector> references_Aux; + cutlass::HostTensor reference_abs_max_Aux; + cutlass::HostTensor reference_abs_max_D; + + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + // Are scalars copied to device memory before kernel launch + ScalarLoc use_device_scalars = ScalarLoc::ON_HOST; + // If per-row scale is enabled and this is true, beta is passed as a host scalar instead of device vector + VectorBeta disable_vector_beta = VectorBeta::DISABLED; + + // Random distribution with which to initialize the A/B/C/D/Aux scaling factors + cutlass::Distribution::Kind init_scale = cutlass::Distribution::Uniform; + // Random distribution with which to initialize the bias vector + cutlass::Distribution::Kind init_bias = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind init_C; + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + HostCollectiveEpilogue( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorBeta disable_vector_beta_ = VectorBeta::DISABLED, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): init_scale(init_scale_), init_bias(init_bias_), + init_C(init_C_), seed(seed_), + stride_factor_C(typename LayoutTagC::Stride()), + stride_factor_D(typename LayoutTagD::Stride()), + check_relative_equality(check_relative_equality_), + use_device_scalars(use_device_scalars_){ } + + bool initialize(ProblemShapeType problem_shapes, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { + // Initialize Epilogue tensors + + tensors_C.clear(); + tensors_D.clear(); + references_D.clear(); + stride_c_host.clear(); + stride_d_host.clear(); + + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = max(problem_shapes.groups(), L); + + for (int32_t i = 0; i < L; ++i) { + auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); + + stride_c_host.push_back(cutlass::make_cute_packed_stride(InternalStrideC{}, {M, N, 1})); + stride_d_host.push_back(cutlass::make_cute_packed_stride(InternalStrideD{}, {M, N, 1})); + + auto c_coord = cutlass::make_Coord(M, N); + tensors_C.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C))); + tensors_D.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D))); + references_D.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D), false)); + EXPECT_TRUE(initialize_tensor(tensors_C[i].host_view(), init_C, seed + 2020)); + tensors_C[i].host_view().at({0, 0}) = ElementC(1); + + cutlass::reference::host::TensorCopy(references_D[i].host_view(), tensors_C[i].host_view()); + tensors_C[i].sync_device(); + tensors_D[i].sync_device(); + } + + auto scalar_coord = cutlass::make_Coord(1); + auto col_vector_coord = cutlass::make_Coord(M); + if constexpr (IsPerRowScaleEnabled) { + alpha.resize(col_vector_coord); + EXPECT_TRUE(initialize_tensor(alpha.host_view(), init_scale, seed + 2023)); + if (disable_vector_beta == VectorBeta::DISABLED) { + beta.resize(scalar_coord, false); + cutlass::reference::host::TensorFill(beta.host_view(), beta_); + } + else { + beta.resize(col_vector_coord); + EXPECT_TRUE(initialize_tensor(beta.host_view(), init_scale, seed + 2024)); + } + } + else { + alpha.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + beta.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + cutlass::reference::host::TensorFill(alpha.host_view(), alpha_); + cutlass::reference::host::TensorFill(beta.host_view(), beta_); + } + alpha.sync_device(); + beta.sync_device(); + + if constexpr (IsScaleFactorEnabled) { + scale_A.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + scale_B.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + scale_C.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + scale_D.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + EXPECT_TRUE(initialize_tensor(scale_A.host_view(), init_scale, seed + 2023)); + EXPECT_TRUE(initialize_tensor(scale_B.host_view(), init_scale, seed + 2024)); + EXPECT_TRUE(initialize_tensor(scale_C.host_view(), init_scale, seed + 2025)); + EXPECT_TRUE(initialize_tensor(scale_D.host_view(), init_scale, seed + 2026)); + scale_A.sync_device(); + scale_B.sync_device(); + scale_C.sync_device(); + scale_D.sync_device(); + } + + if constexpr (IsBiasEnabled) { + bias.resize(col_vector_coord); + EXPECT_TRUE(initialize_tensor(bias.host_view(), init_bias, seed + 2023)); + bias.sync_device(); + } + + if constexpr (IsDeBiasEnabled) { + bias.resize(col_vector_coord); + reference_dbias.resize(col_vector_coord); + cutlass::reference::host::TensorFill(bias.host_view(), ElementBias(0)); + cutlass::reference::host::TensorFill(reference_dbias.host_view(), ElementBias(0)); + bias.sync_device(); + } + + if constexpr (IsAbsMaxEnabledD) { + abs_max_D.resize(scalar_coord); + // ensure in-place device reductions perform their own initialization + cutlass::reference::host::TensorFill(abs_max_D.host_view(), + CUTLASS_STL_NAMESPACE::numeric_limits::max()); + abs_max_D.sync_device(); + reference_abs_max_D.resize(scalar_coord); + cutlass::reference::host::TensorFill(reference_abs_max_D.host_view(), ElementAmax(0)); + } + + tensors_Aux.clear(); + references_Aux.clear(); + + static_assert(!IsGroupGemm or (IsGroupGemm and !IsAuxInEnabled)); + + if constexpr (IsAuxInEnabled) { + auto aux_coord = cutlass::make_Coord(M, N); + auto aux_layout = cutlass::layout::Affine2Layout_Factory::layout_factory(aux_coord, typename LayoutTagAux::Stride{}); + for (int32_t i = 0; i < L; ++i) { + tensors_Aux.push_back(cutlass::HostTensor(aux_coord, aux_layout)); + EXPECT_TRUE(initialize_tensor(tensors_Aux[i].host_view(), init_C, seed + 2023)); + tensors_Aux[i].sync_device(); + } + stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t{}, cute::make_shape(M, N, 1)); + } + + static_assert(!IsGroupGemm or (IsGroupGemm and IsAuxOutEnabled)); + + if constexpr (IsAuxOutEnabled) { + for (int32_t i = 0; i < L; ++i) { + auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto aux_coord = cutlass::make_Coord(M, N); + auto aux_layout = cutlass::layout::Affine2Layout_Factory::layout_factory(aux_coord, typename LayoutTagAux::Stride{}); + tensors_Aux.push_back(cutlass::HostTensor(aux_coord, aux_layout)); + references_Aux.push_back(cutlass::HostTensor(aux_coord, aux_layout, false)); + tensors_Aux[i].sync_device(); + } + + stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t{}, cute::make_shape(M, N, 1)); + + if constexpr (IsScaleFactorEnabled) { + scale_Aux.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + EXPECT_TRUE(initialize_tensor(scale_Aux.host_view(), init_scale, seed + 2027)); + scale_Aux.sync_device(); + } + + if constexpr (IsAbsMaxEnabledAux) { + abs_max_Aux.resize(scalar_coord); + // ensure in-place device reductions perform their own initialization + cutlass::reference::host::TensorFill(abs_max_Aux.host_view(), + CUTLASS_STL_NAMESPACE::numeric_limits::max()); + abs_max_Aux.sync_device(); + reference_abs_max_Aux.resize(scalar_coord); + cutlass::reference::host::TensorFill(reference_abs_max_Aux.host_view(), ElementAmax(0)); + } + } + + return true; + } + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + + // Factors used for calculating relative equality. CUTLASS's relative-equality + // checks in include/cutlass/relatively_equal.h are inspired by + // https://floating-point-gui.de/errors/comparison/. This reference suggests using + // the minimum normal value of a given type as the nonzero_floor. + Element epsilon(static_cast(0.1f)); + Element nonzero_floor(std::numeric_limits::min()); + + if constexpr (!cutlass::is_complex::value) { + if (check_relative_equality == CheckEquality::RELATIVE) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, epsilon, nonzero_floor); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + bool compare_reference( + ProblemShapeType problem_shapes, + ElementScalar alpha, + ElementScalar beta, + int batch) { + tensors_D[batch].sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_C[batch].host_view()), 0); + + if (tensors_D[batch].size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_D[batch].host_view()), 0); + } + + if (references_D[batch].size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(references_D[batch].host_view()), 0); + } + + bool passed = equality_check(references_D[batch].host_view(), tensors_D[batch].host_view()); + if(!passed) { + std::cout<<"D is incorrect"<(problem_shapes.get_host_problem_shape(0), 1); + L = max(problem_shapes.groups(), L); + + std::vector ptr_C_host(L); + std::vector ptr_D_host(L); + + for (int32_t i = 0; i < L; ++i) { + ptr_C_host.at(i) = tensors_C[i].device_data(); + ptr_D_host.at(i) = tensors_D[i].device_data(); + } + + device_tensors_C.reset(L); + device_tensors_C.copy_from_host(ptr_C_host.data()); + + device_tensors_D.reset(L); + device_tensors_D.copy_from_host(ptr_D_host.data()); + + stride_c_device.reset(problem_shapes.groups()); + stride_c_device.copy_from_host(stride_c_host.data()); + + stride_d_device.reset(problem_shapes.groups()); + stride_d_device.copy_from_host(stride_d_host.data()); + + std::vector ptr_Aux_host(L); + if constexpr (IsAuxInEnabled || IsAuxOutEnabled) { + for (int32_t i = 0; i < L; ++i) { + ptr_Aux_host.at(i) = tensors_Aux[i].device_data(); + } + device_tensors_Aux.reset(L); + device_tensors_Aux.copy_from_host(ptr_Aux_host.data()); + } + + Arguments arguments; + if constexpr (IsGroupGemm) { + arguments = + { + {}, + device_tensors_C.get(), stride_c_device.get(), device_tensors_D.get(), stride_d_device.get() + }; + } + else { + arguments = + { + {}, + device_tensors_C.get(), stride_c_host[0], device_tensors_D.get(), stride_d_host[0] + }; + } + + auto &fusion_args = arguments.thread; + if constexpr (IsLegacy) { + arguments.thread = { + alpha.at(coord_0), + beta.at(coord_0), + alpha.device_data(), + beta.device_data() + }; + arguments.ptr_Bias = bias.device_data(); + arguments.ptr_T = device_tensors_Aux.get(); + } + else { + fusion_args.alpha = alpha.at(coord_0); + fusion_args.beta = beta.at(coord_0); + fusion_args.alpha_ptr = alpha.device_data(); + fusion_args.beta_ptr = beta.device_data(); // if disable_vector_beta is true this is nullptr + + if constexpr (IsScaleFactorEnabled) { + fusion_args.scale_a = scale_A.at(coord_0); + fusion_args.scale_b = scale_B.at(coord_0); + fusion_args.scale_c = scale_C.at(coord_0); + fusion_args.scale_d = scale_D.at(coord_0); + fusion_args.scale_a_ptr = scale_A.device_data(); + fusion_args.scale_b_ptr = scale_B.device_data(); + fusion_args.scale_c_ptr = scale_C.device_data(); + fusion_args.scale_d_ptr = scale_D.device_data(); + } + + if constexpr (IsBiasEnabled) { + fusion_args.bias_ptr = bias.device_data(); + } + + if constexpr (IsDeBiasEnabled) { + fusion_args.dbias_ptr = bias.device_data(); + } + + // example of how to set kernel activation arguments + // see ActivationFunctor::Arguments in activation.h for definition + // if Arguments doesn't exist then fusion_args.activation is empty + if constexpr (cute::is_same_v>) { + fusion_args.activation.scale = ElementCompute(1); + } + + // Treat Clamp as ReLU + if constexpr (cute::is_same_v>) { + fusion_args.activation.lower_bound = 0; + fusion_args.activation.upper_bound = std::numeric_limits::max(); + } + + if constexpr (IsAbsMaxEnabledD) { + fusion_args.amax_D_ptr = abs_max_D.device_data(); + } + + if constexpr (IsAuxInEnabled) { + fusion_args.aux_ptr = device_tensors_Aux.get(); + fusion_args.dAux = stride_Aux; + } + + if constexpr (IsAuxOutEnabled) { + fusion_args.aux_ptr = device_tensors_Aux.get(); + fusion_args.dAux = stride_Aux; + if constexpr (IsScaleFactorEnabled) { + fusion_args.scale_aux = scale_Aux.at(coord_0); + fusion_args.scale_aux_ptr = scale_Aux.device_data(); + } + if constexpr (IsAbsMaxEnabledAux) { + fusion_args.amax_aux_ptr = abs_max_Aux.device_data(); + } + } + } + + return arguments; + } + + auto to_host_args(ProblemShapeType problem_shapes, int batch) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); + auto coord_0 = cutlass::make_Coord(0); + auto C = cute::make_tensor(detail::make_iterator(tensors_C[batch].host_data()), + cute::make_layout(cute::make_shape(M, N, 1), stride_c_host[batch])); + auto D = cute::make_tensor(detail::make_iterator(references_D[batch].host_data()), + cute::make_layout(cute::make_shape(M, N, 1), stride_d_host[batch])); + auto Bias = cute::make_tensor(detail::make_iterator(IsDeBiasEnabled ? reference_dbias.host_data() : bias.host_data()), + cute::make_layout(cute::make_shape(M, cute::_1{}))); + auto Aux = cute::make_tensor(detail::make_iterator(IsAuxInEnabled ? tensors_Aux[batch].host_data() : references_Aux[batch].host_data()), + cute::make_layout(cute::make_shape(M, N, 1), stride_Aux)); + auto Valpha = cute::make_tensor(detail::make_iterator(alpha.host_data()), + cute::make_layout(cute::make_shape(M, cute::_1{}))); + auto Vbeta = cute::make_tensor(detail::make_iterator(beta.host_data()), + cute::make_layout(cute::make_shape(M, cute::_1{}))); + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D), + decltype(Bias), + decltype(Aux), + decltype(Valpha), + decltype(Vbeta), + ActivationFunctor + > epilogue_params{}; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = alpha.at(coord_0); + epilogue_params.beta = beta.at(coord_0); + + if constexpr (IsScaleFactorEnabled) { + epilogue_params.scale_a = scale_A.at(coord_0); + epilogue_params.scale_b = scale_B.at(coord_0); + epilogue_params.scale_c = scale_C.at(coord_0); + epilogue_params.scale_d = scale_D.at(coord_0); + } + + if constexpr (IsBiasEnabled or IsDeBiasEnabled) { + epilogue_params.Bias = Bias; + } + + if constexpr (IsAbsMaxEnabledD) { + epilogue_params.abs_max_D = reference_abs_max_D.host_data(); + } + + if constexpr (IsAuxInEnabled) { + epilogue_params.Aux = Aux; + } + + if constexpr (IsAuxOutEnabled) { + epilogue_params.Aux = Aux; + if constexpr (IsScaleFactorEnabled) { + epilogue_params.scale_aux = scale_Aux.at(coord_0); + } + if constexpr (IsAbsMaxEnabledAux) { + epilogue_params.abs_max_Aux = reference_abs_max_Aux.host_data(); + } + } + + if constexpr (IsPerRowScaleEnabled) { + epilogue_params.Valpha = Valpha; + if (disable_vector_beta == VectorBeta::ENABLED) { + epilogue_params.Vbeta = Vbeta; + } + } + return epilogue_params; + } +}; + +template < + typename Gemm, + template class ActivationFunctor_ = cutlass::epilogue::thread::Identity, + bool force_legacy_epilogue = false, + typename ElementA = typename Gemm::GemmKernel::ElementA, + typename ElementB = typename Gemm::GemmKernel::ElementB +> +struct TestbedImpl { + // Kernel data types + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + // All Collective MMA operands are defined by HostCollectiveMainloopType based on the schedule type + using HostCollectiveMainloopType = HostCollectiveMainloop; + using CollectiveEpilogue = cute::conditional_t::value || force_legacy_epilogue, + HostCollectiveDefaultEpilogue, + HostCollectiveEpilogue>; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementCompute = typename ElementComputeType::Type; + using ElementScalar = typename ElementScalarType::Type; + + using LayoutTagA = typename HostCollectiveMainloopType::LayoutTagA; + using LayoutTagB = typename HostCollectiveMainloopType::LayoutTagB; + using LayoutTagC = typename CollectiveEpilogue::LayoutTagC; + using LayoutTagD = typename CollectiveEpilogue::LayoutTagD; + + uint32_t sm_count; + // Used to force multi-wave tests for persistent kernel schedules + constexpr static int MaxSmCount = 16; + static constexpr uint64_t kDefaultSeed = 4096; + static constexpr uint32_t mma_promotion_interval = 4; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + + HostCollectiveMainloopType collective_mma_inputs; + CollectiveEpilogue collective_epilogue; + + static constexpr bool IsGroupGemm = CollectiveEpilogue::IsGroupGemm; + + // + // Methods + // + + TestbedImpl( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorBeta disable_vector_beta_ = VectorBeta::DISABLED, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): collective_mma_inputs(HostCollectiveMainloopType(check_relative_equality_, init_A_, init_B_, seed_)), + collective_epilogue(CollectiveEpilogue(check_relative_equality_, use_device_scalars_, disable_vector_beta_, init_C_, init_scale_, init_bias_, seed_)) { } + + TestbedImpl( + typename LayoutTagA::Stride stride_factor_A_, + typename LayoutTagB::Stride stride_factor_B_, + typename LayoutTagC::Stride stride_factor_C_, + typename LayoutTagD::Stride stride_factor_D_, + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorBeta disable_vector_beta_ = VectorBeta::DISABLED, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): collective_mma_inputs(HostCollectiveMainloopType(check_relative_equality_, stride_factor_A_, stride_factor_B_, init_A_, init_B_, seed_)), + collective_epilogue(CollectiveEpilogue(check_relative_equality_, use_device_scalars_, disable_vector_beta_, init_C_, init_scale_, init_bias_, seed_)) { } + + /// Initializes data structures + bool initialize(ProblemShapeType problem_shapes, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { + collective_mma_inputs.initialize(problem_shapes); + collective_epilogue.initialize(problem_shapes, alpha_, beta_); + + return true; + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + ProblemShapeType problem_shapes, + ElementScalar alpha, + ElementScalar beta, + int batch) + { + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); + + bool passed = collective_mma_inputs.compare_reference(problem_shapes, batch); + passed &= collective_epilogue.compare_reference(problem_shapes, alpha, beta, batch); + EXPECT_TRUE(passed); + if (!passed) { + std::stringstream fname; + fname << "error_Gemm_device_" + << M << "x" << N << "x" << K << "x" << batch << "_" + << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; + + std::ofstream file(fname.str()); + file + << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << batch + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + collective_mma_inputs.print_tensors(file, batch); + collective_epilogue.print_tensors(file, batch); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + ProblemShapeType problem_shapes, + ElementScalar alpha, + ElementScalar beta) + { + using namespace cute; + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = max(problem_shapes.groups(), L); + + bool passed = true; + for (int32_t i = 0; i < L; ++i) { + auto mainloop_params = collective_mma_inputs.to_host_args(problem_shapes, i); + auto epilogue_params = collective_epilogue.to_host_args(problem_shapes, i); + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + passed &= compare_reference(problem_shapes, alpha, beta, i); + } + return passed; + } + + /// Determine if the CUDA device is sufficient to run the kernel + bool sufficient() { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = static_cast(Gemm::GemmKernel::SharedStorageSize); + + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_idx); + this->sm_count = properties.multiProcessorCount; + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } + + return true; + } + + /// Executes one test + bool run( + ProblemShapeType problem_shapes, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + detail::Iterations iterations = detail::Iterations{} + ) + { + + // Fail test if insufficient CUDA device + if (!sufficient()) { + std::cout << "Test failed due to insufficient CUDA device." << std::endl; + return false; + } + + if (!this->initialize(problem_shapes, alpha, beta)) { + std::cerr << "Initialization failed \n"; + return false; + } + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments; + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + this->sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + hw_info.sm_count = this->sm_count; + + typename HostCollectiveMainloopType::Arguments mainloop_args; + + mainloop_args = collective_mma_inputs.to_args(problem_shapes); + + if constexpr (IsGroupGemm) { + arguments = + { + cutlass::gemm::GemmUniversalMode::kGrouped, + problem_shapes, + mainloop_args, + collective_epilogue.to_args(problem_shapes), + hw_info + }; + } + else { + arguments = + { + cutlass::gemm::GemmUniversalMode::kArray, + problem_shapes, + mainloop_args, + collective_epilogue.to_args(problem_shapes), + hw_info + }; + } + + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return false; + } + + // + // Run the GEMM + // + + cudaError_t result; + status = gemm_op.initialize(arguments, workspace.get()); + status = gemm_op.run(); + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + bool passed = this->verify(problem_shapes, alpha, beta); + if (!passed) { + std::cout << "Error : Failed : with alpha: " << alpha << ", beta: " << beta + << "\n"; + } + + return passed; + } +}; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + template class ActivationFunctor = cutlass::epilogue::thread::Identity, + bool force_legacy_epilogue = false, + typename ElementA = typename Gemm::GemmKernel::ElementA, + typename ElementB = typename Gemm::GemmKernel::ElementB +> +struct Testbed3x { + + using TestBedImpl = typename detail::TestbedImpl< + Gemm, + ActivationFunctor, + force_legacy_epilogue, + ElementA, + ElementB + >; + using Kernel = typename Gemm::GemmKernel; + using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; + + using ElementAccumulator = typename TestBedImpl::ElementAccumulator; + using ElementCompute = typename TestBedImpl::ElementCompute; + using ElementScalar = typename TestBedImpl::ElementScalar; + + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + + static constexpr bool IsGroupGemm = TestBedImpl::IsGroupGemm; + + // Detail Implementation + TestBedImpl impl_; + + // + // Methods + // + Testbed3x( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_DEVICE, + VectorBeta disable_vector_beta_ = VectorBeta::DISABLED, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed) + : impl_(check_relative_equality_, use_device_scalars_, disable_vector_beta_, init_A_, init_B_, init_C_, init_scale_, init_bias_, seed_) {} + + /// Executes one test + bool run( + typename TestBedImpl::ProblemShapeType problem_shapes, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + detail::Iterations iterations = detail::Iterations{} + ) + { + return impl_.run( + problem_shapes, alpha, beta, iterations); + } +}; + +template < + typename Gemm, + template class ActivationFunctor = cutlass::epilogue::thread::Identity +> +bool TestAll(double alpha = 1.0, double beta = 0.0, CheckEquality check_relative_equality = CheckEquality::RELATIVE) { + using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + Testbed3x testbed(check_relative_equality, ScalarLoc::ON_DEVICE, VectorBeta::DISABLED); + + int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + std::vector problem_size_m = {max_alignment, 512 - 3 * max_alignment}; + std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment}; + + constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; + constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); + + std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment}; + + int batches[] = {5, 10}; + + bool passed = true; + + for (int batch : batches) { + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + + if constexpr (Testbed3x::IsGroupGemm) { + std::vector problem_sizes_host; + cutlass::DeviceAllocation problem_sizes_device; + + for (int i = 0; i < batch; ++i) { + problem_sizes_host.push_back({m, n, k}); + } + + problem_sizes_device.reset(problem_sizes_host.size()); + problem_sizes_device.copy_from_host(problem_sizes_host.data()); + + passed = testbed.run( + ProblemShapeType{static_cast(problem_sizes_host.size()), problem_sizes_device.get(), problem_sizes_host.data()}, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + } + else { + ProblemShapeType problem_size{{m, n, k, batch}}; + + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + } + + if (!passed) { + std::cout << __FILE__ << ':' << __LINE__ << " : GEMM MNKL " << m << " " << n << " " << k << " " << batch << " FAILED.\n"; + return false; + } + } // k + } // n + } // m + } // batch + + return passed; +} + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/sm90_evt_operations.hpp b/test/unit/gemm/device/sm90_evt_operations.hpp index 5ac8e659d8..201ba72539 100644 --- a/test/unit/gemm/device/sm90_evt_operations.hpp +++ b/test/unit/gemm/device/sm90_evt_operations.hpp @@ -409,14 +409,7 @@ using Sm90LinCombPerColumnBias = Sm90EVT, // alpha * acc + bias Sm90ScalarBroadcast, // alpha Sm90AccFetch, // acc - Sm90RowBroadcast< - ceil_div( - EpilogueDescriptor::StagesC, - size(shape_div(take<0, 2>(typename EpilogueDescriptor::TileShape{}), typename EpilogueDescriptor::EpilogueTile{})) - ) + 1, - typename EpilogueDescriptor::TileShape, - ElementBias - > + Sm90RowBroadcast<0, typename EpilogueDescriptor::TileShape, ElementBias> > >; diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_aux_store.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_aux_store.cu index cb650af177..81436adace 100644 --- a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_aux_store.cu +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_aux_store.cu @@ -130,14 +130,10 @@ bool testEVTAuxStoreWithoutD() { aux_store_D_block.reset(m * n); Gemm gemm_op_base; - auto stride_A = cutlass::make_cute_packed_stride( - typename GemmKernel::StrideA{}, cute::make_shape(m, k, cute::Int<1>{})); - auto stride_B = cutlass::make_cute_packed_stride( - typename GemmKernel::StrideB{}, cute::make_shape(n, k, cute::Int<1>{})); - auto stride_C = cutlass::make_cute_packed_stride( - typename GemmKernel::StrideC{}, cute::make_shape(m, n, cute::Int<1>{})); - auto stride_D = cutlass::make_cute_packed_stride( - typename GemmKernel::StrideD{}, cute::make_shape(m, n, cute::Int<1>{})); + auto stride_A = cutlass::make_cute_packed_stride(typename GemmKernel::StrideA{}, {m, k, 1}); + auto stride_B = cutlass::make_cute_packed_stride(typename GemmKernel::StrideB{}, {n, k, 1}); + auto stride_C = cutlass::make_cute_packed_stride(typename GemmKernel::StrideC{}, {m, n, 1}); + auto stride_D = cutlass::make_cute_packed_stride(typename GemmKernel::StrideD{}, {m, n, 1}); auto arguments_base = typename Gemm::Arguments { cutlass::gemm::GemmUniversalMode::kGemm, diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm.cu new file mode 100644 index 0000000000..b93d936865 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm.cu @@ -0,0 +1,120 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Ptr-Array GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_ptr_array.hpp" + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +using namespace cute; + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_group_gemm, 128x128x64_2x2x1) { + +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size +using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster +using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementC, LayoutC *, AlignmentC, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestAll(1.0, 1.0); + EXPECT_TRUE(result); +} + +#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu new file mode 100644 index 0000000000..dc581acf7f --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu @@ -0,0 +1,179 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Ptr-Array GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_ptr_array.hpp" + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +using namespace cute; + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_ptr_array, 128x128x64_2x2x1) { + +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_128,_128,_64>; // Threadblock-level tile size +using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster +using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementC, LayoutC, AlignmentC, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool result = TestAll(1.0, 1.0); + EXPECT_TRUE(result); +} + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_ptr_array, 128x128x64_2x2x1_NoSmemEpi) { + +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_128,_128,_64>; // Threadblock-level tile size +using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster +using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementC, LayoutC, AlignmentC, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(TestAll(1.0, 0.0)); +} + +#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_stream_k_scheduler.cu b/test/unit/gemm/device/sm90_gemm_stream_k_scheduler.cu index 66b48386e9..e447c7a295 100644 --- a/test/unit/gemm/device/sm90_gemm_stream_k_scheduler.cu +++ b/test/unit/gemm/device/sm90_gemm_stream_k_scheduler.cu @@ -114,13 +114,14 @@ test_scheduler( << size<3>(problem_shape_mnkl) << " and grid size " << grid.x << "x" << grid.y << "x" << grid.z - << " splits=" << params.splits_ + << " splits=" << params.divmod_splits_.divisor << " k_iter=" << params.divmod_tiles_per_output_tile_.divisor << " big_units_=" << params.big_units_ << " big_groups_=" << params.big_groups_ << " sk_tiles=" << params.sk_tiles_ << " sk_units=" << params.sk_units_ - << " k_tiles_per_sk_unit=" << params.k_tiles_per_sk_unit_ + << " k_tiles_per_sk_unit=" << params.divmod_k_tiles_per_sk_unit_.divisor + << " k_tiles_per_sk_big_unit=" << params.divmod_k_tiles_per_sk_big_unit_.divisor << " units_per_problem=" << params.units_per_problem_ << " groups=" << params.divmod_sk_groups_.divisor << std::endl; }; diff --git a/test/unit/gemm/device/testbed_sparse.h b/test/unit/gemm/device/testbed_sparse.h index bf2d2d3d74..eeac68c0a6 100644 --- a/test/unit/gemm/device/testbed_sparse.h +++ b/test/unit/gemm/device/testbed_sparse.h @@ -128,8 +128,8 @@ struct SparseTestbed { scope_max = 2; scope_min = 0; } else if (bits_input <= 8) { - scope_max = 2; - scope_min = -2; + scope_max = 1; + scope_min = -1; } else if (bits_output == 16) { scope_max = 5; scope_min = -5; @@ -353,14 +353,25 @@ struct SparseTestbed { // typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, problem_size, - tensor_A.device_ref(), - tensor_B.device_ref(), - tensor_C.device_ref(), - tensor_D.device_ref(), - tensor_E_reordered.device_ref(), + split_k_slices, {alpha, beta}, - split_k_slices + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_D.device_data(), + tensor_E_reordered.device_data(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0), + tensor_E_reordered.layout().stride(0) }; Gemm gemm_op; @@ -391,7 +402,7 @@ struct SparseTestbed { bool passed = this->verify(problem_size, alpha, beta); if (!passed) { - std::cout << "Error with split_k_slices = " << split_k_slices << ", alpha: " << alpha << std::endl; + std::cout << "Error with split_k_slices = " << split_k_slices << ", alpha: " << alpha << ", beta: " << beta << ", m: " << problem_size.m() << ", n: " << problem_size.n() << ", k:" < 1) { - continue; - } - - if (split_k > 1 && k / Gemm::ThreadblockShape::kK < split_k) { - continue; - } - for (auto alpha : problem_alpha) { for (auto beta : problem_beta) { - cutlass::gemm::GemmCoord problem_size(m, n, k); passed = testbed.run( diff --git a/test/unit/gemm/device/testbed_with_absmax.h b/test/unit/gemm/device/testbed_with_absmax.h index 1224a9229f..2bccba4f3c 100644 --- a/test/unit/gemm/device/testbed_with_absmax.h +++ b/test/unit/gemm/device/testbed_with_absmax.h @@ -212,6 +212,9 @@ struct TestbedWithAmax { EXPECT_GT(cutlass::reference::host::TensorNorm(underlying_testbed.tensor_D.host_view()), 0); EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), underlying_testbed.tensor_D.host_view()); + if (!passed) { + std::cout << "Comparison of D failed" << std::endl; + } if (kScaleAux) { tensor_Aux.sync_host(); @@ -219,14 +222,23 @@ struct TestbedWithAmax { EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Aux.host_view()), 0); EXPECT_GT(cutlass::reference::host::TensorNorm(abs_max_Aux.host_view()), 0); EXPECT_GT(cutlass::reference::host::TensorNorm(reference_Aux.host_view()), 0); - passed &= cutlass::reference::host::TensorEquals(reference_Aux.host_view(), tensor_Aux.host_view()); - passed &= cutlass::reference::host::TensorEquals(abs_max_Aux.host_view(), reference_abs_max_Aux.host_view()); + if (!cutlass::reference::host::TensorEquals(reference_Aux.host_view(), tensor_Aux.host_view())) { + passed = false; + std::cout << "Comparison of Aux failed" << std::endl; + } + if (!cutlass::reference::host::TensorEquals(abs_max_Aux.host_view(), reference_abs_max_Aux.host_view())) { + passed = false; + std::cout << "Comparison of Aux absmax failed" << std::endl; + } } if (kScaleOutput) { abs_max_D.sync_host(); EXPECT_GT(cutlass::reference::host::TensorNorm(abs_max_D.host_view()), 0); - passed &= cutlass::reference::host::TensorEquals(abs_max_D.host_view(), reference_abs_max_D.host_view()); + if (!cutlass::reference::host::TensorEquals(abs_max_D.host_view(), reference_abs_max_D.host_view())) { + passed = false; + std::cout << "Comparison of D absmax failed" << std::endl; + } } EXPECT_TRUE(passed) << " mismatched reference"; @@ -417,16 +429,31 @@ struct TestbedWithAmax { auto arguments = [&]() { if constexpr (IsSparseTestbed) { return typename Gemm::Arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, problem_size, - underlying_testbed.tensor_A.device_ref(), - underlying_testbed.tensor_B.device_ref(), - underlying_testbed.tensor_C.device_ref(), - underlying_testbed.tensor_D.device_ref(), - underlying_testbed.tensor_E_reordered.device_ref(), - tensor_Aux.device_ref(), + batch_count, + epilogue_params, + underlying_testbed.tensor_A.device_data(), + underlying_testbed.tensor_B.device_data(), + underlying_testbed.tensor_C.device_data(), + underlying_testbed.tensor_D.device_data(), + underlying_testbed.tensor_E_reordered.device_data(), + tensor_Aux.device_data(), tensor_Vector.device_data(), - 0, // stride vector - epilogue_params + int64_t(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + underlying_testbed.tensor_A.layout().stride(0), + underlying_testbed.tensor_B.layout().stride(0), + underlying_testbed.tensor_C.layout().stride(0), + underlying_testbed.tensor_D.layout().stride(0), + underlying_testbed.tensor_E_reordered.layout().stride(0), + tensor_Aux.layout().stride(0), + 0 // stride vector }; } else { @@ -522,35 +549,47 @@ bool TestAllGemmWithAbsmax(bool scaleA=true, bool scaleB=true, bool scaleC=true) int M_problems[] = {kAlignmentM, 128 + 32}; int N_problems[] = {kAlignmentN, 512 - 2 * kAlignmentN}; - int K_problems[] = {Gemm::ThreadblockShape::kK, Gemm::ThreadblockShape::kK * (Gemm::kStages + 1)}; + int K_problems[] = {Gemm::ThreadblockShape::kK * 2}; double alpha_problems[] = {1.}; double beta_problems[] = {0.}; + int split_k_slices[] = { + 1, 2 + }; bool passed = true; for (int M : M_problems) { for (int N : N_problems) { for (int K : K_problems) { - for (double alpha : alpha_problems) { - for (double beta : beta_problems) { - TestbedWithAmax testbed(scaleA, scaleB, scaleC); + for (int split_k : split_k_slices) { + if (cutlass::sizeof_bits_v <= 8 && split_k > 1) { + // Don't test split-K with FP8 output. The kernel being tested will writie partial accumulations + // for different splits to global memory in FP8, while the reference kernel will not. This leads + // to mismatches that are difficult to capture without a permissive relative equality check threshold. + continue; + } + + for (double alpha : alpha_problems) { + for (double beta : beta_problems) { + TestbedWithAmax testbed(scaleA, scaleB, scaleC); - using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementAccumulator = typename Gemm::ElementAccumulator; - passed = testbed.run( - cutlass::gemm::GemmUniversalMode::kGemm, - {M, N, K}, - 1, - cutlass::from_real(alpha), - cutlass::from_real(beta) - ); + passed = testbed.run( + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + split_k, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); - EXPECT_TRUE(passed) - << "M: " << M << ", N: " << N << ", K: " << K << ", alpha: " << alpha << ", beta: " << beta; + EXPECT_TRUE(passed) + << "M: " << M << ", N: " << N << ", K: " << K << ", alpha: " << alpha << ", beta: " << beta << ", split_k:" << split_k; - if (!passed) { + if (!passed) { - return passed; + return passed; + } } } } diff --git a/test/unit/gemm/threadblock/mma_multistage_sparse.cu b/test/unit/gemm/threadblock/mma_multistage_sparse.cu index 46b8a1e3fd..1625146b12 100644 --- a/test/unit/gemm/threadblock/mma_multistage_sparse.cu +++ b/test/unit/gemm/threadblock/mma_multistage_sparse.cu @@ -179,6 +179,111 @@ TEST(SM80_sparse_gemm_threadblock_congruous, //////////////////////////////////////////////////////////////////////////////// +TEST(SM80_sparse_gemm_threadblock_congruous, + tensor_op_128x32x64_32x32x64_16x8x32_4stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 32, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 64>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_congruous, + tensor_op_32x256x128_32x64x128_16x8x32_4stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(32, 256, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 128>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_congruous, + tensor_op_128x16x64_32x16x64_16x8x32_4stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 16, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 16, 64>; + using WarpShape = cutlass::gemm::GemmShape<32, 16, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + TEST(SM80_sparse_gemm_threadblock_congruous, tensor_op_128x128x64_64x64x64_16x8x32_4stage) { using ElementA = cutlass::half_t; diff --git a/test/unit/gemm/warp/gemm_mixed_input_sm80.cu b/test/unit/gemm/warp/gemm_mixed_input_sm80.cu index e9a00a88ad..eb7d8023d0 100644 --- a/test/unit/gemm/warp/gemm_mixed_input_sm80.cu +++ b/test/unit/gemm/warp/gemm_mixed_input_sm80.cu @@ -104,7 +104,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_f16, 128x128x64_64x64x64_ using Shape = cutlass::gemm::GemmShape<64, 64, 64>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; using ElementA = int8_t; - using ElementB = cutlass::half_t;; + using ElementB = cutlass::half_t; using ElementC = float; using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< cutlass::sizeof_bits::value, 64>; diff --git a/test/unit/gemm/warp/gemm_sparse_sm80.cu b/test/unit/gemm/warp/gemm_sparse_sm80.cu index d3acf71c0d..f7f83e94ef 100644 --- a/test/unit/gemm/warp/gemm_sparse_sm80.cu +++ b/test/unit/gemm/warp/gemm_sparse_sm80.cu @@ -327,6 +327,48 @@ TEST(SM80_warp_gemm_sparse_tensor_op_congruous_f16, 128x128x64_32x32x64_16x8x32) //////////////////////////////////////////////////////////////////////////////// +TEST(SM80_warp_gemm_sparse_tensor_op_congruous_f16, 64x32x64_32x32x64_16x8x32) { + using Shape = cutlass::gemm::GemmShape<32, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_congruous_f16, 64x16x64_32x16x64_16x8x32) { + using Shape = cutlass::gemm::GemmShape<32, 16, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 16>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + TEST(SM80_warp_gemm_sparse_tensor_op_congruous_f16, 128x64x128_64x32x128_16x8x32) { using Shape = cutlass::gemm::GemmShape<64, 32, 128>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; @@ -402,7 +444,7 @@ TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 128x128x128_64x64x128_16x8x64 using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, - cutlass::layout::RowMajor>::Type; + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; test::gemm::warp::SparseTestbed >() @@ -423,7 +465,7 @@ TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 128x128x128_64x32x128_16x8x64 using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, - cutlass::layout::RowMajor>::Type; + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; test::gemm::warp::SparseTestbed >() @@ -444,7 +486,7 @@ TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 128x128x128_32x64x128_16x8x64 using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, - cutlass::layout::RowMajor>::Type; + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; test::gemm::warp::SparseTestbed >() @@ -465,7 +507,7 @@ TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 128x128x128_32x32x128_16x8x64 using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, - cutlass::layout::RowMajor>::Type; + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; test::gemm::warp::SparseTestbed >() @@ -486,7 +528,7 @@ TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 128x128x128_32x16x128_16x8x64 using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, - cutlass::layout::RowMajor>::Type; + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; test::gemm::warp::SparseTestbed >() @@ -507,7 +549,7 @@ TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 128x64x256_64x32x256_16x8x64) using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, - cutlass::layout::RowMajor>::Type; + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; test::gemm::warp::SparseTestbed >() @@ -528,7 +570,7 @@ TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 64x128x256_32x64x256_16x8x64) using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, - cutlass::layout::RowMajor>::Type; + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; test::gemm::warp::SparseTestbed >() @@ -549,7 +591,7 @@ TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 64x64x256_32x32x256_16x8x64) using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, - cutlass::layout::RowMajor>::Type; + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; test::gemm::warp::SparseTestbed >() @@ -570,7 +612,7 @@ TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 64x32x256_32x16x256_16x8x64) using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, - cutlass::layout::RowMajor>::Type; + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; test::gemm::warp::SparseTestbed >() @@ -927,7 +969,7 @@ TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 128x128x256_64x64x256_16x8x12 using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, - cutlass::layout::RowMajor>::Type; + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; test::gemm::warp::SparseTestbed >() @@ -948,7 +990,7 @@ TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 128x128x256_64x32x256_16x8x12 using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, - cutlass::layout::RowMajor>::Type; + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; test::gemm::warp::SparseTestbed >() @@ -969,7 +1011,7 @@ TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 128x128x256_32x64x256_16x8x12 using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, - cutlass::layout::RowMajor>::Type; + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; test::gemm::warp::SparseTestbed >() @@ -990,7 +1032,7 @@ TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 128x128x256_32x32x256_16x8x12 using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, - cutlass::layout::RowMajor>::Type; + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; test::gemm::warp::SparseTestbed >() @@ -1011,7 +1053,7 @@ TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 128x128x256_32x16x256_16x8x12 using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, - cutlass::layout::RowMajor>::Type; + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; test::gemm::warp::SparseTestbed >() @@ -1032,7 +1074,7 @@ TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 128x64x512_64x32x512_16x8x128 using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, - cutlass::layout::RowMajor>::Type; + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; test::gemm::warp::SparseTestbed >() @@ -1053,7 +1095,7 @@ TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 64x128x512_32x64x512_16x8x128 using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, - cutlass::layout::RowMajor>::Type; + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; test::gemm::warp::SparseTestbed >() @@ -1074,7 +1116,7 @@ TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 64x64x512_32x32x512_16x8x128) using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, - cutlass::layout::RowMajor>::Type; + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; test::gemm::warp::SparseTestbed >() @@ -1095,7 +1137,7 @@ TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 64x32x512_32x16x512_16x8x128) using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, - cutlass::layout::RowMajor>::Type; + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>::Type; test::gemm::warp::SparseTestbed >() diff --git a/test/unit/nvrtc/kernel/thread/contraction.hpp b/test/unit/nvrtc/kernel/thread/contraction.hpp index d2dddfe026..f90e882e4b 100644 --- a/test/unit/nvrtc/kernel/thread/contraction.hpp +++ b/test/unit/nvrtc/kernel/thread/contraction.hpp @@ -49,16 +49,16 @@ struct ContractionKernel { using ElementScalar = float; using ElementAccum = float; -using EpilogueThread = cutlass::epilogue::thread::LinearCombination; +using EpilogueThread = cutlass::epilogue::thread::LinearCombination; static constexpr cute::GMMA::Major majorA = ! kTransA ? cute::GMMA::Major::MN : cute::GMMA::Major::K; static constexpr cute::GMMA::Major majorB = ! kTransB ? cute::GMMA::Major::K : cute::GMMA::Major::MN; /// Kernel config -typedef int64_t stride_type; +typedef int64_t stride_type; typedef int32_t extent_type; static constexpr const stride_type* stride_null = nullptr; @@ -117,7 +117,7 @@ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< using EpilogueOutputOp = cutlass::epilogue::collective::DefaultEpilogue; using CollectiveEpilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter; using Kernel = cutlass::gemm::kernel::GemmUniversal< - ProblemShape, + ProblemShape, CollectiveOp, CollectiveEpilogue>; diff --git a/test/unit/transform/CMakeLists.txt b/test/unit/transform/CMakeLists.txt index 0d768258b5..4912eca2c3 100644 --- a/test/unit/transform/CMakeLists.txt +++ b/test/unit/transform/CMakeLists.txt @@ -27,15 +27,18 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. add_subdirectory(threadblock) +add_subdirectory(kernel) add_custom_target( cutlass_test_unit_transform DEPENDS cutlass_test_unit_transform_threadblock - ) + cutlass_test_unit_transform_filter_format +) add_custom_target( test_unit_transform DEPENDS test_unit_transform_threadblock - ) + test_unit_transform_kernel +) diff --git a/test/unit/transform/kernel/CMakeLists.txt b/test/unit/transform/kernel/CMakeLists.txt new file mode 100644 index 0000000000..d337b31ed9 --- /dev/null +++ b/test/unit/transform/kernel/CMakeLists.txt @@ -0,0 +1,32 @@ +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_test_unit_add_executable( + cutlass_test_unit_transform_filter_format + filter_format_transformer.cu +) diff --git a/test/unit/transform/kernel/filter_format_transformer.cu b/test/unit/transform/kernel/filter_format_transformer.cu new file mode 100644 index 0000000000..ce489afd06 --- /dev/null +++ b/test/unit/transform/kernel/filter_format_transformer.cu @@ -0,0 +1,261 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests cutlass::transform::kernel::ConvFilterFormatTransformer +*/ + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/cutlass.h" + +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/kernel/filter_format_transformer.hpp" +#include "cutlass/transform/device/transform_universal_adapter.hpp" + +#include "thrust/universal_vector.h" +#include "thrust/host_vector.h" +#include "thrust/device_vector.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +auto verify_ckrs_to_crsk(thrust::host_vector const &S, thrust::host_vector const &D, Shape_S shape_s) { + using namespace cute; + + int32_t errors = 0; + int32_t const kErrorLimit = 10; + + if (S.size() != D.size()) { + return false; + } + + auto shape_d = select<2, 0, 1, 3>(shape_s); + + for (int i = 0; i < (int)S.size(); ++i) { + auto [s, r, k, c] = idx2crd(i, shape_s); + auto d_idx = crd2idx(make_coord(k, s, r, c), shape_d); + + if (S[i] != D[d_idx]) { + std::cerr << "Error. S[" << i << "]: " << S[i] << ", D[" << d_idx << "]: " << D[d_idx] << std::endl; + + if (++errors >= kErrorLimit) { + std::cerr << "Aborting on " << kErrorLimit << "nth error." << std::endl; + return false; + } + } + } + + return errors == 0; +} + +template +auto verify_ckrs_to_krsc(thrust::host_vector const &S, thrust::host_vector const &D, Shape_S shape_s) { + using namespace cute; + + int32_t errors = 0; + int32_t const kErrorLimit = 10; + + if (S.size() != D.size()) { + return false; + } + + auto shape_d = select<3, 0, 1, 2>(shape_s); + + for (int i = 0; i < (int)S.size(); ++i) { + auto [s, r, k, c] = idx2crd(i, shape_s); + auto d_idx = crd2idx(make_coord(c, s, r, k), shape_d); + + if (S[i] != D[d_idx]) { + std::cerr << "Error. S[" << i << "]: " << S[i] << ", D[" << d_idx << "]: " << D[d_idx] << std::endl; + + if (++errors >= kErrorLimit) { + std::cerr << "Aborting on " << kErrorLimit << "nth error." << std::endl; + return false; + } + } + } + + return errors == 0; +} + +template +bool transform_test() { + using namespace cute; + + using TransformKernel = cutlass::transform::kernel::ConvFilterFormatTransformer; + using Transform = cutlass::transform::device::TransformUniversalAdapter; + + auto s = 3; + auto r = 3; + auto k = 64 + Alignment / (int)(sizeof(Element)); + auto c = 64 + Alignment / (int)(sizeof(Element)); + + thrust::host_vector h_S(s * r * k * c); + thrust::host_vector h_D(s * r * k * c); + + // + // Initialize + // + + for (int i = 0; i < (int)h_S.size(); ++i) { + h_S[i] = static_cast(i); + h_D[i] = Element{}; + } + + thrust::device_vector d_S = h_S; + thrust::device_vector d_D = h_D; + + Transform transform_op; + + const void* src_ptr = static_cast(d_S.data().get()); + void* dst_ptr = static_cast(d_D.data().get()); + + typename TransformKernel::FilterExtent filter_extent; + filter_extent[0] = k; + filter_extent[1] = r; + filter_extent[2] = s; + filter_extent[3] = c; + + auto args = typename Transform::Arguments { + src_ptr, + dst_ptr, + filter_extent + }; + + cutlass::Status status = cutlass::Status::kInvalid; + + size_t workspace_size = Transform::get_workspace_size(args); + thrust::universal_vector workspace(workspace_size); + + status = transform_op.initialize(args, workspace.data().get()); + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return false; + } + + status = transform_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " Kernel execution error: " + << cudaGetErrorString(result); + + // Verification + h_D = d_D; + auto tensor_shape_S = make_shape(s, r, k, c); + + bool passed = false; + if constexpr(DstFormat == cutlass::transform::kernel::FilterFormat::KTRSC) { + // KTRSC + passed = verify_ckrs_to_krsc(h_S, h_D, tensor_shape_S); + } + else if constexpr(DstFormat == cutlass::transform::kernel::FilterFormat::CTRSK) { + // CTRSK; + passed = verify_ckrs_to_crsk(h_S, h_D, tensor_shape_S); + } + + return passed; +} + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + +TEST(Transform_kernel_ConvFilterFormatTransformer, ckrs_to_crsk) { + bool passed = true; + + // fp16 kernel with alignment bytes from 16 to 2. + passed &= transform_test(); + passed &= transform_test(); + passed &= transform_test(); + passed &= transform_test(); + + // fp8 kernel with alignment bytes from 16 to 1. + passed &= transform_test(); + passed &= transform_test(); + passed &= transform_test(); + passed &= transform_test(); + passed &= transform_test(); + + // int8 kernel with alignment bytes from 16 to 1. + passed &= transform_test(); + passed &= transform_test(); + passed &= transform_test(); + passed &= transform_test(); + passed &= transform_test(); + + // fp32 kernel with alignment bytes from 16 to 4. + passed &= transform_test(); + passed &= transform_test(); + passed &= transform_test(); + + EXPECT_TRUE(passed); +} + +// CKRS -> KRSC +TEST(Transform_kernel_ConvFilterFormatTransformer, ckrs_to_krsc) { + bool passed = true; + + // fp16 kernel with alignment bytes from 16 to 2. + passed &= transform_test(); + passed &= transform_test(); + passed &= transform_test(); + passed &= transform_test(); + + // fp8 kernel with alignment bytes from 16 to 1. + passed &= transform_test(); + passed &= transform_test(); + passed &= transform_test(); + passed &= transform_test(); + passed &= transform_test(); + + // int8 kernel with alignment bytes from 16 to 1. + passed &= transform_test(); + passed &= transform_test(); + passed &= transform_test(); + passed &= transform_test(); + passed &= transform_test(); + + // fp32 kernel with alignment bytes from 16 to 4. + passed &= transform_test(); + passed &= transform_test(); + passed &= transform_test(); + + EXPECT_TRUE(passed); +} + +#endif diff --git a/tools/library/CMakeLists.txt b/tools/library/CMakeLists.txt index 60a6cca599..f8a28fe6b9 100644 --- a/tools/library/CMakeLists.txt +++ b/tools/library/CMakeLists.txt @@ -278,6 +278,7 @@ execute_process( --architectures "${CUTLASS_NVCC_ARCHS_ENABLED}" --kernels "${CUTLASS_LIBRARY_KERNELS}" --ignore-kernels "${CUTLASS_LIBRARY_IGNORE_KERNELS}" + --exclude-kernels "${CUTLASS_LIBRARY_EXCLUDE_KERNELS}" --kernel-filter-file "${KERNEL_FILTER_FILE}" --selected-kernel-list "${CUTLASS_LIBRARY_GENERATED_KERNEL_LIST_FILE}" --cuda-version "${CUTLASS_GENERATOR_CUDA_COMPILER_VERSION}" diff --git a/tools/library/include/cutlass/library/library.h b/tools/library/include/cutlass/library/library.h index d2432f959c..c609367cb2 100644 --- a/tools/library/include/cutlass/library/library.h +++ b/tools/library/include/cutlass/library/library.h @@ -292,6 +292,7 @@ struct GemmUniversalArguments { int sm_count{0}; library::RasterOrder raster_order{}; + int swizzle_size{1}; }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/src/conv_operation_3x.hpp b/tools/library/src/conv_operation_3x.hpp index 4d84765bf1..15fb330300 100644 --- a/tools/library/src/conv_operation_3x.hpp +++ b/tools/library/src/conv_operation_3x.hpp @@ -595,12 +595,10 @@ class ConvOperation3x : public Operation { const TensorStride stride_A = vector_to_array_strides(config.stride_a, the_stride_size); const TensorStride stride_B = vector_to_array_strides(config.stride_b, the_stride_size); - const TensorStride stride_C = vector_to_array_strides(config.stride_c, the_stride_size); // cutlass::library::Conv2dConfiguration has no member stride_d. // The code below imitates the testbed, // which just sets D's strides to C's strides. - const TensorStride stride_D = stride_C; const int num_groups = config.problem_size.groups; if (num_groups != 1) { @@ -773,9 +771,7 @@ class ConvOperation3x : public Operation { const TensorStride stride_A = coord_to_array_strides(input_stride_a); const TensorStride stride_B = coord_to_array_strides(input_stride_b); - const TensorStride stride_C = coord_to_array_strides(input_stride_c); - const TensorStride stride_D = stride_C; const int num_groups = config.problem_size.groups; if (num_groups != 1) { CUTLASS_TRACE_HOST("CUTLASS 3 kernels currently only support groups = 1."); @@ -852,6 +848,12 @@ class ConvOperation3x : public Operation { std::cerr << "ConvOperation3x::update_operator_arguments_from_arguments\n"; #endif + auto status = UpdateFusionArgs::update_( + out_args.epilogue.thread, in_args); + if (status != Status::kSuccess) { + return status; + } + out_args.mainloop.ptr_A = reinterpret_cast(in_args.A); out_args.mainloop.ptr_B = reinterpret_cast(in_args.B); diff --git a/tools/library/src/gemm_operation_3x.hpp b/tools/library/src/gemm_operation_3x.hpp index e50f3a1bc8..4f743f74b7 100644 --- a/tools/library/src/gemm_operation_3x.hpp +++ b/tools/library/src/gemm_operation_3x.hpp @@ -250,6 +250,10 @@ class GemmUniversal3xOperation : public GemmOperation3xBase { /* Query device SM count to pass onto the kernel as an argument, where needed */ operator_args.hw_info.sm_count = arguments->sm_count; + if constexpr (!std::is_const_v) { + operator_args.scheduler.max_swizzle_size = arguments->swizzle_size; + } + if constexpr (!std::is_const_v) { using Enum_t = decltype(operator_args.scheduler.raster_order); switch (arguments->raster_order) { diff --git a/tools/library/src/reference/conv_reference_operation.h b/tools/library/src/reference/conv_reference_operation.h index d5fa06e74c..ab924b5f01 100644 --- a/tools/library/src/reference/conv_reference_operation.h +++ b/tools/library/src/reference/conv_reference_operation.h @@ -489,7 +489,7 @@ template < typename InnerProductOp_ = multiply_add > void make_conv_fprop(Manifest &manifest) { - +#if !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) manifest.append(new ConvReferenceOperation< Provider::kReferenceHost, cutlass::conv::Operator::kFprop, @@ -515,6 +515,7 @@ void make_conv_fprop(Manifest &manifest) { ConvertOp_, InnerProductOp_ >); +#endif // !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) } /// Constructs Dgrad and Wgrad reference operators. @@ -532,7 +533,7 @@ template < typename InnerProductOp_ = multiply_add > void make_conv_backwards(Manifest &manifest) { - +#if !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) manifest.append(new ConvReferenceOperation< Provider::kReferenceHost, cutlass::conv::Operator::kDgrad, @@ -584,6 +585,7 @@ void make_conv_backwards(Manifest &manifest) { ConvertOp_, InnerProductOp_ >); +#endif // !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) } /// Six operators for the price of one. diff --git a/tools/library/src/reference/gemm_reference_operation.h b/tools/library/src/reference/gemm_reference_operation.h index f6b5d911dd..fd58d4f0ac 100644 --- a/tools/library/src/reference/gemm_reference_operation.h +++ b/tools/library/src/reference/gemm_reference_operation.h @@ -293,7 +293,7 @@ template < typename InnerProductOp_ = multiply_add > void make_gemm(Manifest &manifest) { - +#if !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) manifest.append(new GemmReferenceOperation< Provider::kReferenceHost, ElementA_, LayoutA_, TransformA, @@ -317,6 +317,7 @@ void make_gemm(Manifest &manifest) { ConvertOp_, InnerProductOp_ >); +#endif } /// Helper to create NN, NT, TN, and TT GEMM layouts. diff --git a/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h b/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h index 7408701801..8e1292f986 100644 --- a/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h +++ b/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h @@ -84,6 +84,8 @@ class GemmOperationProfiler : public OperationProfiler { int batch_count{1}; cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic}; + int swizzle_size{1}; + // gemm with parallel interleaved reduction // gemm epilogue (alpha, beta) = (1.0, 0.0) // reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta) diff --git a/tools/profiler/include/cutlass/profiler/options.h b/tools/profiler/include/cutlass/profiler/options.h index a05357e5ca..e945d17344 100644 --- a/tools/profiler/include/cutlass/profiler/options.h +++ b/tools/profiler/include/cutlass/profiler/options.h @@ -72,7 +72,7 @@ class Options { // Methods // - Library(CommandLine const &cmdline); + explicit Library(CommandLine const &cmdline); void print_usage(std::ostream &out) const; void print_options(std::ostream &out, int indent = 0) const; @@ -94,7 +94,7 @@ class Options { // Methods // - Device(CommandLine const &cmdline); + explicit Device(CommandLine const &cmdline); void print_usage(std::ostream &out) const; void print_options(std::ostream &out, int indent = 0) const; @@ -128,7 +128,7 @@ class Options { // Methods // - Initialization(CommandLine const &cmdline); + explicit Initialization(CommandLine const &cmdline); void print_usage(std::ostream &out) const; void print_options(std::ostream &out, int indent = 0) const; @@ -170,7 +170,7 @@ class Options { // Methods // - Verification(CommandLine const &cmdline); + explicit Verification(CommandLine const &cmdline); void print_usage(std::ostream &out) const; void print_options(std::ostream &out, int indent = 0) const; @@ -186,22 +186,26 @@ class Options { struct Profiling { /// Number of workspaces to rotate through to avoid cache-resident working sets - int workspace_count; + int workspace_count{0}; /// Number of iterations to warmup each kernel prior to profiling - int warmup_iterations; + int warmup_iterations{10}; /// Number of iterations to profile each kernel - if 0, kernels are launched up to the profiling duration - int iterations; + int iterations{100}; /// Number of ms to sleep between profiling periods (ms) - int sleep_duration; + int sleep_duration{50}; /// If true, profiling is actually conducted. - bool enabled; + bool enabled{true}; /// If true, profiling returns an error code if no kernels are found to match the filters. - bool error_on_no_match = false; + bool error_on_no_match{false}; + + /// If true, profiling returns an error code if no kernel are profiled + // Sometimes the kernel matches but failed to profile (e.g. can_implement() error) + bool error_if_nothing_is_profiled{false}; /// List of providers of each functionality to be profiled ProviderVector providers; @@ -210,7 +214,7 @@ class Options { // Methods // - Profiling(CommandLine const &cmdline); + explicit Profiling(CommandLine const &cmdline); void print_usage(std::ostream &out) const; void print_options(std::ostream &out, int indent = 0) const; @@ -255,7 +259,7 @@ class Options { // Methods // - Report(CommandLine const &cmdline); + explicit Report(CommandLine const &cmdline); void print_usage(std::ostream &out) const; void print_options(std::ostream &out, int indent = 0) const; @@ -277,7 +281,7 @@ class Options { // Methods // - About(CommandLine const &cmdline); + explicit About(CommandLine const &cmdline); void print_usage(std::ostream &out) const; void print_options(std::ostream &out, int indent = 0) const; @@ -320,7 +324,7 @@ class Options { public: - Options(CommandLine const &cmdline); + explicit Options(CommandLine const &cmdline); void print_usage(std::ostream &out) const; void print_options(std::ostream &out) const; diff --git a/tools/profiler/include/cutlass/profiler/problem_space.h b/tools/profiler/include/cutlass/profiler/problem_space.h index 23143d50dc..00391c9b13 100644 --- a/tools/profiler/include/cutlass/profiler/problem_space.h +++ b/tools/profiler/include/cutlass/profiler/problem_space.h @@ -256,7 +256,7 @@ struct ScalarArgument : public KernelArgument { virtual std::ostream &print(std::ostream &out) const; }; - using ValueCollection = std::vector; + using ValueCollection = std::vector; /// Abstract base class to iterate over values within arguments struct ScalarValueIterator : public KernelArgument::ValueIterator { @@ -271,7 +271,7 @@ struct ScalarArgument : public KernelArgument { // Methods // - ScalarValueIterator(ScalarArgument const *argument = nullptr); + explicit ScalarValueIterator(ScalarArgument const *argument = nullptr); virtual void operator++(); virtual bool operator==(ValueIterator const &it) const; @@ -292,7 +292,7 @@ struct ScalarArgument : public KernelArgument { // /// Default ctor - ScalarArgument( + explicit ScalarArgument( ArgumentDescription const *description ): KernelArgument(description) { } @@ -632,7 +632,7 @@ struct TensorArgument : public KernelArgument { // Methods // - TensorValueIterator(TensorArgument const *argument_); + explicit TensorValueIterator(TensorArgument const *argument_); virtual void operator++(); virtual bool operator==(ValueIterator const &it) const; @@ -649,7 +649,7 @@ struct TensorArgument : public KernelArgument { // /// Default ctor - TensorArgument( + explicit TensorArgument( ArgumentDescription const *description ): KernelArgument(description) { } @@ -690,7 +690,7 @@ struct EnumeratedTypeArgument : public KernelArgument { virtual std::ostream &print(std::ostream &out) const; }; - using ValueCollection = std::vector; + using ValueCollection = std::vector; /// Abstract base class to iterate over values within arguments struct EnumeratedTypeValueIterator : public KernelArgument::ValueIterator { @@ -705,7 +705,7 @@ struct EnumeratedTypeArgument : public KernelArgument { // Methods // - EnumeratedTypeValueIterator(EnumeratedTypeArgument const *argument_ = nullptr); + explicit EnumeratedTypeValueIterator(EnumeratedTypeArgument const *argument_ = nullptr); virtual void operator++(); virtual bool operator==(ValueIterator const &it) const; @@ -725,7 +725,7 @@ struct EnumeratedTypeArgument : public KernelArgument { // /// Default ctor - EnumeratedTypeArgument(ArgumentDescription const *description): + explicit EnumeratedTypeArgument(ArgumentDescription const *description): KernelArgument(description) {} virtual bool not_null() const { @@ -819,7 +819,7 @@ class ProblemSpace { // /// Default ctor - ProblemSpace() {} + ProblemSpace() = default; /// Constructs a problem space from a vector of arguments. This vector must outlive /// the ProblemSpace object, which stores pointers to objects within the diff --git a/tools/profiler/src/cutlass_profiler.cu b/tools/profiler/src/cutlass_profiler.cu index 8889bb107b..ebdaf66e2e 100644 --- a/tools/profiler/src/cutlass_profiler.cu +++ b/tools/profiler/src/cutlass_profiler.cu @@ -143,16 +143,19 @@ void CutlassProfiler::enumerate_() { /// Profiles all operations int CutlassProfiler::profile_() { - int result = 0; + // Keep track of all device memory tensor in map DeviceContext device_context; - // For all profilers + + int result = 0; + // For all profilers (e.g. gemm/sparse_gemm/conv2d...) for (auto & profiler : operation_profilers_) { if (options_.operation_kind == library::OperationKind::kInvalid || - options_.operation_kind == profiler->kind()) { + options_.operation_kind == profiler->kind()) { result = profiler->profile_all(options_, library::Singleton::get().manifest, device_context); + // If some profile failed, terminate immediately if (result) { return result; } diff --git a/tools/profiler/src/gemm_operation_profiler.cu b/tools/profiler/src/gemm_operation_profiler.cu index daee075656..39628d6bfa 100644 --- a/tools/profiler/src/gemm_operation_profiler.cu +++ b/tools/profiler/src/gemm_operation_profiler.cu @@ -75,6 +75,7 @@ GemmOperationProfiler::GemmOperationProfiler(Options const &options): {ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"}, {ArgumentTypeID::kInteger, {"batch_count", "batch-count"}, "Number of GEMMs computed in one batch"}, {ArgumentTypeID::kEnumerated, {"raster_order", "raster-order"}, "Raster order (heuristic, along_n, along_m)"}, + {ArgumentTypeID::kInteger, {"swizzle_size", "swizzle-size"}, "Size to swizzle"}, }, { library::Provider::kCUBLAS} ) { @@ -191,6 +192,14 @@ Status GemmOperationProfiler::GemmProblem::parse( this->mode = library::GemmUniversalMode::kBatched; } + if (!arg_as_int(this->swizzle_size, "swizzle_size", problem_space, problem)) { + // default value + this->swizzle_size = 1; + if (this->swizzle_size <= 0) { + return Status::kErrorInvalidProblem; + } + } + if (!arg_as_RasterOrder(this->raster_order, "raster_order", problem_space, problem)) { // default value this->raster_order = library::RasterOrder::kHeuristic; @@ -329,6 +338,8 @@ void GemmOperationProfiler::GemmProblem::initialize_result( set_argument(result, "split_k_slices", problem_space, split_k_slices); set_argument(result, "batch_count", problem_space, batch_count); set_argument(result, "raster_order", problem_space, library::to_string(raster_order)); + set_argument(result, "swizzle_size", problem_space, swizzle_size); + set_argument(result, "alpha", problem_space, library::lexical_cast(alpha, operation_desc.element_epilogue)); @@ -383,6 +394,7 @@ Status GemmOperationProfiler::initialize_configuration( gemm_workspace_.arguments.alpha = problem_.alpha.data(); gemm_workspace_.arguments.beta = problem_.beta.data(); gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; + gemm_workspace_.arguments.swizzle_size = problem_.swizzle_size; gemm_workspace_.arguments.raster_order = problem_.raster_order; // initialize reduction operation for parallel splitKMode if (problem_.split_k_mode == library::SplitKMode::kParallel) { diff --git a/tools/profiler/src/operation_profiler.cu b/tools/profiler/src/operation_profiler.cu index 71d1e55251..daf73f4bfe 100644 --- a/tools/profiler/src/operation_profiler.cu +++ b/tools/profiler/src/operation_profiler.cu @@ -345,6 +345,7 @@ int OperationProfiler::profile_all( // For each operation in manifest int matched_operation_count = 0; + int profiled_operation_count = 0; for (auto const& operation_ptr : manifest) { library::Operation const *operation = operation_ptr.get(); @@ -434,7 +435,7 @@ int OperationProfiler::profile_all( // If there was an internal error, consume the CUDA error and move to the next operation. (void)cudaGetLastError(); - report.append_results(results_); + report.append_result(model_result_); continue; } else if (status != Status::kSuccess) { @@ -522,25 +523,42 @@ int OperationProfiler::profile_all( operation, problem_space, problem); + + // Count op as profiled, even it failed to profile + profiled_operation_count++; } report.append_results(results_); results_.clear(); - } + } // if op satisfied compute capacity if (!continue_profiling) { + // break out of `for op in manifest` loop and move to next problem + // `for each problem in problem space` conditional check on not continue profiling break; } - } + } // for op in manifest // If we did not find any kernels that match our filters and error_on_no_match was set, report an error if (options.profiling.error_on_no_match && matched_operation_count <= 0) { #if !NDEBUG - std::cout << "Error: No matching kernels found with kernel selection filters [--error_on_no_match]" << std::endl; + std::cerr << "Error: No matching kernels found with kernel selection filters [--error_on_no_match]" << std::endl; #endif - retval = 1; + retval |= 1; + // Stop profiling on error no match + continue_profiling = false; } - } + + if (options.profiling.error_if_nothing_is_profiled && options.profiling.enabled && profiled_operation_count <= 0) { + #if !NDEBUG + std::cerr << "Error: No kernels profiled found with kernel selection filters [--error_if_nothing_is_profiled]" << std::endl; + #endif + retval |= 1; + // Stop profiling on error no match + continue_profiling = false; + } + + } // for each problem in problem space return retval; } diff --git a/tools/profiler/src/options.cu b/tools/profiler/src/options.cu index e98dd27bc3..e2259aa008 100644 --- a/tools/profiler/src/options.cu +++ b/tools/profiler/src/options.cu @@ -726,11 +726,13 @@ Options::Options(cutlass::CommandLine const &cmdline): else if (cmdline.check_cmd_line_flag("kernels")) { cmdline.get_cmd_line_arguments("kernels", operation_names); profiling.error_on_no_match = cmdline.check_cmd_line_flag("error-on-no-match"); + profiling.error_if_nothing_is_profiled = cmdline.check_cmd_line_flag("error-if-nothing-is-profiled"); } if (cmdline.check_cmd_line_flag("ignore-kernels")) { cmdline.get_cmd_line_arguments("ignore-kernels", excluded_operation_names); profiling.error_on_no_match = cmdline.check_cmd_line_flag("error-on-no-match"); + profiling.error_if_nothing_is_profiled = cmdline.check_cmd_line_flag("error-if-nothing-is-profiled"); } // Prevent launches on the device for anything other than CUTLASS operation diff --git a/tools/profiler/src/problem_space.cpp b/tools/profiler/src/problem_space.cpp index f4b1c9b0ba..bd76bdbb3a 100644 --- a/tools/profiler/src/problem_space.cpp +++ b/tools/profiler/src/problem_space.cpp @@ -395,10 +395,6 @@ std::unique_ptr EnumeratedTypeArgument::end() con ////////////////////////////////////////////////////////////////////////////////////////////////// -ProblemSpace::Iterator::Iterator() { - -} - ProblemSpace::Iterator::Iterator(ProblemSpace const &problem_space) { for (auto const & arg_ptr : problem_space.arguments) { construct_(arg_ptr.get()); diff --git a/tools/util/include/cutlass/util/host_tensor.h b/tools/util/include/cutlass/util/host_tensor.h index 6326715aee..b859153679 100644 --- a/tools/util/include/cutlass/util/host_tensor.h +++ b/tools/util/include/cutlass/util/host_tensor.h @@ -104,23 +104,15 @@ class HostTensor { /// Constant reference to element in tensor using ConstReference = typename ConstTensorRef::Reference; - /// Note: Below is used to handle packing of subbyte elements - /// kBitsStoredVec : The bits of store vec that could be divisiable by the element - /// kElementsPerStoredVec : The number of elements could be stored in per store vec - /// kNumStoragePerStoredVec : How much storage(i.e. sizeof(element storage)) the store vec needs to consume. - /// Usually the element storage of subbyte is uint8_t. - /// Example - /// int2: kBitsStoredVec = 8; kElementsPerStoredVec = 4; kNumStoragePerStoredVec = 1 uint8_t; - /// int4: kBitsStoredVec = 8; kElementsPerStoredVec = 2; kNumStoragePerStoredVec = 1 uint8_t; - static constexpr int kBitsStoredVec = (sizeof_bits::value < 8) ? cutlass::lcm(sizeof_bits::value, 8) : sizeof_bits::value; - static constexpr int kElementsPerStoredVec = kBitsStoredVec / sizeof_bits::value; - static constexpr int kNumStoragePerStoredVec = kBitsStoredVec / (sizeof(Element) * 8); - - static_assert(kBitsStoredVec != 0, "kBitsStoredVec can not be zero"); - static_assert(kElementsPerStoredVec != 0, "kElementsPerStoredVec can not be zero"); - static_assert(kNumStoragePerStoredVec != 0, "kNumStoragePerStoredVec can not be zero"); - - private: +private: + using StorageUnit = typename platform::conditional_t, uint8_t, // Avoid the std::vector specialization + typename platform::conditional_t::value % 8 == 0, // Handle subbyte types + Element, uint8_t>>; + using StorageContainerCalculator = cutlass::detail::StorageContainerCalculator; + static constexpr int kContainerTypeNumBits = StorageContainerCalculator::kContainerTypeNumBits; + static constexpr int kContainerTypeNumLogicalElements = StorageContainerCalculator::kContainerTypeNumLogicalElements; + static constexpr int kContainerTypeNumBytes = StorageContainerCalculator::kContainerTypeNumBytes; + static constexpr int kContainerTypeNumStorageUnit = StorageContainerCalculator::kContainerTypeNumStorageUnit; // // Data members @@ -133,13 +125,17 @@ class HostTensor { Layout layout_; /// Host-side memory allocation - /// avoid the std::vector specialization - std::vector, uint8_t, Element>> host_; + std::vector host_; /// Device-side memory - device_memory::allocation device_; + device_memory::allocation device_; - public: + /// number of containers + size_t count_to_container_storage_unit_count(size_t count) { + return (count + kContainerTypeNumLogicalElements - 1) / kContainerTypeNumLogicalElements * kContainerTypeNumStorageUnit; + } + +public: // // Device and Host Methods // @@ -185,15 +181,15 @@ class HostTensor { device_.reset(); host_.clear(); - count = (count + kElementsPerStoredVec - 1) / kElementsPerStoredVec * kNumStoragePerStoredVec; - host_.resize(count); + size_t count_container = count_to_container_storage_unit_count(count); + host_.resize(count_container); // Allocate memory - Element* device_memory = nullptr; + StorageUnit* device_memory = nullptr; if (device_backed_) { - device_memory = device_memory::allocate(count); + device_memory = device_memory::allocate(count_container); } - device_.reset(device_memory, device_backed_ ? count : 0); + device_.reset(device_memory, device_backed_ ? count_container : 0); } /// Updates the extent and layout of the HostTensor. Allocates memory according to the new @@ -229,8 +225,9 @@ class HostTensor { layout_ = layout; LongIndex new_size = size_t(layout_.capacity(extent_)); + LongIndex new_size_container = count_to_container_storage_unit_count((layout_.capacity(extent_))); - if (static_cast(new_size) > host_.size()) { + if (static_cast(new_size_container) > host_.size()) { reserve(new_size, device_backed_); } } @@ -244,14 +241,14 @@ class HostTensor { resize(extent, Layout::packed(extent), device_backed_); } - /// Returns the number of elements stored in the host tensor + /// Returns the logical number of elements stored in the host tensor size_t size() const { - return host_.size() / kNumStoragePerStoredVec * kElementsPerStoredVec; + return layout_.capacity(extent_); } - /// Returns the logical capacity based on extent and layout. May differ from size(). + /// Returns the logical capacity in terms of number of elements. May be larger than the size(). LongIndex capacity() const { - return layout_.capacity(extent_); + return host_.size() / kContainerTypeNumStorageUnit * kContainerTypeNumLogicalElements; } /// Gets pointer to host data @@ -277,10 +274,10 @@ class HostTensor { } /// Gets pointer to device data - Element * device_data() { return device_.get(); } + Element * device_data() { return reinterpret_cast(device_.get()); } /// Gets pointer to device data - Element const * device_data() const { return device_.get(); } + Element const * device_data() const { return reinterpret_cast(device_.get()); } /// Gets pointer to device data with a pointer offset Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory::get(device_data(), ptr_element_offset); } @@ -389,7 +386,7 @@ class HostTensor { void sync_host() { if (device_backed()) { device_memory::copy_to_host( - host_data(), device_data(), size()); + host_.data(), device_.get(), device_.size()); } } @@ -397,7 +394,7 @@ class HostTensor { void sync_device() { if (device_backed()) { device_memory::copy_to_device( - device_data(), host_data(), size()); + device_.get(), host_.data(), host_.capacity()); } } @@ -412,8 +409,9 @@ class HostTensor { else { count = __NV_STD_MIN(capacity(), count); } + size_t container_count = count_to_container_storage_unit_count(count); device_memory::copy_to_host( - host_data(), ptr_device, count); + host_.data(), reinterpret_cast(ptr_device), container_count); } /// Copy data from a caller-supplied device pointer into host memory. @@ -427,8 +425,9 @@ class HostTensor { else { count = __NV_STD_MIN(capacity(), count); } + size_t container_count = count_to_container_storage_unit_count(count); device_memory::copy_device_to_device( - device_data(), ptr_device, count); + device_.get(), reinterpret_cast(ptr_device), container_count); } /// Copy data from a caller-supplied device pointer into host memory. @@ -442,8 +441,9 @@ class HostTensor { else { count = __NV_STD_MIN(capacity(), count); } + size_t container_count = count_to_container_storage_unit_count(count); device_memory::copy_to_device( - device_data(), ptr_host, count); + device_.get(), reinterpret_cast(ptr_host), container_count); } /// Copy data from a caller-supplied device pointer into host memory. @@ -457,8 +457,9 @@ class HostTensor { else { count = __NV_STD_MIN(capacity(), count); } + size_t container_count = count_to_container_storage_unit_count(count); device_memory::copy_host_to_host( - host_data(), ptr_host, count); + host_.data(), reinterpret_cast(ptr_host), container_count); } /// Copy data from a caller-supplied device pointer into host memory. @@ -472,8 +473,9 @@ class HostTensor { else { count = __NV_STD_MIN(capacity(), count); } + size_t container_count = count_to_container_storage_unit_count(count); device_memory::copy_to_host( - ptr_host, device_data(), count); + reinterpret_cast(ptr_host), device_.get(), container_count); } /// Copy data from a caller-supplied device pointer into host memory. @@ -487,8 +489,9 @@ class HostTensor { else { count = __NV_STD_MIN(capacity(), count); } + size_t container_count = count_to_container_storage_unit_count(count); device_memory::copy_device_to_device( - ptr_device, device_data(), count); + reinterpret_cast(ptr_device), device_.get(), container_count); } /// Copy data from a caller-supplied device pointer into host memory. @@ -502,8 +505,9 @@ class HostTensor { else { count = __NV_STD_MIN(capacity(), count); } + size_t container_count = count_to_container_storage_unit_count(count); device_memory::copy_to_device( - ptr_device, host_data(), count); + reinterpret_cast(ptr_device), host_.data(), container_count); } /// Copy data from a caller-supplied device pointer into host memory. @@ -517,8 +521,9 @@ class HostTensor { else { count = __NV_STD_MIN(capacity(), count); } + size_t container_count = count_to_container_storage_unit_count(count); device_memory::copy_host_to_host( - ptr_host, host_data(), count); + reinterpret_cast(ptr_host), host_.data(), container_count); } }; diff --git a/tools/util/include/cutlass/util/packed_stride.hpp b/tools/util/include/cutlass/util/packed_stride.hpp index e0f2ec0b56..f5b5e36765 100644 --- a/tools/util/include/cutlass/util/packed_stride.hpp +++ b/tools/util/include/cutlass/util/packed_stride.hpp @@ -458,6 +458,7 @@ make_cute_packed_stride( // Filter cutlass::layout::TensorKCSR -> rank-3 stride (k, (_1, s, r), _0) template +CUTLASS_HOST_DEVICE cute::Stride, IntT, IntT>, cute::Int<0>> make_cute_packed_stride( cute::Stride, IntT, IntT>, cute::Int<0>> s, @@ -497,6 +498,71 @@ make_cute_packed_stride( return s_copy; } + +// +// Wgrad output tensor ((_1, s, r, t), k, _0) +// + +// Filter cutlass::layout::TensorCSK -> rank-3 stride ((_1, s), k, _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT>, IntT, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT>, IntT, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_ksc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ksc[2] == 1); + auto s_copy = s; + cute::get<1,0>(s_copy) = stride_ksc[0]; + cute::get<0,1>(s_copy) = stride_ksc[1]; + return s_copy; +} + +// Filter cutlass::layout::TensorCSRK -> rank-3 stride ((_1, s, r), k, _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT>, IntT, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT, IntT>, IntT, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_krsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_krsc[3] == 1); + auto s_copy = s; + cute::get<1,0>(s_copy) = stride_krsc[0]; + cute::for_each(cute::make_seq<2>{}, [&](auto i) { + cute::get<0,2-i>(s_copy) = stride_krsc[i+1]; + }); + return s_copy; +} + +// Filter cutlass::layout::TensorCSRTK -> rank-3 stride ((_1, s, r, t), k, _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT, IntT>, IntT, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT, IntT, IntT>, IntT, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_ktrsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ktrsc[4] == 1); + auto s_copy = s; + cute::get<1,0>(s_copy) = stride_ktrsc[0]; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<0,3-i>(s_copy) = stride_ktrsc[i+1]; + }); + return s_copy; +} ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/tools/util/include/cutlass/util/reference/device/tensor_fill.h b/tools/util/include/cutlass/util/reference/device/tensor_fill.h index 05b877a235..1230863735 100644 --- a/tools/util/include/cutlass/util/reference/device/tensor_fill.h +++ b/tools/util/include/cutlass/util/reference/device/tensor_fill.h @@ -443,10 +443,10 @@ struct RandomUniformFunc { int int_scale_ = -1 ): seed(seed_), - range(static_cast(max_ - min)), + range(static_cast(max_) - static_cast(min)), max(static_cast(max_)), int_scale(int_scale_) { - + float_scale_up = FloatType(IntType(2) << int_scale); // scale up to clamp low order bits float_scale_down = FloatType(1) / FloatType(IntType(2) << int_scale); } diff --git a/tools/util/include/cutlass/util/reference/host/conv.hpp b/tools/util/include/cutlass/util/reference/host/conv.hpp index 202091d95e..b5beb2a6d6 100644 --- a/tools/util/include/cutlass/util/reference/host/conv.hpp +++ b/tools/util/include/cutlass/util/reference/host/conv.hpp @@ -125,7 +125,8 @@ template< class ShapePadding, class StrideTraversal, class ShapeDilation, - class EpilogueFusionParams> + class EpilogueFusionParams +> struct ConvReferenceImpl { using ElementAcc = typename EpilogueFusionParams::ElementAcc; using ElementC = typename EpilogueFusionParams::ElementC; @@ -145,7 +146,6 @@ struct ConvReferenceImpl { NumericConverter output_converter; EpilogueFusionParams& epi_fusion_params_; - TensorA const& tensor_a_; TensorB const& tensor_b_; TensorC const& tensor_c_; @@ -174,7 +174,8 @@ struct ConvReferenceImpl { padding_(padding), tstride_(tstride), dilation_(dilation), - epi_fusion_params_(epi_fusion_params) { + epi_fusion_params_(epi_fusion_params) + { static_assert(rank(ShapePadding{}) == rank(ShapeDilation{})); static_assert(rank(ShapePadding{}) == rank(StrideTraversal{})); } @@ -211,7 +212,9 @@ struct ConvReferenceImpl { for (int32_t c = 0; c < C; ++c) { int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); if (detail::is_activation_in_bounds(tensor_a_, n, w, c)) { - accumulator += ElementAcc(tensor_a_(c, w, n) * tensor_b_(c, s, k)); + auto a = tensor_a_(c, w, n); + auto b = tensor_b_(c, s, k); + accumulator += ElementAcc(a * b); } } } @@ -256,7 +259,9 @@ struct ConvReferenceImpl { int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_); if (detail::is_activation_in_bounds(tensor_a_, n, h, w, c)) { - accumulator += ElementAcc(tensor_a_(c, w, h, n) * tensor_b_(c, s, r, k)); + auto a = tensor_a_(c, w, h, n); + auto b = tensor_b_(c, s, r, k); + accumulator += ElementAcc(a * b); } } } @@ -308,7 +313,9 @@ struct ConvReferenceImpl { int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_); int32_t d = z * cute::get<2>(tstride_) - cute::get<2>(padding_) + t * cute::get<2>(dilation_); if (detail::is_activation_in_bounds(tensor_a_, n, d, h, w, c)) { - accumulator += ElementAcc(tensor_a_(c, w, h, d, n) * tensor_b_(c, s, r, t, k)); + auto a = tensor_a_(c, w, h, d, n); + auto b = tensor_b_(c, s, r, t, k); + accumulator += ElementAcc(a * b); } } } @@ -516,9 +523,12 @@ struct ConvReferenceImpl { // Specialization for 1D wgrad kernel void wgrad_reference(cute::Int<1> spatial_dims) { - int32_t N = size<2>(tensor_a_); - int32_t Q = size<1>(tensor_a_); - int32_t K = size<0>(tensor_a_); + int32_t N = + size<2>(tensor_a_); + int32_t Q = + size<1>(tensor_a_); + int32_t K = + size<0>(tensor_a_); int32_t S = size<1>(tensor_d_); int32_t C = size<0>(tensor_d_); @@ -536,8 +546,14 @@ struct ConvReferenceImpl { for (int32_t n = 0; n < N; ++n) { for (int32_t q = 0; q < Q; ++q) { int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); - if (detail::is_activation_in_bounds(tensor_b_, n, w, c)) { - accumulator += ElementAcc(tensor_b_(c, w, n) * tensor_a_(k, q, n)); + bool is_in_bounds = + detail::is_activation_in_bounds(tensor_b_, n, w, c); + if (is_in_bounds) { + auto act = + tensor_b_(c, w, n); + auto xformed_act = + tensor_a_(k, q, n); + accumulator += ElementAcc(act * xformed_act); } } } @@ -555,10 +571,14 @@ struct ConvReferenceImpl { // Specialization for 2D wgrad kernel void wgrad_reference(cute::Int<2> spatial_dims) { - int32_t N = size<3>(tensor_a_); - int32_t P = size<2>(tensor_a_); - int32_t Q = size<1>(tensor_a_); - int32_t K = size<0>(tensor_a_); + int32_t N = + size<3>(tensor_a_); + int32_t P = + size<2>(tensor_a_); + int32_t Q = + size<1>(tensor_a_); + int32_t K = + size<0>(tensor_a_); int32_t R = size<2>(tensor_d_); int32_t S = size<1>(tensor_d_); int32_t C = size<0>(tensor_d_); @@ -580,8 +600,14 @@ struct ConvReferenceImpl { for (int32_t q = 0; q < Q; ++q) { int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_); - if (detail::is_activation_in_bounds(tensor_b_, n, h, w, c)) { - accumulator += ElementAcc(tensor_b_(c, w, h, n) * tensor_a_(k, q, p, n)); + bool is_in_bounds = + detail::is_activation_in_bounds(tensor_b_, n, h, w, c); + if (is_in_bounds) { + auto act = + tensor_b_(c, w, h, n); + auto xformed_act = + tensor_a_(k, q, p, n); + accumulator += ElementAcc(act * xformed_act); } } } @@ -601,11 +627,16 @@ struct ConvReferenceImpl { // Specialization for 3D wgrad kernel void wgrad_reference(cute::Int<3> spatial_dims) { - int32_t N = size<4>(tensor_a_); - int32_t Z = size<3>(tensor_a_); - int32_t P = size<2>(tensor_a_); - int32_t Q = size<1>(tensor_a_); - int32_t K = size<0>(tensor_a_); + int32_t N = + size<4>(tensor_a_); + int32_t Z = + size<3>(tensor_a_); + int32_t P = + size<2>(tensor_a_); + int32_t Q = + size<1>(tensor_a_); + int32_t K = + size<0>(tensor_a_); int32_t T = size<3>(tensor_d_); int32_t R = size<2>(tensor_d_); int32_t S = size<1>(tensor_d_); @@ -631,8 +662,14 @@ struct ConvReferenceImpl { int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_); int32_t d = z * cute::get<2>(tstride_) - cute::get<2>(padding_) + t * cute::get<2>(dilation_); - if (detail::is_activation_in_bounds(tensor_b_, n, d, h, w, c)) { - accumulator += ElementAcc(tensor_b_(c, w, h, d, n) * tensor_a_(k, q, p, z, n)); + bool is_in_bounds = + detail::is_activation_in_bounds(tensor_b_, n, d, h, w, c); + if (is_in_bounds) { + auto act = + tensor_b_(c, w, h, d, n); + auto xformed_act = + tensor_a_(k, q, p, z, n); + accumulator += ElementAcc(act * xformed_act); } } } diff --git a/tools/util/include/cutlass/util/reference/host/gett.hpp b/tools/util/include/cutlass/util/reference/host/gett.hpp index 84aa93634e..9508dc49bf 100644 --- a/tools/util/include/cutlass/util/reference/host/gett.hpp +++ b/tools/util/include/cutlass/util/reference/host/gett.hpp @@ -82,7 +82,6 @@ struct GettMainloopParams { }; ///////////////////////////////////////////////////////////////////////////////////////////////// - template< class ElementScalar_, class ElementScalingFactor_, @@ -117,7 +116,6 @@ struct GettEpilogueParams { using EngineD = typename TensorD::engine_type; using LayoutD = typename TensorD::layout_type; static constexpr bool PerColumnBias = PerColumnBias_; - ElementScalar alpha = ElementScalar(1); ElementScalar beta = ElementScalar(0); @@ -184,6 +182,8 @@ void gett_mainloop( static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B"); static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B"); + + using cute::raw_pointer_cast; using ElementA = typename ElementTraits::type; using ElementB = typename ElementTraits::type; @@ -254,6 +254,8 @@ void gett_epilogue( static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B"); static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B"); + using cute::raw_pointer_cast; + using ElementCompute = typename EpilogueParams::ElementCompute; using ElementC = typename EpilogueParams::TensorC::value_type; using ElementD = typename EpilogueParams::TensorD::value_type; @@ -265,7 +267,6 @@ void gett_epilogue( using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp; constexpr bool PerColBias = EpilogueParams::PerColumnBias; - constexpr bool IsScalingAndAmaxOutputNeeded = cute::is_same_v or cute::is_same_v; @@ -300,7 +301,7 @@ void gett_epilogue( // Output related converter NumericConverter destination_converter; - NumericConverter aux_destination_converter; + [[maybe_unused]] NumericConverter aux_destination_converter; NumericConverter dBias_converter; // Epilogue operations @@ -417,6 +418,7 @@ void gett_epilogue( } } } + #if defined(_OPENMP) #pragma omp critical(Abs_Max_Data_Update) #endif